Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
from torch.utils.data import Dataset
import pandas as pd
import torchaudio
import torch
from sounds import sounds
from cnn import CNNNetwork
from torch.utils.data import DataLoader
from torch import nn
import warnings
def create_data_loader(train_data, batch_size):
train_dataloader = DataLoader(train_data, batch_size=batch_size)
return train_dataloader
def train_single_epoch(model, data_loader, loss_fn, optimiser, device):
correct=0
for input, target in data_loader:
input, target = input.to(device), target.to(device)
# calculate loss
prediction = model(input)
loss = loss_fn(prediction, target)
# backpropagate error and update weights
optimiser.zero_grad()
loss.backward()
optimiser.step()
print(f"loss: {loss.item()}")
def train(model, data_loader, loss_fn, optimiser, device, epochs):
for i in range(epochs):
print(f"Epoch {i+1}")
train_single_epoch(model, data_loader, loss_fn, optimiser, device)
print("---------------------------")
print("Finished training")