WaLLoC#

@inproceedings{jacobellis2025learned,
  title={Learned Compression for Compressed Learning},
  author={Jacobellis, Dan and Yadwadkar, Neeraja J.},
  booktitle={Data Compression Conference},
  year={2025},
  organization={IEEE}
}

arXiv

Github

Pre-trained model weights

import io
import torch
import matplotlib.pyplot as plt
import PIL.Image
from datasets import load_dataset
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from compressors.walloc._codec import (
    load_codec,
    encode_to_latent,
    decode_from_blob,
    latent_to_webp_bytes,
    to_model_input,
    from_model_output,
)

device = 'cpu'
torch_dtype = torch.float32

Load the codec#

codec, info = load_codec(device=device, torch_dtype=torch_dtype)

print(f'channels                 = {info.channels}')
print(f'J (wavelet depth)        = {info.J}')
print(f'latent_dim               = {info.latent_dim}')
print(f'latent_bits              = {info.latent_bits}')
print(f'dimensionality_reduction = {info.dimensionality_reduction:.2f}x')
print(f'input_range              = {info.input_range}  # [-0.5, 0.5]')
Skipping import of cpp extensions due to incompatible torch version 2.11.0+cu130 for torchao version 0.15.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
/home/dgj335/g/lib/python3.12/site-packages/pytorch_wavelets/dtcwt/coeffs.py:7: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import resource_stream
channels                 = 3
J (wavelet depth)        = 3
latent_dim               = 12
latent_bits              = 8
dimensionality_reduction = 16.00x
input_range              = centered  # [-0.5, 0.5]

Load a Kodak image#

dataset = load_dataset('danjacobellis/kodak', split='validation')
img_pil = dataset[22]['image'].convert('RGB')
W, H = img_pil.size
n_pixels = H * W
print(f'image size = {W}x{H}  (H % 8 = {H % 8}, W % 8 = {W % 8})')
img_pil
image size = 768x512  (H % 8 = 0, W % 8 = 0)
_images/e3e3a8bfc8532ca69a837c954ab148f40291bf94bcc1f6003bd0d52b4e40f547.webp
x_01 = pil_to_tensor(img_pil).to(torch_dtype).to(device).unsqueeze(0) / 255.0
x_in = to_model_input(x_01)
print(f'x_01 range = [{x_01.min():.3f}, {x_01.max():.3f}]')
print(f'x_in range = [{x_in.min():.3f}, {x_in.max():.3f}]   shape = {tuple(x_in.shape)}')
x_01 range = [0.000, 1.000]
x_in range = [-0.500, 0.500]   shape = (1, 3, 512, 768)

Wavelet analysis#

with torch.inference_mode():
    X = codec.wavelet_analysis(x_in, J=codec.J)
print(f'X shape = {tuple(X.shape)}  (expected: 1, 3 * 4**J = 192, H/8, W/8)')
X shape = (1, 192, 64, 96)  (expected: 1, 3 * 4**J = 192, H/8, W/8)

Encode and round#

with torch.inference_mode():
    z = codec.encoder[0:2](X)        # continuous latent
    z_hat = codec.encoder[2](z)      # rounded integer latent

print(f'z shape     = {tuple(z.shape)}')
print(f'z range     = [{z.min():.2f}, {z.max():.2f}]')
print(f'z_hat shape = {tuple(z_hat.shape)}  (1, latent_dim, H/8, W/8)')
print(f'z_hat range = [{z_hat.min().item():.0f}, {z_hat.max().item():.0f}]  '
      f'(integer-valued, dtype={z_hat.dtype})')
z shape     = (1, 12, 64, 96)
z range     = [-28.90, 32.35]
z_hat shape = (1, 12, 64, 96)  (1, latent_dim, H/8, W/8)
z_hat range = [-29, 32]  (integer-valued, dtype=torch.float32)
plt.figure(figsize=(5, 3), dpi=120)
plt.hist(z.flatten().cpu().numpy(), bins=121, range=(-30, 30), density=True)
plt.title('Histogram of pre-rounding latents (z)')
plt.xlabel('latent value')
plt.ylabel('density')
plt.grid(alpha=0.3)
plt.show()
_images/d6d0325c3bf2d1023fc843f9e3276317b3c0bb29e0e438c651f1f4b5f29e2d70.webp

Pack to bytes (WebP-lossless)#

blob = latent_to_webp_bytes(z_hat, info.latent_bits)
buff = io.BytesIO(blob)
bpp = 8 * len(buff.getbuffer()) / n_pixels
print(f'len(buff)        = {len(buff.getbuffer())} bytes')
print(f'bpp              = {bpp:.4f}')
print(f'compression ratio = {24 / bpp:.2f}x  (vs 24-bit RGB)')
len(buff)        = 23276 bytes
bpp              = 0.4736
compression ratio = 50.68x  (vs 24-bit RGB)

Visualize the packed latent#

def scale_for_display(img, n_bits):
    """Rescale an n_bits-quantized PIL image to fill 8-bit display range."""
    scale_factor = (2 ** 8 - 1) / (2 ** n_bits - 1)
    lut = [int(i * scale_factor) for i in range(2 ** n_bits)]
    channels = img.split()
    scaled = [ch.point(lut * 2 ** (8 - n_bits)) for ch in channels]
    return PIL.Image.merge(img.mode, scaled)

latent_pil = PIL.Image.open(io.BytesIO(blob))
print(f'latent image: size = {latent_pil.size}, mode = {latent_pil.mode}')
scale_for_display(latent_pil, info.latent_bits)
latent image: size = (192, 128), mode = RGB
_images/2aa900a47d9f28c2019a6bb9ad5992150e0f1bf924cf76fcdbd35c575578809a.webp

Decode from the buffer#

X_hat = decode_from_blob(
    codec, buff.getvalue(), info.latent_dim, info.latent_bits,
    device, torch_dtype,
)
x_hat = from_model_output(X_hat).clamp(0, 1)
print(f'X_hat range = [{X_hat.min():.3f}, {X_hat.max():.3f}]  '
      f'(decoder output, approx [-0.5, 0.5]; outliers clamped below)')
print(f'x_hat range = [{x_hat.min():.3f}, {x_hat.max():.3f}]  '
      f'(after from_model_output + clamp, [0, 1])')
print(f'x_hat shape = {tuple(x_hat.shape)}')
X_hat range = [-0.699, 0.568]  (decoder output, approx [-0.5, 0.5]; outliers clamped below)
x_hat range = [0.000, 1.000]  (after from_model_output + clamp, [0, 1])
x_hat shape = (1, 3, 512, 768)

Reconstruction and PSNR#

mse = torch.nn.functional.mse_loss(x_01, x_hat).item()
psnr = -10 * torch.log10(torch.tensor(mse)).item()
print(f'bpp  = {bpp:.4f}')
print(f'PSNR = {psnr:.2f} dB')
to_pil_image(x_hat[0].cpu())
bpp  = 0.4736
PSNR = 34.24 dB
_images/980c94f17c2f2b999a5c43824cfdba10a44ab5c173c79bde0b4b15d9783e3dbb.webp

Sanity: codec.forward matches the manual pipeline#

with torch.inference_mode():
    x_hat_fwd, _, _ = codec(x_in)              # in [-0.5, 0.5]
x_hat_fwd_01 = from_model_output(x_hat_fwd).clamp(0, 1)
max_diff = (x_hat_fwd_01 - x_hat).abs().max().item()
print(f'max | codec.forward - manual decode | = {max_diff:.6e}')
max | codec.forward - manual decode | = 0.000000e+00