Kandinsky 2.1 MoVQ#

@inproceedings{razzhigaev2023kandinsky,
  title={Kandinsky: an Improved Text-to-Image Synthesis with Image Prior and Latent Diffusion},
  author={Razzhigaev, Anton and Shakhmatov, Arseniy and Maltseva, Anastasia and Arkhipkin, Vladimir and Pavlov, Igor and Ryabov, Ilya and Kuts, Angelina and Panchenko, Alexander and Kuznetsov, Andrey and Dimitrov, Denis},
  booktitle={EMNLP},
  year={2023}
}

arXiv

Pre-trained model weights

import io, zlib
import numpy as np
import torch
import datasets
import matplotlib.pyplot as plt
from torchvision.transforms.v2.functional import pil_to_tensor

from compressors.diffusers._codec import (
    load_codec,
    encode_to_quant_and_indices,
    decode_from_quant,
    to_model_input,
    from_model_output,
)

Load the codec#

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32

vae, info = load_codec("kandinsky-community/kandinsky-2-1", device=device, torch_dtype=torch_dtype)

print(f"class             = {info.class_name}")
print(f"stride            = {info.stride}")
print(f"num_vq_embeddings = {info.num_vq_embeddings}")
print(f"latent_channels   = {info.latent_channels}")
print(f"vq_embed_dim      = {info.vq_embed_dim}")
print(f"bytes_per_index   = {info.bytes_per_index}")
print(f"input_range       = {info.input_range!r}")
Skipping import of cpp extensions due to incompatible torch version 2.11.0+cu130 for torchao version 0.15.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
/home/dgj335/g/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:205: UserWarning: The `local_dir_use_symlinks` argument is deprecated and ignored in `hf_hub_download`. Downloading to a local directory does not use symlinks anymore.
  warnings.warn(
class             = VQModel
stride            = 8
num_vq_embeddings = 16384
latent_channels   = 4
vq_embed_dim      = 4
bytes_per_index   = 2
input_range       = 'neg11'
codebook = vae.quantize.embedding.weight
print(f"codebook shape = {tuple(codebook.shape)}")
print(f"|codebook|.mean = {codebook.abs().mean().item():.4f}  (should be well above 0)")
codebook shape = (16384, 4)
|codebook|.mean = 0.1982  (should be well above 0)

Load a Kodak image#

ds = datasets.load_dataset("danjacobellis/kodak", split="validation")
img_pil = ds[1]["image"].convert("RGB")
W, H = img_pil.size
n_pixels = H * W

x_01 = pil_to_tensor(img_pil).to(torch_dtype).to(device).unsqueeze(0) / 255.0
print(f"image size = {W}×{H}  ({n_pixels} px)")
print(f"x_01.shape = {tuple(x_01.shape)}  range = [{x_01.min().item():.3f}, {x_01.max().item():.3f}]")

plt.figure(figsize=(7, 5))
plt.imshow(x_01[0].permute(1, 2, 0).cpu().numpy())
plt.axis("off")
plt.title("Original image")
plt.show()
image size = 768×512  (393216 px)
x_01.shape = (1, 3, 512, 768)  range = [0.000, 1.000]
_images/2f3bda0a8b433241eaa34367480409489987af991644ff9832cc73aedf016824.webp

Encode and quantize#

x_in = to_model_input(x_01, info.input_range).to(torch_dtype)

with torch.inference_mode():
    quant, indices, latent_hw = encode_to_quant_and_indices(vae, x_in)

H_l, W_l = latent_hw
print(f"quant.shape   = {tuple(quant.shape)}")
print(f"indices.shape = {tuple(indices.shape)}  dtype = {indices.dtype}")
print(f"latent grid   = {H_l}×{W_l}  (= {H}//{info.stride} × {W}//{info.stride})")
print(f"index range   = [{int(indices.min())}, {int(indices.max())}]  of [0, {info.num_vq_embeddings})")
quant.shape   = (1, 4, 64, 96)
indices.shape = (6144,)  dtype = torch.int64
latent grid   = 64×96  (= 512//8 × 768//8)
index range   = [0, 16383]  of [0, 16384)

Index map#

idx_map = indices.view(H_l, W_l).cpu().numpy()
plt.figure(figsize=(7, 5))
plt.imshow(idx_map, cmap="viridis")
plt.colorbar(fraction=0.025)
plt.axis("off")
plt.title(f"Index map ({H_l}×{W_l})  values ∈ [0, {info.num_vq_embeddings})")
plt.show()
_images/b0f900e87ee05f3b638cfccd616e010ec864e00d01b0d0c8c909e1307e2cf28a.webp

Pack indices to bytes#

indices_np = indices.cpu().numpy().astype(np.uint16)
buff = io.BytesIO(indices_np.tobytes())
n_bytes = len(buff.getbuffer())
bpp_raw = 8 * n_bytes / n_pixels
print(f"raw uint16 bytes = {n_bytes}")
print(f"bpp (raw)        = {bpp_raw:.4f}")
raw uint16 bytes = 12288
bpp (raw)        = 0.2500
z_bytes = len(zlib.compress(buff.getvalue()))
bpp_zlib = 8 * z_bytes / n_pixels
entropy_ceiling = np.log2(info.num_vq_embeddings)
print(f"raw uint16 :  {n_bytes:6d} B   bpp = {bpp_raw:.4f}")
print(f"zlib       :  {z_bytes:6d} B   bpp = {bpp_zlib:.4f}")
print(f"entropy ceiling per index = log2({info.num_vq_embeddings}) = {entropy_ceiling:.2f} bits")
raw uint16 :   12288 B   bpp = 0.2500
zlib       :   11667 B   bpp = 0.2374
entropy ceiling per index = log2(16384) = 14.00 bits

Decode from the buffer#

# 1. parse uint16 from buffer
indices_recovered = np.frombuffer(buff.getvalue(), dtype=np.uint16)
indices_long = torch.from_numpy(indices_recovered.copy()).long().to(device)

# 2. codebook lookup
z_q = vae.quantize.embedding(indices_long)                          # (N, vq_embed_dim)
quant_recovered = z_q.reshape(1, H_l, W_l, info.vq_embed_dim)        # (B, H_l, W_l, C)
quant_recovered = quant_recovered.permute(0, 3, 1, 2).contiguous().to(torch_dtype)

# 3. decode
with torch.inference_mode():
    recon = decode_from_quant(vae, quant_recovered)
recon_01 = from_model_output(recon.float(), info.input_range).clamp(0, 1)

# Sanity check: recovering quant from indices is lossless modulo float rounding
max_diff = (quant_recovered - quant).abs().max().item()
print(f"max |quant_recovered - quant| = {max_diff:.2e}  (should be ~1e-8)")
max |quant_recovered - quant| = 2.98e-08  (should be ~1e-8)

Reconstruction quality#

mse = torch.nn.functional.mse_loss(x_01, recon_01).item()
psnr_db = -10 * np.log10(mse)
print(f"PSNR = {psnr_db:.2f} dB    bpp (raw uint16) = {bpp_raw:.4f}    bpp (zlib) = {bpp_zlib:.4f}")
PSNR = 29.49 dB    bpp (raw uint16) = 0.2500    bpp (zlib) = 0.2374
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
axes[0].imshow(x_01[0].permute(1, 2, 0).cpu().numpy())
axes[0].set_title("Original")
axes[0].axis("off")
axes[1].imshow(recon_01[0].permute(1, 2, 0).cpu().numpy())
axes[1].set_title(f"Reconstruction  ({psnr_db:.2f} dB, {bpp_raw:.3f} bpp)")
axes[1].axis("off")
plt.tight_layout()
plt.show()
_images/7619afa532f96216d77304b2a6d9660a16e9a1ce956e1c05a0a0b84a51f6b1dc.webp