Image Compression#

import os
import torch
import json
import matplotlib.pyplot as plt
import numpy as np
from types import SimpleNamespace
from PIL import Image, ImageEnhance
from IPython.display import display
from torchvision.transforms import ToPILImage, PILToTensor
from walloc import walloc
from walloc.walloc import latent_to_pil, pil_to_latent

Load the model from a pre-trained checkpoint#

wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_16x.pth

wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_16x.json

device = "cpu"
codec_config = SimpleNamespace(**json.load(open("RGB_16x.json")))
checkpoint = torch.load("RGB_16x.pth",map_location="cpu",weights_only=False)
codec = walloc.Codec2D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode
)
codec.load_state_dict(checkpoint['model_state_dict'])
codec = codec.to(device)
codec.eval();

Load an example image#

wget "https://r0k.us/graphics/kodak/kodak/kodim05.png"

img = Image.open("kodim05.png")
img
_images/d69f75f606233935cf37e7e89d4127e5258084bf96e03ff9b23fe393991683e7.png

Full encoding and decoding pipeline with .forward()#

  • If codec.eval() is called, the latent is rounded to nearest integer.

  • If codec.train() is called, uniform noise is added instead of rounding.

with torch.no_grad():
    codec.eval()
    x = PILToTensor()(img).to(torch.float)
    x = (x/255 - 0.5).unsqueeze(0).to(device)
    x_hat, _, _ = codec(x)
ToPILImage()(x_hat[0]+0.5)
_images/ee8d99bd1ea3cb2d0bf03a0896ac2fe676aaef09edb764b3725af24983b75705.png

Accessing latents#

with torch.no_grad():
    X = codec.wavelet_analysis(x,J=codec.J)
    z = codec.encoder[0:2](X)
    z_hat = codec.encoder[2](z)
    X_hat = codec.decoder(z_hat)
    x_rec = codec.wavelet_synthesis(X_hat,J=codec.J)
print(f"dimensionality reduction: {x.numel()/z.numel()}×")
dimensionality reduction: 16.0×
plt.figure(figsize=(5,3),dpi=150)
plt.hist(
    z.flatten().numpy(),
    range=(-25,25),
    bins=151,
    density=True,
);
plt.title("Histogram of latents")
plt.xlim([-25,25]);
_images/8d3dd038eecd9f8273d4469aa732cc954126f8a5992f50161b488d068a70b0b1.png

Lossless compression of latents#

def scale_for_display(img, n_bits):
    scale_factor = (2**8 - 1) / (2**n_bits - 1)
    lut = [int(i * scale_factor) for i in range(2**n_bits)]
    channels = img.split()
    scaled_channels = [ch.point(lut * 2**(8-n_bits)) for ch in channels]
    return Image.merge(img.mode, scaled_channels)

Single channel PNG (L)#

z_padded = torch.nn.functional.pad(z_hat, (0, 0, 0, 0, 0, 4))
z_pil = latent_to_pil(z_padded,codec.latent_bits,1)
display(scale_for_display(z_pil[0], codec.latent_bits))
z_pil[0].save('latent.png')
png = [Image.open("latent.png")]
z_rec = pil_to_latent(png,16,codec.latent_bits,1)
assert(z_rec.equal(z_padded))
print("compression_ratio: ", x.numel()/os.path.getsize("latent.png"))
_images/50e38751610b35578c90925fc3b6c0a6179235001e172b20fbb5749cb2578e12.png
compression_ratio:  26.729991842653856

Three channel WebP (RGB)#

z_pil = latent_to_pil(z_hat,codec.latent_bits,3)
display(scale_for_display(z_pil[0], codec.latent_bits))
z_pil[0].save('latent.webp',lossless=True)
webp = [Image.open("latent.webp")]
z_rec = pil_to_latent(webp,12,codec.latent_bits,3)
assert(z_rec.equal(z_hat))
print("compression_ratio: ", x.numel()/os.path.getsize("latent.webp"))
_images/ff39249f219ca35fbb59c296ebf2b00dd80f9b8cdb4889c4b244fa98b1caad25.png
compression_ratio:  28.811254396248536

Four channel TIF (CMYK)#

z_padded = torch.nn.functional.pad(z_hat, (0, 0, 0, 0, 0, 4))
z_pil = latent_to_pil(z_padded,codec.latent_bits,4)
display(scale_for_display(z_pil[0], codec.latent_bits))
z_pil[0].save('latent.tif',compression="tiff_adobe_deflate")
tif = [Image.open("latent.tif")]
z_rec = pil_to_latent(tif,16,codec.latent_bits,4)
assert(z_rec.equal(z_padded))
print("compression_ratio: ", x.numel()/os.path.getsize("latent.tif"))
_images/f763a4a4d1a6df907646ba1fdfef8be5794a513ab2653709ceb9df5d246e24d2.jpg
compression_ratio:  21.04034530731638