Training: Stereo Audio

Training: Stereo Audio#

import torch
import warnings
import torch.nn.functional as F
import io
import torchaudio
import torchvision
import numpy as np
import random
import PIL
from datasets import load_dataset
from torchvision.transforms import ToPILImage
from IPython.display import display, Audio, Image, update_display, HTML
from fastprogress.fastprogress import master_bar, progress_bar
from walloc import walloc
from itertools import combinations
device = "cuda:3"
class Config: pass
config = Config()
config.batch_size = 32
config.num_workers = 24
config.grad_accum_steps = 1
config.plot_update = 64
config.patience = 64
config.min_lr = 1e-7
config.max_lr = 3e-5
config.warmup_steps = 5000
config.weight_decay = 0.
config.epochs = 30
config.epoch_len = 10000
config.ϕ = 0.

config.audio_length = 2**19
config.channels=2
config.J = 8
config.Ne = None
config.Nd = 768
config.latent_dim = 108
config.latent_bits = 8
config.lightweight_encode = True
config.post_filter = True
train_dataset = load_dataset("danjacobellis/musdb_segments", split="train")
valid_dataset = load_dataset("danjacobellis/musdb_segments", split="validation")
mixing_weights = []

# Original mixture
mixing_weights.append(torch.ones(4))

# Amplify two channels, attenuate two
for indices in combinations(range(4), 2):
    weights = torch.full((4,), 0.5)
    weights[list(indices)] = 1.5
    mixing_weights.append(weights)

# Combinations of 3 channels (set to 1.0), one channel set to 0.0
for indices in combinations(range(4), 3):
    weights = torch.zeros(4)
    weights[list(indices)] = 1.0
    mixing_weights.append(weights)

# Combinations of 2 channels (set to 1.0), others set to 0.0
for indices in combinations(range(4), 2):
    weights = torch.zeros(4)
    weights[list(indices)] = 1.0
    mixing_weights.append(weights)

# Combinations of 1 channel
for index in range(4):
    weights = torch.zeros(4)
    weights[index] = 1.0
    mixing_weights.append(weights)
L = config.audio_length
C = config.channels
crop = torchvision.transforms.RandomCrop((4, L))
center_crop = torchvision.transforms.CenterCrop((1, L))
def collate_fn(batch):
    B = len(batch)
    x = torch.zeros((B, C, L), dtype=torch.float)
    for i_sample, sample in enumerate(batch):
        xi = torch.zeros((C,4,2**21), dtype=torch.int16)
        for i_instr, instr in enumerate(['vocal', 'bass', 'drums', 'other']):
            audio, fs = torchaudio.load(sample[f'audio_{instr}']['bytes'], normalize=False)
            xi[:,i_instr,:] = audio
        xi = crop(xi).to(torch.float)
        w = random.choice(mixing_weights).view(1, -1, 1)
        xi = (w*xi).sum(dim=1)
        xi = xi - xi.mean()
        max_abs = xi.abs().max()
        xi = xi / (max_abs + 1e-8)
        xi = xi/2
        # if random.random() < 0.5:
        #     xi = -xi
        # if random.random() < 0.5:
        #     xi = xi.flip(0)
        x[i_sample,:,:] = xi
    return x

def valid_collate_fn(batch):
    B = len(batch)
    x = torch.zeros((B, C, L), dtype=torch.float)
    for i_sample, sample in enumerate(batch):
        xi = torch.zeros((C, 1, 2**21), dtype=torch.int16)
        audio_mix, fs = torchaudio.load(sample['audio_mix']['bytes'], normalize=False)
        xi[:, 0, :] = audio_mix
        xi = center_crop(xi).to(torch.float)
        xi = xi.squeeze(1)
        xi = xi - xi.mean()
        max_abs = xi.abs().max()
        xi = xi / (max_abs + 1e-8)
        xi = xi / 2
        x[i_sample, :, :] = xi
    return x
def make_spectrogram(X):
    X = spectrogram(X).log()
    X = X - X.mean()
    X = X/X.std()
    X = X/3
    X = X.clamp(-0.5,0.5)
    X = X + 0.5
    return ToPILImage()(X)
valid_batch = valid_dataset[200:201]
valid_batch = [dict(zip(valid_batch.keys(), values)) for values in zip(*valid_batch.values())]
x_valid = valid_collate_fn(valid_batch).to(device)
spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=44100,
    n_fft=4096,
).to(device)
SG = make_spectrogram(x_valid[0,0].to(device))
display(SG.resize((1024,256)).transpose(method=PIL.Image.FLIP_TOP_BOTTOM))
SG = make_spectrogram(x_valid[0,1].to(device))
display(SG.resize((1024,256)).transpose(method=PIL.Image.FLIP_TOP_BOTTOM))
_images/65f7cc1522c944409d59bf75d767f02970ca805d3fc2a2f69806ebd29ea19486.png _images/47d8e156754ec79f63904b01bcf451d29355b2f33d679b9be87677994fb1a9f9.png
Audio(x_valid.to("cpu").numpy()[0],rate=44100)