| import torch |
| from torch.utils.data import DataLoader, Dataset |
| import torchaudio |
| import torchvision.transforms as tvt |
| from denoising_diffusion_pytorch.classifier_free_guidance import Unet, GaussianDiffusion |
| import glob |
| import torch.nn as nn |
| import time, math |
| from PIL import Image |
| from diffusers import Mel |
| import sys |
| import torchaudio |
| import librosa |
| import matplotlib.pyplot as plt |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| args = sys.argv[1:] |
|
|
| class Audio(Dataset): |
| def __init__(self, folder): |
| |
| self.waveforms = [] |
| self.labels = [] |
| print("Loading files...") |
| for file in glob.iglob(folder + '/**/*.wav', recursive=True): |
| self.labels.append(int(file.split('/')[-1][0])) |
| waveform, _ = torchaudio.load(file) |
| |
| self.waveforms.append(waveform) |
| |
| def __len__(self): |
| return len(self.waveforms) |
|
|
| def __getitem__(self, index): |
| return self.waveforms[index], self.labels[index] |
|
|
|
|
| image_size = 256 |
| if len(args) >= 1: |
| image_size = int(args[0]) |
|
|
| MEL = Mel(x_res=image_size, y_res=image_size) |
| img_to_tensor = tvt.PILToTensor() |
|
|
| def collate(batch): |
| spectros = [] |
| labels = [] |
| for waveform, label in batch: |
| MEL.load_audio(raw_audio=waveform[0]) |
| for slice in range(MEL.get_number_of_slices()): |
| spectro = MEL.audio_slice_to_image(slice) |
| spectro = img_to_tensor(spectro) / 255.0 |
| |
| |
| |
| |
| spectros.append(spectro) |
| labels.append(label) |
|
|
| spectros = torch.stack(spectros) |
| labels = torch.tensor(labels) |
| |
| return spectros.to(device), labels.to(device) |
|
|
|
|
| def initialize(scheduler = None, batch_size=32): |
| model = Unet( |
| dim = 64, |
| num_classes=10, |
| dim_mults=(1, 2, 4, 8), |
| channels=1 |
| ) |
| diffusion = GaussianDiffusion( |
| model, |
| image_size=image_size, |
| timesteps=1000, |
| loss_type = 'l2', |
| objective='pred_x0', |
| |
| ) |
| diffusion.to(device) |
| |
| optim = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8) |
| if scheduler: |
| scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-5, max_lr=1e-3, mode="exp_range", cycle_momentum=False) |
| return diffusion, optim, scheduler |
|
|
| def timeSince(since): |
| now = time.time() |
| s = now - since |
| m = math.floor(s / 60) |
| s -= m * 60 |
| return '%dm %ds' % (m, s) |
|
|
| start = time.time() |
|
|
| def train(model, optim, train_dl, batch_size=32, epochs=5, scheduler = None): |
| size = len(train_dl.dataset) |
| model.train() |
| losses = [] |
| |
| for e in range(epochs): |
| batch_loss, batch_counts = 0, 0 |
| for step, batch in enumerate(train_dl): |
| model.zero_grad() |
| batch_counts += 1 |
| spectros, labels = batch |
| loss = model(spectros, classes=labels) |
| |
| batch_loss += loss.item() |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1) |
| optim.step() |
| if scheduler is not None: |
| scheduler.step() |
| |
| if (step % 100 == 0 and step != 0) or (step == len(train_dl) - 1): |
| to_print = f"{e + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {timeSince(start)} | {step*batch_size:>5d}/{size:>5d}" |
| print(to_print) |
| losses.append(batch_loss) |
| batch_loss, batch_counts = 0, 0 |
|
|
| labels = torch.randint(0,9,(8, )).to(device) |
| print(labels) |
| samples = model.sample(labels) |
| for i, sample in enumerate(samples): |
| im = Image.fromarray(sample[0].cpu().numpy() * 255).convert('L') |
| audio = torch.tensor([MEL.image_to_audio(im)]) |
| torchaudio.save(f"audio/sample{e}_{i}_{labels[i]}.wav", audio, 48000) |
| im.save(f"images/sample{e}_{i}_{labels[i]}.jpg") |
| return losses |
|
|
| if __name__ == "__main__": |
| num_epochs = 10 |
| if len(args) >= 2: |
| num_epochs = int(args[1]) |
|
|
| batch_size = 32 |
| if len(args) >= 3: |
| batch_size = int(args[2]) |
|
|
| print(image_size, num_epochs, batch_size) |
| model, optim, scheduler = initialize(scheduler=True, batch_size=batch_size) |
| train_data = Audio("AudioMNIST/data") |
| print("Done Loading") |
| train_dl = DataLoader(train_data, batch_size, True, collate_fn=collate) |
| train(model, optim, train_dl, batch_size, num_epochs, scheduler) |
| torch.save(model.state_dict(), "diffusion_condition_model.pt") |