Training: RGB Image

Training: RGB Image#

import torch
import warnings
import torch.nn.functional as F
import numpy as np 
import io
from datasets import load_dataset
from IPython.display import display, Image, update_display, HTML
from torchvision.transforms import ToPILImage, RandomCrop, PILToTensor
from fastprogress.fastprogress import master_bar, progress_bar
from walloc import walloc
device = "cuda"
train_dataset = load_dataset("danjacobellis/LSDIR_540", split="train")
valid_dataset = load_dataset("danjacobellis/imagenet_hq", split="validation")

def collate_fn(batch):
    return torch.cat([
        PILToTensor()(RandomCrop(256)(sample['image'])).unsqueeze(0) for sample in batch
    ])

dataloader_valid = torch.utils.data.DataLoader(
    valid_dataset.select([1,7,30,33]),
    batch_size=4,
    num_workers=12,
    drop_last=True,
    shuffle=False,
    collate_fn=collate_fn
)

valid_batch = next(iter(dataloader_valid))
valid_batch = valid_batch/255
valid_batch = valid_batch - 0.5
valid_batch = valid_batch.to(device)

for img in valid_batch:
    display(ToPILImage()(img+0.5))
_images/6f6cd6ba970940e68eaee8caca9cf3c6c6d5ebedda9788faf561e46dc848fa27.png _images/fc97a55f2bf5133f5a58c2d1ee2fb585309df7b98e96d74efa1131dd67ea70b5.png _images/9ccf7febce94a4cd2abc9798f9cddcb83fecb721ad8f7a1e26e737f6c18ae2d4.png _images/3236a50952b05293817337318553a6134c4cc705563eeb011994d36ee7d466a3.png
class Config: pass
config = Config()
config.batch_size = 64
config.num_workers = 12
config.grad_accum_steps = 1
config.plot_update = 128
config.patience = 64
config.min_lr = 1e-7
config.max_lr = 3e-5
config.warmup_steps = 50000
config.weight_decay = 0.
config.epochs = 150
config.ϕ = 0.

config.channels=3
config.J = 3
config.Ne = None
config.Nd = 768
config.latent_dim = 16
config.latent_bits = 8
config.lightweight_encode = True
codec = walloc.Codec2D(
    channels=config.channels,
    J=config.J,
    Ne=config.Ne,
    Nd=config.Nd,
    latent_dim=config.latent_dim,
    latent_bits=config.latent_bits,
    lightweight_encode=config.lightweight_encode
).to(device)

optimizer = torch.optim.AdamW(
    params=codec.parameters(),
    weight_decay=0.0,
    lr = config.min_lr
)

sum(p.numel() for p in codec.parameters())/1e6
56.914384
def minus_cosine_warmup(i_step):
    scale = 0.5 * (np.log10(config.max_lr) - np.log10(config.min_lr))
    angle =  np.pi * i_step / (config.warmup_steps//config.plot_update)
    log_lr = np.log10(config.min_lr) + scale * (1 - np.cos(angle))
    lr = 10 ** log_lr
    return lr/config.min_lr
    
warmup = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda = lambda i_step: minus_cosine_warmup(i_step)
)

reduce_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.98,
    patience=config.patience,
    threshold=1e-5,
    min_lr=config.min_lr
)
warnings.filterwarnings("ignore", message="Truncated File Read", category=UserWarning, module="PIL.TiffImagePlugin")
dist_losses, rate_losses = [], []
learning_rates = [optimizer.param_groups[0]['lr']]
img_displays = []
text_display = None
codec.train()
optimizer.zero_grad()
mb = master_bar(range(config.epochs))
mb.names = ['Distortion Loss', 'Smoothed']
i_step = 0
for i_epoch in mb:
    dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        drop_last=True,
        shuffle=True,
        collate_fn=collate_fn
    )
   
    for i, batch in enumerate(progress_bar(dataloader, parent=mb)):
    
        x = batch.to(device)
        x = x.to(torch.float)
        x = x/255
        x = x-0.5

        _, mse_loss, tf_loss = codec(x)

        dist_loss = mse_loss + config.ϕ*tf_loss
        dist_losses.append(np.log10(dist_loss.item()))
        loss = dist_loss
        loss.backward()
        if (i + 1) % config.grad_accum_steps == 0: 
            optimizer.step()
            optimizer.zero_grad()
        
        # plotting and lr scheduler
        if len(dist_losses) % config.plot_update == 0:
            plot_n = len(dist_losses) // config.plot_update
            smoothed_x = (0.5+torch.arange(plot_n)) * config.plot_update
            smoothed_y = torch.tensor(dist_losses).reshape(plot_n, -1).mean(dim=1)
            dist_x = range(len(dist_losses))
            dist_y = dist_losses
            mb.update_graph([[dist_x, dist_y],[smoothed_x, smoothed_y]])
            mb.child.comment = f'loss {smoothed_y[-1]:.4g}; lr {learning_rates[-1]:.4g}'

            # lr scheduler
            if i_step < config.warmup_steps:
                warmup.step()
            else:
                reduce_plateau.step(smoothed_y[-1])
            learning_rates.append(optimizer.param_groups[0]['lr'])

            with torch.no_grad():
                codec.eval()
                y_valid, _, _ = codec(valid_batch)
                codec.train()
            
            for img_idx, img in enumerate(y_valid):
                buffer = io.BytesIO()
                ToPILImage()(img + 0.5).save(buffer, format="PNG")
                # ToPILImage()(img + 0.5).save(f"video/{img_idx}_{i_epoch}_{i}.png")
                buffer.seek(0)
                if len(img_displays) <= img_idx:
                    img_displays.append(display(Image(buffer.read()), display_id=True))
                else:
                    update_display(Image(buffer.read()), display_id=img_displays[img_idx].display_id)
        i_step+=1
        
    torch.save({
        'model_state_dict': codec.state_dict(),
        'i_epoch': i_epoch,
        'learning_rates': learning_rates,
        'dist_losses': dist_losses,
        'config': config
    }, f"log_{device}.pth")
/home/dgj335/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Attempt to open cnn_infer failed: handle=0 error: libcudnn_cnn_infer.so.8: cannot open shared object file: No such file or directory (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:78.)
  return F.conv2d(input, weight, bias, self.stride,
_images/5e71be6b2cf0625159137f7e00943f754d339714c76a32f9699e24fd0663cfd9.png _images/a907f1491219fc1e4a544bf70325120dd63ea9303d710319732e1a1208020e95.png _images/c24469981348d73d9815cbb83db66aa598233285abf13f08c5d23263df7f8b55.png _images/df83d0df2bafcd69690a43acdbf7a8f3b2dfaf8c1db6bb655cf8ba0cb310fe7d.png _images/68953558f9c85db64064586791044ca36a79ccce683da8f143d409cc96d4da12.png
from IPython.display import display, Image, update_display, HTML
import matplotlib.pyplot as plt
display(HTML(mb.main_bar.progress))
display(HTML(mb.child.progress))
100.00% [150/150 21:42:48<00:00]
100.00% [1327/1327 08:40<00:00 loss -2.758; lr 2.881e-05]
torch.save({
    'model_state_dict': codec.state_dict(),
    'i_epoch': i_epoch,
    'learning_rates': learning_rates,
    'dist_losses': dist_losses,
    'config': config
}, "../../hf/walloc/RGB_Li_16c_J3_nf8_v1.0.2.pth")
codec.eval()
with torch.no_grad():
    y_valid, _, _ = codec(valid_batch)
for img_idx, img in enumerate(y_valid):
    display(ToPILImage()(img + 0.5))
_images/08f97caeadf837c35d7626aad769311223378bfbc9fdb1324e26fdc995ea30db.png _images/21b5c301e5cbdbc9df28e23e30973f0ca62f4279c513e1f7d7b98979d1321c42.png _images/74bbc460082b8144ad23435d3cf4782bc925ff6d625bba2d4e8220ff510e6566.png _images/4f498e06714aef18311afdd22236d4521a054e35f6655ed73509541c92083fde.png
with torch.no_grad():
    X = codec.wavelet_analysis(valid_batch,codec.J)
    Z = codec.encoder[:1](X)
    Z2 = codec.encoder(X)
import zlib
bytes = Z.round().to(torch.int8).detach().cpu().numpy().tobytes()
print(valid_batch.numel()/len(bytes))
print(valid_batch.numel()/len(zlib.compress(bytes,level=9)))
12.0
21.569128658018155
plt.hist(Z.to("cpu").flatten().numpy(), density=True,range=(-20,20),bins=100);
_images/5e4fa7c02ac01002152b107295660ff1f9263d155a2c00419a1009e0555845a3.png
plt.hist(Z2.to("cpu").round().flatten().numpy(), density=True,range=(-31.5,31.5), bins=63, width=1.);
_images/449c89dc0e5e9a18e841b55469f8f55e43f2553cb38c9c757bfbfbe159a8be84.png