DGML#

@inproceedings{cheng2020learned,
  title={Learned image compression with discretized gaussian mixture likelihoods and attention modules},
  author={Cheng, Zhengxue and Sun, Heming and Takeuchi, Masaru and Katto, Jiro},
  booktitle={CVPR},
  year={2020}
}

arXiv

import io, struct, zlib

import numpy as np
import torch
import torch.nn.functional as F
import datasets
import matplotlib.pyplot as plt
import PIL.Image
from torchvision.transforms.v2.functional import to_pil_image

from compressors.compressai_baselines.zoo import cheng2020_anchor

torch.set_num_threads(1)
device = 'cpu'  # the codec runs on CPU; the eval harness convention

Load the codec#

quality = 3
model = cheng2020_anchor(quality=quality, pretrained=True, progress=False)
model.eval().to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"cheng2020-anchor (quality={quality})  N=M={model.N}  params={n_params/1e6:.2f} M")
cheng2020-anchor (quality=3)  N=M=128  params=11.83 M

Load a Kodak image#

dataset = datasets.load_dataset("danjacobellis/kodak", split="validation")
img_idx = 22
img_pil = dataset[img_idx]['image'].convert("RGB")
print(f"Kodak[{img_idx}]  size = {img_pil.size}")

display(img_pil)
Kodak[22]  size = (768, 512)
_images/e3e3a8bfc8532ca69a837c954ab148f40291bf94bcc1f6003bd0d52b4e40f547.webp

Preprocessing#

def pil_to_01_tensor(img):
    arr = np.asarray(img.convert("RGB")).astype("float32") / 255.0
    return torch.from_numpy(arr).permute(2, 0, 1).contiguous().unsqueeze(0)


def pad_to_multiple(t, multiple):
    H, W = t.shape[-2:]
    pH = (multiple - H % multiple) % multiple
    pW = (multiple - W % multiple) % multiple
    return F.pad(t, (0, pW, 0, pH), mode="replicate"), (H, W)


x = pil_to_01_tensor(img_pil).to(device)
H_orig, W_orig = x.shape[-2:]
n_pixels = H_orig * W_orig
x_pad, _ = pad_to_multiple(x, 64)
print(f"native HxW = {H_orig}x{W_orig}; padded = {x_pad.shape[-2]}x{x_pad.shape[-1]}; n_pixels = {n_pixels}")
native HxW = 512x768; padded = 512x768; n_pixels = 393216

Forward pass#

with torch.no_grad():
    out = model(x_pad)

bpp_theoretical = sum(
    float((-torch.log2(lk)).sum().item())
    for lk in out['likelihoods'].values()
) / n_pixels

print(f"bpp_theoretical = {bpp_theoretical:.4f}")
bpp_theoretical = 0.1129

Per-channel bit cost#

y_lk = out['likelihoods']['y']
per_ch_bits_y = (-torch.log2(y_lk)).sum(dim=(0, 2, 3)).cpu().numpy()
per_ch_bpp_y = per_ch_bits_y / n_pixels

fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(np.arange(per_ch_bpp_y.size), per_ch_bpp_y, width=1.0, color="#3a7ca5")
ax.set_xlabel("y channel index")
ax.set_ylabel("bpp")
ax.set_title("Per-channel bpp contribution from latent y")
ax.set_xlim(-0.5, per_ch_bpp_y.size - 0.5)
ax.grid(True, alpha=0.3, axis="y")
fig.tight_layout()
plt.show()

top5 = np.argsort(-per_ch_bpp_y)[:5]
print("Top-5 most expensive y channels:")
for i in top5:
    print(f"  ch {int(i):3d}   bpp = {float(per_ch_bpp_y[i]):.5f}")
_images/93862fe8fbb18897f3097d51e4d09ea532e57bad09dcf952b9894060d7d8e369.webp
Top-5 most expensive y channels:
  ch 119   bpp = 0.00941
  ch 100   bpp = 0.00746
  ch  91   bpp = 0.00694
  ch  12   bpp = 0.00662
  ch  53   bpp = 0.00644

Top-K latent feature maps#

with torch.no_grad():
    y_pre = model.g_a(x_pad)

top_k = 16
top_idx = np.argsort(-per_ch_bpp_y)[:top_k]

ncols = 4
nrows = top_k // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2.0, nrows * 1.4))
for ax, ch in zip(axes.flat, top_idx):
    fmap = y_pre[0, int(ch)].cpu().numpy()
    ax.imshow(fmap, cmap="gray")
    ax.set_title(f"ch {int(ch)}", fontsize=8)
    ax.axis("off")
fig.suptitle(f"Top-{top_k} y channels by bit cost (pre-quantization)", fontsize=10)
fig.tight_layout()
plt.show()
_images/66ec165fd553824ddd2fdcf08f742249fa5f98b939024280a510939e2080fd74.webp

Encode: extract integer latents#

with torch.no_grad():
    y = model.g_a(x_pad)
    z = model.h_a(y)
    z_hat, _ = model.entropy_bottleneck(z)
    gaussian_params = model.h_s(z_hat)
    scales_hat, means_hat = gaussian_params.chunk(2, 1)
    y_hat, _ = model.gaussian_conditional(y, scales_hat, means=means_hat)

print(f"y_hat  shape = {tuple(y_hat.shape)}    range = [{y_hat.min().item():+.2f}, {y_hat.max().item():+.2f}]")
print(f"z_hat  shape = {tuple(z_hat.shape)}    range = [{z_hat.min().item():+.2f}, {z_hat.max().item():+.2f}]")
y_hat  shape = (1, 128, 32, 48)    range = [-24.86, +63.98]
z_hat  shape = (1, 128, 8, 12)    range = [-5.54, +3.46]

Pack into a byte buffer#

y_int = y_hat.detach().cpu().to(torch.int16).numpy()
z_int = z_hat.detach().cpu().to(torch.int16).numpy()

buff = io.BytesIO()
buff.write(struct.pack("<8i", *y_int.shape, *z_int.shape))  # N C H W  N C H W
buff.write(y_int.tobytes())
buff.write(z_int.tobytes())

bpp_buffer_raw = 8 * len(buff.getbuffer()) / n_pixels

# Honest middle-ground rate: pipe the same buffer through zlib.
payload_z = zlib.compress(buff.getvalue(), 9)
bpp_buffer_zlib = 8 * len(payload_z) / n_pixels

print(f"buffer size           = {len(buff.getbuffer()):>8d} bytes")
print(f"zlib(level=9)         = {len(payload_z):>8d} bytes")
print()
print(f"bpp_theoretical       = {bpp_theoretical:.4f}    (entropy lower bound)")
print(f"bpp_buffer_raw        = {bpp_buffer_raw:.4f}    ({bpp_buffer_raw/bpp_theoretical:.1f}x of theoretical)")
print(f"bpp_buffer_zlib       = {bpp_buffer_zlib:.4f}    ({bpp_buffer_zlib/bpp_theoretical:.1f}x of theoretical)")
buffer size           =   417824 bytes
zlib(level=9)         =    16908 bytes

bpp_theoretical       = 0.1129    (entropy lower bound)
bpp_buffer_raw        = 8.5007    (75.3x of theoretical)
bpp_buffer_zlib       = 0.3440    (3.0x of theoretical)

Decode from the buffer#

buff.seek(0)
y_shape = struct.unpack("<4i", buff.read(16))
z_shape = struct.unpack("<4i", buff.read(16))
y_n = np.frombuffer(buff.read(int(np.prod(y_shape)) * 2), dtype=np.int16).reshape(y_shape)
z_n = np.frombuffer(buff.read(int(np.prod(z_shape)) * 2), dtype=np.int16).reshape(z_shape)

y_hat_decoded = torch.from_numpy(y_n.copy()).float().to(device)

with torch.no_grad():
    x_hat_padded = model.g_s(y_hat_decoded)

x_hat = x_hat_padded[..., :H_orig, :W_orig].clamp(0.0, 1.0)
mse = F.mse_loss(x[..., :H_orig, :W_orig], x_hat).item()
psnr = -10 * np.log10(mse) if mse > 0 else float('inf')
print(f"PSNR (buffer-decoded) = {psnr:.2f} dB")
PSNR (buffer-decoded) = 30.86 dB

Original / reconstruction / difference#

orig_np  = x[0, :, :H_orig, :W_orig].cpu().permute(1, 2, 0).numpy()
recon_np = x_hat[0].cpu().permute(1, 2, 0).numpy()
diff_np  = np.clip(np.abs(orig_np - recon_np) * 8.0, 0.0, 1.0)

fig, axes = plt.subplots(1, 3, figsize=(13, 4.5))
for ax, im, title in zip(
    axes,
    [orig_np, recon_np, diff_np],
    ["original",
     f"reconstruction (PSNR = {psnr:.2f} dB)",
     "|x − x̂| × 8"],
):
    ax.imshow(im)
    ax.set_title(title, fontsize=10)
    ax.axis("off")
fig.tight_layout()
plt.show()
_images/df89ffeeeb82a06c9c0a238d422a96e81bf154204a2f6d62bd3d02b7990866ce.webp

Rate-distortion staircase#

rd_points = []
for q in [1, 2, 3, 4, 5, 6]:
    m = cheng2020_anchor(quality=q, pretrained=True, progress=False).eval().to(device)
    with torch.no_grad():
        out_q = m(x_pad)
    bpp_q = sum(
        float((-torch.log2(lk)).sum().item())
        for lk in out_q['likelihoods'].values()
    ) / n_pixels
    x_hat_q = out_q['x_hat'][..., :H_orig, :W_orig].clamp(0.0, 1.0)
    mse_q = F.mse_loss(x[..., :H_orig, :W_orig], x_hat_q).item()
    psnr_q = -10 * np.log10(mse_q) if mse_q > 0 else float('inf')
    rd_points.append((q, bpp_q, psnr_q))
    print(f"  q={q}  N={m.N}  bpp={bpp_q:.4f}  PSNR={psnr_q:.2f} dB")
    del m

bpps = [p[1] for p in rd_points]
psnrs = [p[2] for p in rd_points]

fig, ax = plt.subplots(figsize=(6, 4))
ax.semilogx(bpps, psnrs, 'o-', color="#3a7ca5")
for q, b, p in rd_points:
    ax.annotate(f"q={q}", (b, p), textcoords="offset points", xytext=(6, -3), fontsize=8)
ax.set_xlabel("bpp (theoretical, from likelihoods)")
ax.set_ylabel("PSNR (dB)")
ax.set_title(f"cheng2020-anchor RD on Kodak[{img_idx}]")
ax.grid(True, alpha=0.3, which="both")
fig.tight_layout()
plt.show()
  q=1  N=128  bpp=0.0592  PSNR=32.08 dB
  q=2  N=128  bpp=0.0810  PSNR=33.33 dB
  q=3  N=128  bpp=0.1129  PSNR=34.54 dB
  q=4  N=192  bpp=0.1640  PSNR=36.38 dB
  q=5  N=192  bpp=0.2318  PSNR=37.92 dB
  q=6  N=192  bpp=0.3202  PSNR=39.19 dB
_images/9d429c7db96b03d10ef410ba0d61aa0a1ba47c46395b2ec762ed9b32c3154c9d.webp