Skip to content
Permalink
df72c8e6e6
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
1434 lines (994 sloc) 55 KB
"""
an adversarial autoencoder for raw audio
Notes on discriminators:
the model uses two discriminators, one for distinguishing between real and fake prior distributions
and one for distinguishing between real and fake audio.
Notes on dilated convolutions:
the model combines dilated convolution with regular convolution.
It does so my creating branched paths for each convolution step where dilation is greater than 0,
the branched paths then join again by adding the results of the regular convolution and dilated convolution.
"""
# Imports
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import OrderedDict
import simpleaudio as sa
from scipy.io.wavfile import read as audio_read
from scipy.io.wavfile import write as audio_write
import numpy as np
import glob
from matplotlib import pyplot as plt
import os, time
import json
from common import utils
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))
#device = 'cpu'
# Configurations
# Audio Configuration
audio_file_path = "../data/sonification_audio"
#audio_file_path = "../data/improvisation_audio"
audio_save_sample_rate = 32000 # sample rate for saving example sonifications
audio_sample_rate = 32000 # numer of audio samples per sec
audio_window_length = 256 # size of audio window in number of samples
audio_window_rate = 250 # number of audio windows per sec
standardize_audio = True
# Training Configuration
audio_window_offset = 32
batch_size = 128
epochs = 200
train_percentage = 0.9
test_percentage = 0.1
ae_learning_rate = 1e-4
disc_audio_learning_rate = 4e-4
disc_prior_learning_rate = 4e-4
ae_rec_loss_scale = 1.0
ae_disc_audio_loss_scale = 1.0
ae_disc_prior_loss_scale = 0.1
ae_l1_loss_scale = 0.0
ae_l2_loss_scale = 0.01
disc_prior_label_smoothing = 0.9
disc_audio_label_smoothing = 0.9
model_save_interval = 50
save_history = False
# Model Configuration
# Autoencoder Configuration
latent_dim = 32
ae_activation_function = "LeakyReLU"
ae_conv_filter_counts = [8, 16, 32, 64]
ae_conv_kernel_sizes = [7, 7, 7, 7] # important: use odd kernel sizes otherwise padding calculation fails
ae_conv_strides = [4, 4, 4, 4]
ae_conv_dilations = [0, 0, 0, 0]
ae_dense_layer_sizes = [32]
ae_dropout_prob = 0.0
ae_use_batch_normalization = True
ae_use_layer_normalization = False
# Prior Discriminator Configuration
disc_prior_activation_function = "LeakyReLU"
disc_prior_dense_layer_sizes = [32, 32]
disc_prior_dropout_prob = 0.0
disc_prior_use_batch_normalization = True
disc_prior_use_layer_normalization = False
# Audio Discriminator Configuration
disc_audio_activation_function = "LeakyReLU"
disc_audio_conv_filter_counts = [8, 16, 32, 64]
disc_audio_conv_kernel_sizes = [7, 7, 7, 7] # important: use odd kernel sizes otherwise padding calculation fails
disc_audio_conv_strides = [4, 4, 4, 4]
disc_audio_conv_dilations = [0, 0, 0, 0]
disc_audio_dense_layer_sizes = [32]
disc_audio_dropout_prob = 0.0
disc_audio_use_batch_normalization = True
disc_audio_use_layer_normalization = False
# Save / Load Model Weights
save_models = True
save_tscript = True
save_weights = False
load_weights = False
disc_prior_model_file = "../results_pytorch/sonification/aae/models/disc_prior"
disc_audio_model_file = "../results_pytorch/sonification/aae/models/disc_audio"
ae_encoder_model_file = "../results_pytorch/sonification/aae/models/ae_encoder"
ae_decoder_model_file = "../results_pytorch/sonification/aae/models/ae_decoder"
disc_prior_weights_file = "../results_pytorch/sonification/aae/weights/disc_prior_weights_epoch_200"
disc_audio_weights_file = "../results_pytorch/sonification/aae/weights/disc_audio_weights_epoch_200"
ae_encoder_weights_file = "../results_pytorch/sonification/aae/weights/ae_encoder_weights_epoch_200"
ae_decoder_weights_file = "../results_pytorch/sonification/aae/weights/ae_decoder_weights_epoch_200"
# Save Audio Examples
save_audio = False
audio_save_interval = 100
audio_save_start_times = [ 40.0, 120.0, 300.0, 480.0 ] # in seconds
audio_save_duration = 10.0
audio_traverse_start_window = 0
audio_traverse_end_window = 1000
audio_traverse_window_count = 10
audio_traverse_interpolation_count = 100
audio_result_path = "../results_pytorch/sonification/aae/audio"
# Prepare Dataset
# load audio
audio = audio_read(audio_file_path + ".wav")
# convert audio from int16 to normalized float
audio_sample_scale = 2**15
audio_float_array = np.array(audio[1], dtype=np.float32)
audio_float_array = audio_float_array / audio_sample_scale
# create standardized audio float array
if standardize_audio:
audio_mean = np.mean(audio_float_array)
audio_std = np.std(audio_float_array)
audio_standardized = (audio_float_array - audio_mean) / (audio_std)
else:
audio_standardized = audio_float_array
# gather audio training data
audio_training_data = []
max_audio_idx = audio_standardized.shape[0] - audio_window_length - 1
for audio_idx in range(0, max_audio_idx, audio_window_offset):
audio_training_data.append( audio_standardized[audio_idx:audio_idx+audio_window_length] )
audio_training_data = np.array(audio_training_data)
audio_count = audio_training_data.shape[0]
# create dataset
class AudioDataset(Dataset):
def __init__(self, audio_windows):
self.audio_windows = audio_windows
def __len__(self):
return self.audio_windows.shape[0]
def __getitem__(self, idx):
return self.audio_windows[idx, ...]
full_dataset = AudioDataset(audio_training_data)
dataset_size = len(full_dataset)
test_size = int(test_percentage * dataset_size)
train_size = dataset_size - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# Save Original Audio Examples
def create_ref_sonifications(start_time, duration, file_name):
start_sample = int(start_time * audio_save_sample_rate)
sample_count = int(duration * audio_save_sample_rate)
ref_audio = audio_standardized[start_sample:start_sample + sample_count]
if standardize_audio:
ref_audio = audio_mean + ref_audio * audio_std
audio_write(file_name, audio_save_sample_rate, ref_audio)
if save_audio:
for audio_start_time in audio_save_start_times:
file_name = "{}/ref_audio_{}.wav".format(audio_result_path, audio_start_time)
create_ref_sonifications(audio_start_time, audio_save_duration, file_name)
# Create Models
# function returning normal distributed random data
# serves as reference for the discriminator to distinguish the encoders prior from
def sample_normal(shape):
return torch.tensor(np.random.normal(size=shape), dtype=torch.float32).to(device)
class DiscriminatorPrior(nn.Module):
def __init__(self):
super(DiscriminatorPrior, self).__init__()
dense_layer_count = len(disc_prior_dense_layer_sizes)
layers = []
in_size = latent_dim
for layer_index in range(dense_layer_count):
out_size = disc_prior_dense_layer_sizes[layer_index]
if disc_prior_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
layers.append(nn.Linear(in_size, out_size))
activation_function = getattr(nn, disc_prior_activation_function)()
layers.append(activation_function)
if disc_prior_use_layer_normalization:
layers.append(nn.LayerNorm())
if disc_prior_dropout_prob > 0.0:
layers.append(nn.Dropout(p=disc_prior_dropout_prob))
in_size = out_size
layers.append(nn.Linear(in_size, 1))
layers.append(nn.Sigmoid())
self.model = nn.Sequential(*layers)
def forward(self, x):
yhat = self.model(x)
return yhat
disc_prior = DiscriminatorPrior().to(device)
print(disc_prior)
x = torch.rand((batch_size, latent_dim), dtype=torch.float32).to(device)
x.shape
x = disc_prior(x)
x.shape
"""
for name, param in discriminator_prior.named_parameters():
print(f"Layer: {name} | Size: {param.size()}")
"""
if save_models == True:
# save using pickle
torch.save(disc_prior, "{}.pth".format(disc_prior_model_file))
# save using onnx
x = torch.zeros((1, latent_dim)).to(device)
torch.onnx.export(disc_prior, x, "{}.onnx".format(disc_prior_model_file))
if save_tscript == True:
disc_prior.eval()
# save using TochScript
#x = torch.rand((batch_size, latent_dim), dtype=torch.float32).to(device)
x = torch.rand((1, latent_dim), dtype=torch.float32).to(device)
script_module = torch.jit.trace(disc_prior, x)
script_module.save("{}.pt".format(disc_prior_model_file))
disc_prior.train()
if load_weights and disc_prior_weights_file:
disc_prior.load_state_dict(torch.load(disc_prior_weights_file))
# Dilation Block
class DilationBlock(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, stride, dilation, activation):
super().__init__()
branch1 = []
branch1.append(nn.Conv1d(in_channels=in_filters, out_channels=out_filters, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2))
if activation != None:
activation_function = getattr(nn, disc_prior_activation_function)()
branch1.append(activation_function)
self.branch1 = nn.Sequential(*branch1)
branch2 = []
branch2.append(nn.Conv1d(in_channels=in_filters, out_channels=out_filters, kernel_size=kernel_size, dilation=dilation, padding=dilation * (kernel_size - 1) // 2))
if activation != None:
activation_function = getattr(nn, disc_prior_activation_function)()
branch2.append(activation_function)
branch2.append(nn.Conv1d(in_channels=out_filters, out_channels=out_filters, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2))
if activation != None:
activation_function = getattr(nn, disc_prior_activation_function)()
branch2.append(activation_function)
self.branch2 = nn.Sequential(*branch2)
def forward(self, x):
#print("dilation x ", x.shape)
x1 = self.branch1(x)
#print("dilation x1 ", x1.shape)
x2 = self.branch2(x)
#print("dilation x2 ", x2.shape)
x = x1 + x2
#print("dilation x3 ", x.shape)
return x
"""
# debug
in_size = 256
in_filters = 32
out_filters = 32
x = torch.rand((batch_size, in_filters, in_size), dtype=torch.float32)
print("x1 ", x.shape)
kernel_size = 7
stride = 4
dilation = 32
padding = (kernel_size - 1) // 2 * dilation
x = DilationBlock(in_filters=in_filters, out_filters=out_filters, kernel_size=kernel_size, stride=stride, dilation=dilation, activation="LeakyReLU")(x)
print("x2 ", x.shape)
"""
# Transposed Dilation Block
class TransposedDilationBlock(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, stride, dilation, activation):
super().__init__()
activation_function = getattr(nn, disc_prior_activation_function)()
branch1 = []
branch1.append(nn.ConvTranspose1d(in_channels=in_filters, out_channels=out_filters, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, output_padding=(kernel_size - 1) // 2))
if activation != None:
activation_function = getattr(nn, disc_prior_activation_function)()
branch1.append(activation_function)
self.branch1 = nn.Sequential(*branch1)
branch2 = []
branch2.append(nn.ConvTranspose1d(in_channels=in_filters, out_channels=in_filters, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, output_padding=(kernel_size - 1) // 2))
if activation != None:
activation_function = getattr(nn, disc_prior_activation_function)()
branch2.append(activation_function)
branch2.append(nn.ConvTranspose1d(in_channels=in_filters, out_channels=out_filters, kernel_size=kernel_size, dilation=dilation, padding=dilation * (kernel_size - 1) // 2))
if activation != None:
activation_function = getattr(nn, disc_prior_activation_function)()
branch2.append(activation_function)
self.branch2 = nn.Sequential(*branch2)
def forward(self, x):
#print("dilation x ", x.shape)
x1 = self.branch1(x)
#print("dilation x1 ", x1.shape)
x2 = self.branch2(x)
#print("dilation x2 ", x2.shape)
x = x1 + x2
#print("dilation x3 ", x.shape)
return x
"""
# debug
in_size = 256
in_filters = 32
out_filters = 32
x = torch.rand((batch_size, in_filters, in_size), dtype=torch.float32)
print("x1 ", x.shape)
kernel_size = 7
stride = 4
dilation = 32
padding = (kernel_size - 1) // 2 * dilation
x = TransposedDilationBlock(in_filters=in_filters, out_filters=out_filters, kernel_size=kernel_size, stride=stride, dilation=dilation, activation="LeakyReLU")(x)
print("x2 ", x.shape)
"""
# Create Audio Discriminator
class DiscriminatorAudio(nn.Module):
def __init__(self):
super(DiscriminatorAudio, self).__init__()
self.conv_model = None
self.dense_model = None
activation = disc_audio_activation_function
conv_layer_count = len(disc_audio_conv_filter_counts)
dense_layer_count = len(disc_audio_dense_layer_sizes)
in_size = audio_window_length
in_channels = 1
if conv_layer_count > 0:
layers = []
for layer_index in range(conv_layer_count):
out_channels = disc_audio_conv_filter_counts[layer_index]
kernel_size = disc_audio_conv_kernel_sizes[layer_index]
stride = disc_audio_conv_strides[layer_index]
dilation = disc_audio_conv_dilations[layer_index]
if disc_audio_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_channels))
if dilation > 0:
layers.append(DilationBlock(in_channels, out_channels, kernel_size, stride, dilation, activation ))
else:
layers.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2))
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if disc_audio_use_layer_normalization:
layers.append(nn.LayerNorm())
if disc_audio_dropout_prob > 0.0:
layers.append(nn.Dropout(p=disc_prior_dropout_prob))
#print("conv ", layer_index, " size_i ", in_size, " size_o ", (in_size // stride), " channels_i ", in_channels, " channels_o ", out_channels )
in_channels = out_channels
in_size = in_size // stride # only correct for padding == "same"
self.conv_model = nn.Sequential(*layers)
# flattened size
in_size *= out_channels
if dense_layer_count > 0:
layers = []
for layer_index in range(dense_layer_count):
out_size = disc_audio_dense_layer_sizes[layer_index]
if disc_audio_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
layers.append(nn.Linear(in_size, out_size))
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if disc_audio_use_layer_normalization:
layers.append(nn.LayerNorm())
if disc_audio_dropout_prob > 0.0:
layers.append(nn.Dropout(p=disc_prior_dropout_prob))
#print("dense ", layer_index, " size_i ", in_size, " size_o ", out_size )
in_size = out_size
if disc_audio_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
layers.append(nn.Linear(in_size, 1))
layers.append(nn.Sigmoid())
self.dense_model = nn.Sequential(*layers)
def forward(self, x):
#print("x1 ", x.shape)
if self.conv_model != None:
x = torch.unsqueeze(x, dim=1)
#print("x2 ", x.shape)
x = self.conv_model(x)
#print("x3 ", x.shape)
x = x.view(-1, x.shape[1] * x.shape[2])
#print("x4 ", x.shape)
yhat = self.dense_model(x)
#print("x5 ", yhat.shape)
return yhat
disc_audio = DiscriminatorAudio().to(device)
print(disc_audio)
x = torch.rand((batch_size, audio_window_length), dtype=torch.float32).to(device)
x.shape
x = disc_audio(x)
x.shape
"""
for name, param in discriminator_prior.named_parameters():
print(f"Layer: {name} | Size: {param.size()}")
"""
if save_models == True:
# save using pickle
torch.save(disc_audio, "{}.pth".format(disc_audio_model_file))
# save using onnx
x = torch.zeros((1, audio_window_length)).to(device)
torch.onnx.export(disc_audio, x, "{}.onnx".format(disc_audio_model_file))
if save_tscript == True:
disc_audio.eval()
# save using TochScript
x = torch.rand((1, audio_window_length), dtype=torch.float32).to(device)
script_module = torch.jit.trace(disc_audio, x)
script_module.save("{}.pt".format(disc_audio_model_file))
disc_audio.train()
if load_weights and disc_audio_weights_file:
disc_audio.load_state_dict(torch.load(disc_audio_weights_file))
# Create Audio Encoder
class AudioEncoder(nn.Module):
def __init__(self):
super(AudioEncoder, self).__init__()
self.conv_model = None
self.dense_model = None
activation = ae_activation_function
conv_layer_count = len(ae_conv_filter_counts)
dense_layer_count = len(ae_dense_layer_sizes)
in_size = audio_window_length
in_channels = 1
if conv_layer_count > 0:
layers = []
for layer_index in range(conv_layer_count):
out_channels = ae_conv_filter_counts[layer_index]
kernel_size = ae_conv_kernel_sizes[layer_index]
stride = ae_conv_strides[layer_index]
dilation = ae_conv_dilations[layer_index]
if ae_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_channels))
if dilation > 0:
layers.append(DilationBlock(in_channels, out_channels, kernel_size, stride, dilation, activation ))
else:
layers.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2))
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if ae_use_layer_normalization:
layers.append(nn.LayerNorm())
if ae_dropout_prob > 0.0:
layers.append(nn.Dropout(p=ae_dropout_prob))
#print("conv ", layer_index, " size_i ", in_size, " size_o ", (in_size // stride), " channels_i ", in_channels, " channels_o ", out_channels )
in_channels = out_channels
in_size = in_size // stride # only correct for padding == "same"
self.conv_model = nn.Sequential(*layers)
# flattened size
in_size *= out_channels
if dense_layer_count > 0:
layers = []
for layer_index in range(dense_layer_count):
out_size = ae_dense_layer_sizes[layer_index]
if disc_audio_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
layers.append(nn.Linear(in_size, out_size))
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if ae_use_layer_normalization:
layers.append(nn.LayerNorm())
if ae_dropout_prob > 0.0:
layers.append(nn.Dropout(p=ae_dropout_prob))
#print("dense ", layer_index, " size_i ", in_size, " size_o ", out_size )
in_size = out_size
if ae_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
layers.append(nn.Linear(in_size, latent_dim))
self.dense_model = nn.Sequential(*layers)
def forward(self, x):
#print("AudioEncoder forward")
#print("x1 ", x.shape)
if self.conv_model != None:
x = torch.unsqueeze(x, dim=1)
#print("x2 ", x.shape)
x = self.conv_model(x)
#print("x3 ", x.shape)
x = x.view(-1, x.shape[1] * x.shape[2])
#print("x4 ", x.shape)
yhat = self.dense_model(x)
#print("x5 ", yhat.shape)
return yhat
ae_encoder = AudioEncoder().to(device)
print(ae_encoder)
x = torch.rand((batch_size, audio_window_length), dtype=torch.float32).to(device)
x.shape
x = ae_encoder(x)
x.shape
"""
for name, param in discriminator_prior.named_parameters():
print(f"Layer: {name} | Size: {param.size()}")
"""
if save_models == True:
# save using pickle
torch.save(ae_encoder, "{}.pth".format(ae_encoder_model_file))
# save using onnx
x = torch.zeros((1, audio_window_length)).to(device)
torch.onnx.export(ae_encoder, x, "{}.onnx".format(ae_encoder_model_file))
if save_tscript == True:
ae_encoder.eval()
# save using TochScript
x = torch.rand((1, audio_window_length), dtype=torch.float32).to(device)
script_module = torch.jit.trace(ae_encoder, x)
script_module.save("{}.pt".format(ae_encoder_model_file))
ae_encoder.train()
if load_weights and ae_encoder_weights_file:
ae_encoder.load_state_dict(torch.load(ae_encoder_weights_file))
# Create Audio Decoder
rev_ae_dense_layer_sizes = ae_dense_layer_sizes.copy()
rev_ae_dense_layer_sizes.reverse()
rev_ae_conv_filter_counts = ae_conv_filter_counts.copy()
rev_ae_conv_kernel_sizes = ae_conv_kernel_sizes.copy()
rev_ae_conv_strides = ae_conv_strides.copy()
rev_ae_conv_dilations = ae_conv_dilations.copy()
rev_ae_conv_filter_counts.reverse()
rev_ae_conv_kernel_sizes.reverse()
rev_ae_conv_strides.reverse()
rev_ae_conv_dilations.reverse()
rev_ae_conv_filter_counts = rev_ae_conv_filter_counts[1:]
rev_ae_conv_filter_counts.append(1)
x = torch.rand((batch_size, 1, audio_window_length), dtype=torch.float32).to(device)
x = ae_encoder.conv_model(x)
shape_before_flattening = list(x.shape)
shape_after_flattening = [shape_before_flattening[0], shape_before_flattening[1] * shape_before_flattening[2]]
class AudioDecoder(nn.Module):
def __init__(self):
super(AudioDecoder, self).__init__()
self.dense_model = None
self.conv_model = None
activation = ae_activation_function
dense_layer_count = len(rev_ae_dense_layer_sizes)
conv_layer_count = len(ae_conv_filter_counts)
in_size = latent_dim
if dense_layer_count > 0:
layers = []
for layer_index in range(dense_layer_count):
out_size = rev_ae_dense_layer_sizes[layer_index]
if disc_audio_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
layers.append(nn.Linear(in_size, out_size))
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if disc_audio_use_layer_normalization:
layers.append(nn.LayerNorm())
if disc_audio_dropout_prob > 0.0:
layers.append(nn.Dropout(p=disc_prior_dropout_prob))
#print("dense ", layer_index, " size_i ", in_size, " size_o ", out_size )
in_size = out_size
if ae_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_size))
out_size = shape_after_flattening[1]
layers.append(nn.Linear(in_size, out_size))
in_size = out_size
if conv_layer_count > 0:
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if ae_use_layer_normalization:
layers.append(nn.LayerNorm())
self.dense_model = nn.Sequential(*layers)
if conv_layer_count > 0:
layers = []
in_channels = in_size
in_size = 1
for layer_index in range(conv_layer_count - 1):
out_channels = rev_ae_conv_filter_counts[layer_index]
kernel_size = rev_ae_conv_kernel_sizes[layer_index]
stride = rev_ae_conv_strides[layer_index]
dilation = rev_ae_conv_dilations[layer_index]
if ae_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_channels))
if dilation > 0:
layers.append(TransposedDilationBlock(in_channels, out_channels, kernel_size, stride, dilation, activation ))
else:
layers.append(nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, output_padding=(kernel_size - 1) // 2))
activation_function = getattr(nn, activation)()
layers.append(activation_function)
if ae_use_layer_normalization:
layers.append(nn.LayerNorm())
if ae_dropout_prob > 0.0:
layers.append(nn.Dropout(p=ae_dropout_prob))
#print("conv ", layer_index, " size_i ", in_size, " size_o ", in_size * stride, " channels_i ", in_channels, " channels_o ", out_channels, " kernel ", kernel_size, " stride ", stride, " dilation ", dilation )
in_channels = out_channels
in_size = in_size * stride # only correct for padding == "same"
out_channels = rev_ae_conv_filter_counts[-1]
kernel_size = rev_ae_conv_kernel_sizes[-1]
stride = rev_ae_conv_strides[-1]
dilation = rev_ae_conv_dilations[-1]
if ae_use_batch_normalization:
layers.append(nn.BatchNorm1d(in_channels))
if dilation > 0:
layers.append(TransposedDilationBlock(in_channels, out_channels, kernel_size, stride, dilation, None ))
else:
layers.append(nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, output_padding=(kernel_size - 1) // 2))
#print("conv ", conv_layer_count - 1, " size_i ", in_size, " size_o ", in_size * stride, " channels_i ", in_channels, " channels_o ", out_channels, " kernel ", kernel_size, " stride ", stride, " dilation ", dilation )
self.conv_model = nn.Sequential(*layers)
def forward(self, x):
#print("AudioDecoder forward")
#print("x1 ", x.shape)
x = self.dense_model(x)
#print("x2 ", x.shape)
if self.conv_model != None:
#x = x.view(shape_before_flattening)
x = x.view([x.shape[0]] + shape_before_flattening[1:])
#print("x2 ", x.shape)
x = self.conv_model(x)
#print("x3 ", x.shape)
x = x.view(-1, x.shape[1] * x.shape[2])
#print("x4 ", x.shape)
yhat = x
#print("x5 ", yhat.shape)
return yhat
ae_decoder = AudioDecoder().to(device)
print(ae_decoder)
x = torch.rand((batch_size, latent_dim), dtype=torch.float32).to(device)
x.shape
x = ae_decoder(x)
x.shape
if save_models == True:
# save using pickle
torch.save(ae_decoder, "{}.pth".format(ae_decoder_model_file))
# save using onnx
x = torch.zeros((1, latent_dim)).to(device)
torch.onnx.export(ae_decoder, x, "{}.onnx".format(ae_decoder_model_file))
if save_tscript == True:
ae_decoder.eval()
# save using TochScript
x = torch.rand((1, latent_dim), dtype=torch.float32).to(device)
script_module = torch.jit.trace(ae_decoder, x)
script_module.save("{}.pt".format(ae_decoder_model_file))
ae_decoder.train()
if load_weights and ae_decoder_weights_file:
ae_decoder.load_state_dict(torch.load(ae_decoder_weights_file))
# Training
# loss functions
cross_entropy = nn.BCELoss()
# discriminator prior loss
def disc_prior_loss(d_x, g_z, smoothing_factor = 0.9):
ones = torch.ones_like(d_x).to(device)
zeros = torch.zeros_like(g_z).to(device)
real_loss = cross_entropy(d_x, ones * smoothing_factor) # If we feed the discriminator with real images, we assume they all are the right pictures --> Because of that label == 1
fake_loss = cross_entropy(g_z, zeros) # Each noise we feed in are fakes image --> Because of that labels are 0.
total_loss = real_loss + fake_loss
return total_loss
# discriminator audio loss
def disc_audio_loss(d_x, g_z, smoothing_factor = 0.9):
ones = torch.ones_like(d_x).to(device)
zeros = torch.zeros_like(g_z).to(device)
real_loss = cross_entropy(d_x, ones * smoothing_factor) # If we feed the discriminator with real images, we assume they all are the right pictures --> Because of that label == 1
fake_loss = cross_entropy(g_z, zeros) # Each noise we feed in are fakes image --> Because of that labels are 0.
total_loss = real_loss + fake_loss
return total_loss
# autoencoder audio reconstruction loss
def ae_rec_loss(y, yhat):
flat_y = torch.flatten(y)
flat_yhat = torch.flatten(yhat)
_loss = torch.mean((flat_y-flat_yhat)**2)
return _loss
# autoencoder disc prior loss
def ae_disc_prior_loss(dx_of_gx):
ones = torch.ones_like(dx_of_gx).to(device)
return cross_entropy(dx_of_gx, ones)
# autoencoder disc audio loss
def ae_disc_audio_loss(dx_of_gx):
ones = torch.ones_like(dx_of_gx).to(device)
return cross_entropy(dx_of_gx, ones)
# autoencoder l1 loss
def ae_l1_loss(z):
flat_z = torch.flatten(z)
_loss = torch.mean(torch.abs(flat_z))
return _loss
# autoencoder l2 loss
def ae_l2_loss(z):
flat_z = torch.flatten(z)
_loss = torch.mean(flat_z**2)
return _loss
# optimizers
disc_prior_optimizer = torch.optim.Adam(disc_prior.parameters(), lr=disc_prior_learning_rate)
disc_audio_optimizer = torch.optim.Adam(disc_audio.parameters(), lr=disc_audio_learning_rate)
ae_optimizer = torch.optim.Adam(list(ae_encoder.parameters()) + list(ae_decoder.parameters()), lr=ae_learning_rate)
# Train Step
def train_step(ae_encoder, ae_decoder, disc_prior, disc_audio, audio_batch, batch_size = 32):
# train disc_prior
with torch.no_grad():
fake_normal = ae_encoder(audio_batch)
real_normal = sample_normal(fake_normal.shape)
disc_prior_real_output = disc_prior(real_normal)
disc_prior_fake_output = disc_prior(fake_normal)
#print("disc_prior_real_output ",disc_prior_real_output)
#print("disc_prior_fake_output ",disc_prior_fake_output)
_disc_prior_loss = disc_prior_loss(disc_prior_real_output, disc_prior_fake_output)
# Backpropagation
disc_prior_optimizer.zero_grad()
_disc_prior_loss.backward()
disc_prior_optimizer.step()
# train disc_audio
with torch.no_grad():
fake_audio = ae_decoder(ae_encoder(audio_batch))
real_audio = audio_batch
disc_audio_real_output = disc_audio(real_audio)
disc_audio_fake_output = disc_audio(fake_audio)
#print("disc_audio_real_output ",disc_audio_real_output)
#print("disc_audio_fake_output ",disc_audio_fake_output)
_disc_audio_loss = disc_audio_loss(disc_audio_real_output, disc_audio_fake_output)
# Backpropagation
disc_audio_optimizer.zero_grad()
_disc_audio_loss.backward()
disc_audio_optimizer.step()
# train autoencoder
encoder_out = ae_encoder(audio_batch)
decoder_out = ae_decoder(encoder_out)
disc_prior_fake_output = disc_prior(encoder_out)
disc_audio_fake_output = disc_audio(decoder_out)
_ae_rec_loss = ae_rec_loss(audio_batch, decoder_out)
_ae_disc_prior_loss = ae_disc_prior_loss(disc_prior_fake_output)
_ae_disc_audio_loss = ae_disc_audio_loss(disc_audio_fake_output)
_ae_l1_loss = ae_l1_loss(encoder_out)
_ae_l2_loss = ae_l2_loss(encoder_out)
_ae_loss = _ae_rec_loss * ae_rec_loss_scale + _ae_disc_prior_loss * ae_disc_prior_loss_scale + _ae_disc_audio_loss * ae_disc_audio_loss_scale + _ae_l1_loss * ae_l1_loss_scale + _ae_l2_loss * ae_l2_loss_scale
#print("train_step _ae_loss ", _ae_loss.detach().cpu().numpy(), "_ae_rec_loss ", _ae_rec_loss.detach().cpu().numpy(), " _ae_disc_prior_loss ", _ae_disc_prior_loss.detach().cpu().numpy(), " _ae_disc_audio_loss ", _ae_disc_audio_loss.detach().cpu().numpy(), " _ae_l1_loss ", _ae_l1_loss.detach().cpu().numpy(), " _ae_l2_loss ", _ae_l2_loss.detach().cpu().numpy())
# Backpropagation
ae_optimizer.zero_grad()
_ae_loss.backward()
ae_optimizer.step()
return _ae_loss, _disc_prior_loss, _disc_audio_loss, _ae_rec_loss, _ae_disc_prior_loss, _ae_disc_audio_loss, _ae_l1_loss, _ae_l2_loss
# Test Step
def test_step(ae_encoder, ae_decoder, disc_prior, disc_audio, audio_batch, batch_size = 32):
# test disc_prior
with torch.no_grad():
fake_normal = ae_encoder(audio_batch)
real_normal = sample_normal(fake_normal.shape)
disc_prior_real_output = disc_prior(real_normal)
disc_prior_fake_output = disc_prior(fake_normal)
_disc_prior_loss = disc_prior_loss(disc_prior_real_output, disc_prior_fake_output)
# test disc_audio
with torch.no_grad():
fake_audio = ae_decoder(ae_encoder(audio_batch))
real_audio = audio_batch
disc_audio_real_output = disc_audio(real_audio)
disc_audio_fake_output = disc_audio(fake_audio)
_disc_audio_loss = disc_audio_loss(disc_audio_real_output, disc_audio_fake_output)
# test autoencoder
with torch.no_grad():
encoder_out = ae_encoder(audio_batch)
decoder_out = ae_decoder(encoder_out)
_ae_rec_loss = ae_rec_loss(audio_batch, decoder_out)
_ae_disc_prior_loss = ae_disc_prior_loss(disc_prior_fake_output)
_ae_disc_audio_loss = ae_disc_audio_loss(disc_audio_fake_output)
_ae_l1_loss = ae_l1_loss(encoder_out)
_ae_l2_loss = ae_l2_loss(encoder_out)
_ae_loss = _ae_rec_loss * ae_rec_loss_scale + _ae_disc_prior_loss * ae_disc_prior_loss_scale + _ae_disc_audio_loss * ae_disc_audio_loss_scale + _ae_l1_loss * ae_l1_loss_scale + _ae_l2_loss * ae_l2_loss_scale
return _ae_loss, _disc_prior_loss, _disc_audio_loss, _ae_rec_loss, _ae_disc_prior_loss, _ae_disc_audio_loss, _ae_l1_loss, _ae_l2_loss
# Training Loop
def train(train_dataset, test_dataset, epoches):
loss_history = {}
loss_history["ae train_loss"] = []
loss_history["disc prior train loss"] = []
loss_history["disc audio train loss"] = []
loss_history["ae rec train loss"] = []
loss_history["ae disc prior train loss"] = []
loss_history["ae disc audio train loss"] = []
loss_history["ae l1 train loss"] = []
loss_history["ae l2 train loss"] = []
loss_history["ae test loss"] = []
loss_history["disc prior test loss"] = []
loss_history["disc audio test loss"] = []
loss_history["ae rec test loss"] = []
loss_history["ae disc prior test loss"] = []
loss_history["ae disc audio test loss"] = []
loss_history["ae l1 test loss"] = []
loss_history["ae l2 test loss"] = []
for epoch in range(epoches):
start = time.time()
ae_train_loss_per_epoch = []
disc_prior_train_loss_per_epoch = []
disc_audio_train_loss_per_epoch = []
ae_rec_train_loss_per_epoch = []
ae_disc_prior_train_loss_per_epoch = []
ae_disc_audio_train_loss_per_epoch = []
ae_l1_train_loss_per_epoch = []
ae_l2_train_loss_per_epoch = []
for train_batch in train_dataset:
train_batch = train_batch.to(device)
_ae_loss, _disc_prior_loss, _disc_audio_loss, _ae_rec_loss, _ae_disc_prior_loss, _ae_disc_audio_loss, _ae_l1_loss, _ae_l2_loss = train_step(ae_encoder, ae_decoder, disc_prior, disc_audio, train_batch, batch_size = batch_size)
_ae_loss = _ae_loss.detach().cpu().numpy()
_disc_prior_loss = _disc_prior_loss.detach().cpu().numpy()
_disc_audio_loss = _disc_audio_loss.detach().cpu().numpy()
_ae_rec_loss = _ae_rec_loss.detach().cpu().numpy()
_ae_disc_prior_loss = _ae_disc_prior_loss.detach().cpu().numpy()
_ae_disc_audio_loss = _ae_disc_audio_loss.detach().cpu().numpy()
_ae_l1_loss = _ae_l1_loss.detach().cpu().numpy()
_ae_l2_loss = _ae_l2_loss.detach().cpu().numpy()
#print("batch ae_loss ", _ae_loss, " disc_prior_loss ", _disc_prior_loss, " disc_audio_loss ", _disc_audio_loss, " ae_rec_loss ", _ae_rec_loss, " ae_disc_prior_loss ", _ae_disc_prior_loss, " ae_disc_audio_loss ", _ae_disc_audio_loss, " ae_l1_loss ", _ae_l1_loss, " ae_l2_loss ", _ae_l2_loss)
ae_train_loss_per_epoch.append(_ae_loss)
disc_prior_train_loss_per_epoch.append(_disc_prior_loss)
disc_audio_train_loss_per_epoch.append(_disc_audio_loss)
ae_rec_train_loss_per_epoch.append(_ae_rec_loss)
ae_disc_prior_train_loss_per_epoch.append(_ae_disc_prior_loss)
ae_disc_audio_train_loss_per_epoch.append(_ae_disc_audio_loss)
ae_l1_train_loss_per_epoch.append(_ae_l1_loss)
ae_l2_train_loss_per_epoch.append(_ae_l2_loss)
ae_train_loss_per_epoch = np.mean(np.array(ae_train_loss_per_epoch))
disc_prior_train_loss_per_epoch = np.mean(np.array(disc_prior_train_loss_per_epoch))
disc_audio_train_loss_per_epoch = np.mean(np.array(disc_audio_train_loss_per_epoch))
ae_rec_train_loss_per_epoch = np.mean(np.array(ae_rec_train_loss_per_epoch))
ae_disc_prior_train_loss_per_epoch = np.mean(np.array(ae_disc_prior_train_loss_per_epoch))
ae_disc_audio_train_loss_per_epoch = np.mean(np.array(ae_disc_audio_train_loss_per_epoch))
ae_l1_train_loss_per_epoch = np.mean(np.array(ae_l1_train_loss_per_epoch))
ae_l2_train_loss_per_epoch = np.mean(np.array(ae_l2_train_loss_per_epoch))
"""
print("epoch ae_loss ", ae_train_loss_per_epoch, " disc_prior_loss ", disc_prior_train_loss_per_epoch, " disc_audio_loss ", disc_audio_train_loss_per_epoch,
" ae_rec_loss ", ae_rec_train_loss_per_epoch, " ae_disc_prior_loss ", ae_disc_prior_train_loss_per_epoch, " ae_disc_audio_loss ", ae_disc_audio_train_loss_per_epoch,
" ae_l1_loss ", ae_l1_train_loss_per_epoch, " ae_l2_loss ", ae_l2_train_loss_per_epoch)
"""
ae_test_loss_per_epoch = []
disc_prior_test_loss_per_epoch = []
disc_audio_test_loss_per_epoch = []
ae_rec_test_loss_per_epoch = []
ae_disc_prior_test_loss_per_epoch = []
ae_disc_audio_test_loss_per_epoch = []
ae_l1_test_loss_per_epoch = []
ae_l2_test_loss_per_epoch = []
for test_batch in test_dataset:
test_batch = test_batch.to(device)
_ae_loss, _disc_prior_loss, _disc_audio_loss, _ae_rec_loss, _ae_disc_prior_loss, _ae_disc_audio_loss, _ae_l1_loss, _ae_l2_loss = test_step(ae_encoder, ae_decoder, disc_prior, disc_audio, test_batch, batch_size = batch_size)
_ae_loss = _ae_loss.detach().cpu().numpy()
_disc_prior_loss = _disc_prior_loss.detach().cpu().numpy()
_disc_audio_loss = _disc_audio_loss.detach().cpu().numpy()
_ae_rec_loss = _ae_rec_loss.detach().cpu().numpy()
_ae_disc_prior_loss = _ae_disc_prior_loss.detach().cpu().numpy()
_ae_disc_audio_loss = _ae_disc_audio_loss.detach().cpu().numpy()
_ae_l1_loss = _ae_l1_loss.detach().cpu().numpy()
_ae_l2_loss = _ae_l2_loss.detach().cpu().numpy()
ae_test_loss_per_epoch.append(_ae_loss)
disc_prior_test_loss_per_epoch.append(_disc_prior_loss)
disc_audio_test_loss_per_epoch.append(_disc_audio_loss)
ae_rec_test_loss_per_epoch.append(_ae_rec_loss)
ae_disc_prior_test_loss_per_epoch.append(_ae_disc_prior_loss)
ae_disc_audio_test_loss_per_epoch.append(_ae_disc_audio_loss)
ae_l1_test_loss_per_epoch.append(_ae_l1_loss)
ae_l2_test_loss_per_epoch.append(_ae_l2_loss)
ae_test_loss_per_epoch = np.mean(np.array(ae_test_loss_per_epoch))
disc_prior_test_loss_per_epoch = np.mean(np.array(disc_prior_test_loss_per_epoch))
disc_audio_test_loss_per_epoch = np.mean(np.array(disc_audio_test_loss_per_epoch))
ae_rec_test_loss_per_epoch = np.mean(np.array(ae_rec_test_loss_per_epoch))
ae_disc_prior_test_loss_per_epoch = np.mean(np.array(ae_disc_prior_test_loss_per_epoch))
ae_disc_audio_test_loss_per_epoch = np.mean(np.array(ae_disc_audio_test_loss_per_epoch))
ae_l1_test_loss_per_epoch = np.mean(np.array(ae_l1_test_loss_per_epoch))
ae_l2_test_loss_per_epoch = np.mean(np.array(ae_l2_test_loss_per_epoch))
if epoch % model_save_interval == 0 and save_weights == True:
disc_prior.save_weights("aae/weights/disc_prior_weights epoch_{}".format(epoch))
disc_audio.save_weights("aae/weights/disc_audio_weights epoch_{}".format(epoch))
ae_encoder.save_weights("aae/weights/ae_encoder_weights epoch_{}".format(epoch))
ae_decoder.save_weights("aae/weights/ae_decoder_weights epoch_{}".format(epoch))
if epoch % audio_save_interval == 0 and save_audio == True:
create_epoch_sonifications(epoch)
loss_history["ae train_loss"].append(ae_train_loss_per_epoch)
loss_history["disc prior train loss"].append(disc_prior_train_loss_per_epoch)
loss_history["disc audio train loss"].append(disc_audio_train_loss_per_epoch)
loss_history["ae rec train loss"].append(ae_rec_train_loss_per_epoch)
loss_history["ae disc prior train loss"].append(ae_disc_prior_train_loss_per_epoch)
loss_history["ae disc audio train loss"].append(ae_disc_audio_train_loss_per_epoch)
loss_history["ae l1 train loss"].append(ae_l1_train_loss_per_epoch)
loss_history["ae l2 train loss"].append(ae_l2_train_loss_per_epoch)
loss_history["ae test loss"].append(ae_test_loss_per_epoch)
loss_history["disc prior test loss"].append(disc_prior_test_loss_per_epoch)
loss_history["disc audio test loss"].append(disc_audio_test_loss_per_epoch)
loss_history["ae rec test loss"].append(ae_rec_test_loss_per_epoch)
loss_history["ae disc prior test loss"].append(ae_disc_prior_test_loss_per_epoch)
loss_history["ae disc audio test loss"].append(ae_disc_audio_test_loss_per_epoch)
loss_history["ae l1 test loss"].append(ae_l1_test_loss_per_epoch)
loss_history["ae l2 test loss"].append(ae_l2_test_loss_per_epoch)
print ('epoch {} : ae train {:01.4f} dprior {:01.4f} daudio {:01.4f} dprior2 {:01.4f} daudio2 {:01.4f} rec {:01.4f} l1 {:01.4f} l2 {:01.4f} test {:01.4f} dprior {:01.4f} daudio {:01.4f} rec {:01.4f} l1 {:01.4f} l2 {:01.4f} time {:01.2f}'.format(epoch + 1, ae_train_loss_per_epoch, ae_disc_prior_train_loss_per_epoch, ae_disc_audio_train_loss_per_epoch, disc_prior_train_loss_per_epoch, disc_audio_train_loss_per_epoch,
ae_rec_train_loss_per_epoch, ae_l1_train_loss_per_epoch, ae_l2_train_loss_per_epoch, ae_test_loss_per_epoch, ae_disc_prior_test_loss_per_epoch, ae_disc_audio_test_loss_per_epoch,
ae_rec_test_loss_per_epoch, ae_l1_test_loss_per_epoch, ae_l2_test_loss_per_epoch, time.time()-start))
return loss_history
# Save Reconstructd Audio Examples
def create_pred_sonification(start_time, duration, file_name):
grain_env = np.hanning(audio_window_length)
grain_offset = audio_sample_rate // audio_window_rate
predict_start_window = int(start_time * audio_save_sample_rate // audio_window_length)
predict_window_count = int(duration * audio_save_sample_rate // grain_offset)
pred_audio_sequence_length = (predict_window_count - 1) * grain_offset + audio_window_length
pred_audio_sequence = np.zeros(shape=(pred_audio_sequence_length), dtype=np.float32)
for i in range(predict_window_count):
target_audio = audio_standardized[predict_start_window * audio_window_length + i*grain_offset:predict_start_window * audio_window_length + i*grain_offset + audio_window_length]
target_audio = torch.from_numpy(target_audio).to(device)
target_audio = torch.unsqueeze(target_audio, 0)
with torch.no_grad():
enc_audio = ae_encoder(target_audio)
pred_audio = ae_decoder(enc_audio)
pred_audio = np.squeeze(pred_audio.detach().cpu().numpy())
pred_audio = pred_audio * grain_env
pred_audio_sequence[i*grain_offset:i*grain_offset + audio_window_length] = pred_audio_sequence[i*grain_offset:i*grain_offset + audio_window_length] + pred_audio
if standardize_audio:
pred_audio_sequence = audio_mean + pred_audio_sequence * audio_std
# cpnvert audio float array to audio int16 array with samples between -2**15 amd 2**15
pred_audio_int16_sequence = np.clip(pred_audio_sequence, -1.0, 1.0)
pred_audio_int16_sequence = pred_audio_int16_sequence * audio_sample_scale
pred_audio_int16_sequence = pred_audio_int16_sequence.astype(np.int16)
audio_write(file_name, audio_save_sample_rate, pred_audio_int16_sequence)
# Save Interpolated Audio Examples
def create_interpol_sonification(start_time1, start_time2, duration, file_name):
grain_env = np.hanning(audio_window_length)
grain_offset = audio_sample_rate // audio_window_rate
start_alpha = -1.0
end_alpha = 2.0
audio_window_count = int(duration * audio_window_rate)
pred_audio_sequence_length = int((audio_window_count - 1) * grain_offset + audio_window_length)
pred_audio_sequence = np.zeros(shape=(pred_audio_sequence_length), dtype=np.float32)
for i in range(audio_window_count):
target1_audio_start = int(start_time1 * audio_save_sample_rate + i * grain_offset)
target1_audio_end = target1_audio_start + audio_window_length
target2_audio_start = int(start_time2 * audio_save_sample_rate + i * grain_offset)
target2_audio_end = target2_audio_start + audio_window_length
seq1_audio_window = audio_standardized[target1_audio_start:target1_audio_end]
seq2_audio_window = audio_standardized[target2_audio_start:target2_audio_end]
seq1_audio_window = np.expand_dims(seq1_audio_window, axis=0)
seq2_audio_window = np.expand_dims(seq2_audio_window, axis=0)
seq1_audio_window = torch.from_numpy(seq1_audio_window).to(device)
seq2_audio_window = torch.from_numpy(seq2_audio_window).to(device)
with torch.no_grad():
seq1_latent_vector = ae_encoder(seq1_audio_window)
seq2_latent_vector = ae_encoder(seq2_audio_window)
alpha = start_alpha + (end_alpha - start_alpha) * i / (audio_window_count - 1)
mix_latent_vector = seq1_latent_vector * (1.0 - alpha) + seq2_latent_vector * alpha
with torch.no_grad():
pred_audio = ae_decoder(mix_latent_vector)
pred_audio = (torch.squeeze(pred_audio)).detach().cpu().numpy()
pred_audio *= grain_env
pred_audio_start = int(i * grain_offset)
pred_audio_end = pred_audio_start + audio_window_length
pred_audio_sequence[pred_audio_start:pred_audio_end] += pred_audio
if standardize_audio:
pred_audio_sequence = audio_mean + pred_audio_sequence * audio_std
# convert audio float array to audio int16 array with samples between -2**15 amd 2**15
pred_audio_int16_sequence = np.clip(pred_audio_sequence, -1.0, 1.0)
pred_audio_int16_sequence = pred_audio_int16_sequence * audio_sample_scale
pred_audio_int16_sequence = pred_audio_int16_sequence.astype(np.int16)
audio_write(file_name, audio_save_sample_rate, pred_audio_int16_sequence)
# Save Latent Space Traversal Audio Examples
def create_traverse_sonifcation(start_audio_window_index, end_audio_window_index, window_count, window_interpolation_count, file_name):
# interpolate audio latent space
grain_env = np.hanning(audio_window_length)
grain_offset = audio_sample_rate // audio_window_rate
target_windows = list(np.linspace(start_audio_window_index, end_audio_window_index, window_count, dtype=np.int32))
interpolation_window_count = window_interpolation_count
target_audio = []
audio_window_count = len(target_windows)
pred_audio_sequence_length = int((window_interpolation_count * audio_window_count - 1) * grain_offset + audio_window_length)
pred_audio_sequence = np.zeros(shape=(pred_audio_sequence_length), dtype=np.float32)
for wI in range(audio_window_count - 1):
target_start_window = target_windows[wI]
target_end_window = target_windows[wI + 1]
target_start_audio = np.expand_dims(audio_standardized[target_start_window*audio_window_length:(target_start_window + 1) * audio_window_length], axis=0)
target_end_audio = np.expand_dims(audio_standardized[target_end_window*audio_window_length:(target_end_window + 1) * audio_window_length], axis=0)
target_start_audio = torch.from_numpy(target_start_audio).to(device)
target_end_audio = torch.from_numpy(target_end_audio).to(device)
with torch.no_grad():
start_enc = ae_encoder(target_start_audio)
end_enc = ae_encoder(target_end_audio)
for i in range(interpolation_window_count):
inter_enc = start_enc + (end_enc - start_enc) * i / (interpolation_window_count - 1.0)
with torch.no_grad():
pred_audio = ae_decoder(inter_enc)
pred_audio = (torch.squeeze(pred_audio)).detach().cpu().numpy()
pred_audio *= grain_env
i2 = (wI * interpolation_window_count + i)
pred_audio_start = int(i2 * grain_offset)
pred_audio_end = pred_audio_start + audio_window_length
pred_audio_sequence[pred_audio_start:pred_audio_end] += pred_audio
#print("predict window ", wI, " step ", i)
if standardize_audio:
pred_audio_sequence = audio_mean + pred_audio_sequence * audio_std
# cpnvert audio float array to audio int16 array with samples between -2**15 amd 2**15
pred_audio_int16_sequence = np.clip(pred_audio_sequence, -1.0, 1.0)
pred_audio_int16_sequence = pred_audio_int16_sequence * audio_sample_scale
pred_audio_int16_sequence = pred_audio_int16_sequence.astype(np.int16)
#play_obj = sa.play_buffer(pred_audio_int16_sequence, 1, 2, audio_save_sample_rate)
#play_obj.stop()
audio_write(file_name, audio_save_sample_rate, pred_audio_int16_sequence)
# Create Audio Examples for Each Epoch
def create_epoch_sonifications(epoch):
# put models into evaluation mode, turning off dropout and batchnorm
disc_prior.eval()
disc_audio.eval()
ae_encoder.eval()
ae_decoder.eval()
# pred sonifications
for audio_start_time in audio_save_start_times:
file_name = "{}/pred_audio_{}_epoch_{}.wav".format(audio_result_path, audio_start_time, epoch)
create_pred_sonification(audio_start_time, audio_save_duration, file_name)
# two audio region interpolation
for time_index in range(len(audio_save_start_times) - 1):
audio_start_time1 = audio_save_start_times[time_index]
audio_start_time2 = audio_save_start_times[time_index + 1]
file_name = "{}/interpol_audio_{}_{}_epoch_{}.wav".format(audio_result_path, audio_start_time1, audio_start_time2, epoch)
create_interpol_sonification(audio_start_time1, audio_start_time2, audio_save_duration, file_name)
# audio window sequence traversal
file_name = "{}/traverse_audio_{}_{}_{}_{}_epoch_{}.wav".format(audio_result_path, audio_traverse_start_window, audio_traverse_end_window, audio_traverse_window_count, audio_traverse_interpolation_count, epoch)
create_traverse_sonifcation(audio_traverse_start_window, audio_traverse_end_window, audio_traverse_window_count, audio_traverse_interpolation_count, file_name)
# put models back into training mode
disc_prior.train()
disc_audio.train()
ae_encoder.train()
ae_decoder.train()
# Train Model
loss_history = train(train_dataloader, test_dataloader, epochs)
# save history
if save_history:
utils.save_loss_as_csv(loss_history, "../results_pytorch/sonification/aae/history.csv")
utils.save_loss_as_image(loss_history, "../results_pytorch/sonification/aae/history.png")
# final weights save
if save_weights:
torch.save(disc_prior.state_dict(), "../results_pytorch/sonification/aae/disc_prior_weights_epoch_{}".format(epochs))
torch.save(disc_audio.state_dict(), "../results_pytorch/sonification/aae/disc_audio_weights_epoch_{}".format(epochs))
torch.save(ae_encoder.state_dict(), "../results_pytorch/sonification/aae/ae_encoder_weights_epoch_{}".format(epochs))
torch.save(ae_decoder.state_dict(), "../results_pytorch/sonification/aae/ae_decoder_weights_epoch_{}".format(epochs))
if save_audio:
create_epoch_sonifications(epochs)