Kandinsky 2.1 MoVQ#
@inproceedings{razzhigaev2023kandinsky,
title={Kandinsky: an Improved Text-to-Image Synthesis with Image Prior and Latent Diffusion},
author={Razzhigaev, Anton and Shakhmatov, Arseniy and Maltseva, Anastasia and Arkhipkin, Vladimir and Pavlov, Igor and Ryabov, Ilya and Kuts, Angelina and Panchenko, Alexander and Kuznetsov, Andrey and Dimitrov, Denis},
booktitle={EMNLP},
year={2023}
}
import io, zlib
import numpy as np
import torch
import datasets
import matplotlib.pyplot as plt
from torchvision.transforms.v2.functional import pil_to_tensor
from compressors.diffusers._codec import (
load_codec,
encode_to_quant_and_indices,
decode_from_quant,
to_model_input,
from_model_output,
)
Load the codec#
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32
vae, info = load_codec("kandinsky-community/kandinsky-2-1", device=device, torch_dtype=torch_dtype)
print(f"class = {info.class_name}")
print(f"stride = {info.stride}")
print(f"num_vq_embeddings = {info.num_vq_embeddings}")
print(f"latent_channels = {info.latent_channels}")
print(f"vq_embed_dim = {info.vq_embed_dim}")
print(f"bytes_per_index = {info.bytes_per_index}")
print(f"input_range = {info.input_range!r}")
Skipping import of cpp extensions due to incompatible torch version 2.11.0+cu130 for torchao version 0.15.0 Please see https://github.com/pytorch/ao/issues/2919 for more info
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with:
```
pip install accelerate
```
.
/home/dgj335/g/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:205: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.
warnings.warn(
class = VQModel
stride = 8
num_vq_embeddings = 16384
latent_channels = 4
vq_embed_dim = 4
bytes_per_index = 2
input_range = 'neg11'
codebook = vae.quantize.embedding.weight
print(f"codebook shape = {tuple(codebook.shape)}")
print(f"|codebook|.mean = {codebook.abs().mean().item():.4f} (should be well above 0)")
codebook shape = (16384, 4)
|codebook|.mean = 0.1982 (should be well above 0)
Load a Kodak image#
ds = datasets.load_dataset("danjacobellis/kodak", split="validation")
img_pil = ds[1]["image"].convert("RGB")
W, H = img_pil.size
n_pixels = H * W
x_01 = pil_to_tensor(img_pil).to(torch_dtype).to(device).unsqueeze(0) / 255.0
print(f"image size = {W}×{H} ({n_pixels} px)")
print(f"x_01.shape = {tuple(x_01.shape)} range = [{x_01.min().item():.3f}, {x_01.max().item():.3f}]")
plt.figure(figsize=(7, 5))
plt.imshow(x_01[0].permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("Original image")
plt.show()
image size = 768×512 (393216 px)
x_01.shape = (1, 3, 512, 768) range = [0.000, 1.000]
Encode and quantize#
x_in = to_model_input(x_01, info.input_range).to(torch_dtype)
with torch.inference_mode():
quant, indices, latent_hw = encode_to_quant_and_indices(vae, x_in)
H_l, W_l = latent_hw
print(f"quant.shape = {tuple(quant.shape)}")
print(f"indices.shape = {tuple(indices.shape)} dtype = {indices.dtype}")
print(f"latent grid = {H_l}×{W_l} (= {H}//{info.stride} × {W}//{info.stride})")
print(f"index range = [{int(indices.min())}, {int(indices.max())}] of [0, {info.num_vq_embeddings})")
quant.shape = (1, 4, 64, 96)
indices.shape = (6144,) dtype = torch.int64
latent grid = 64×96 (= 512//8 × 768//8)
index range = [0, 16383] of [0, 16384)
Index map#
idx_map = indices.view(H_l, W_l).cpu().numpy()
plt.figure(figsize=(7, 5))
plt.imshow(idx_map, cmap="viridis")
plt.colorbar(fraction=0.025)
plt.axis("off")
plt.title(f"Index map ({H_l}×{W_l}) values ∈ [0, {info.num_vq_embeddings})")
plt.show()
Pack indices to bytes#
indices_np = indices.cpu().numpy().astype(np.uint16)
buff = io.BytesIO(indices_np.tobytes())
n_bytes = len(buff.getbuffer())
bpp_raw = 8 * n_bytes / n_pixels
print(f"raw uint16 bytes = {n_bytes}")
print(f"bpp (raw) = {bpp_raw:.4f}")
raw uint16 bytes = 12288
bpp (raw) = 0.2500
z_bytes = len(zlib.compress(buff.getvalue()))
bpp_zlib = 8 * z_bytes / n_pixels
entropy_ceiling = np.log2(info.num_vq_embeddings)
print(f"raw uint16 : {n_bytes:6d} B bpp = {bpp_raw:.4f}")
print(f"zlib : {z_bytes:6d} B bpp = {bpp_zlib:.4f}")
print(f"entropy ceiling per index = log2({info.num_vq_embeddings}) = {entropy_ceiling:.2f} bits")
raw uint16 : 12288 B bpp = 0.2500
zlib : 11667 B bpp = 0.2374
entropy ceiling per index = log2(16384) = 14.00 bits
Decode from the buffer#
# 1. parse uint16 from buffer
indices_recovered = np.frombuffer(buff.getvalue(), dtype=np.uint16)
indices_long = torch.from_numpy(indices_recovered.copy()).long().to(device)
# 2. codebook lookup
z_q = vae.quantize.embedding(indices_long) # (N, vq_embed_dim)
quant_recovered = z_q.reshape(1, H_l, W_l, info.vq_embed_dim) # (B, H_l, W_l, C)
quant_recovered = quant_recovered.permute(0, 3, 1, 2).contiguous().to(torch_dtype)
# 3. decode
with torch.inference_mode():
recon = decode_from_quant(vae, quant_recovered)
recon_01 = from_model_output(recon.float(), info.input_range).clamp(0, 1)
# Sanity check: recovering quant from indices is lossless modulo float rounding
max_diff = (quant_recovered - quant).abs().max().item()
print(f"max |quant_recovered - quant| = {max_diff:.2e} (should be ~1e-8)")
max |quant_recovered - quant| = 2.98e-08 (should be ~1e-8)
Reconstruction quality#
mse = torch.nn.functional.mse_loss(x_01, recon_01).item()
psnr_db = -10 * np.log10(mse)
print(f"PSNR = {psnr_db:.2f} dB bpp (raw uint16) = {bpp_raw:.4f} bpp (zlib) = {bpp_zlib:.4f}")
PSNR = 29.49 dB bpp (raw uint16) = 0.2500 bpp (zlib) = 0.2374
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
axes[0].imshow(x_01[0].permute(1, 2, 0).cpu().numpy())
axes[0].set_title("Original")
axes[0].axis("off")
axes[1].imshow(recon_01[0].permute(1, 2, 0).cpu().numpy())
axes[1].set_title(f"Reconstruction ({psnr_db:.2f} dB, {bpp_raw:.3f} bpp)")
axes[1].axis("off")
plt.tight_layout()
plt.show()