| | |
| | |
| | |
| | import torch |
| |
|
| | from typing import Tuple, Dict, List |
| | import utils |
| |
|
| | |
| | |
| | |
| | def train_step(model: torch.nn.Module, |
| | dataloader: torch.utils.data.DataLoader, |
| | loss_fn: torch.nn.Module, |
| | optimizer: torch.optim.Optimizer, |
| | device: torch.device) -> Tuple[float, float]: |
| | |
| | ''' |
| | Performs a training step for a PyTorch model. |
| | |
| | Args: |
| | model (torch.nn.Module): PyTorch model that will be trained |
| | dataloader (torch.utils.data.DataLoader): Dataloader containing data to train on |
| | loss_fn (torch.nn.Module): Loss function used as the error metric |
| | optimizer (torch.optim.Optimizer): Optimization method used to update model parameters per batch |
| | device (torch.device): Device to train on |
| | |
| | Returns: |
| | train_loss (float): The average loss calculated over the training set. |
| | train_acc (float): The accuracy calculated over the training set. |
| | ''' |
| | |
| | model.train() |
| | train_loss = torch.tensor(0.0, device = device) |
| | train_acc = torch.tensor(0.0, device = device) |
| | num_samps = len(dataloader.dataset) |
| |
|
| | |
| | for X, y in dataloader: |
| | |
| | optimizer.zero_grad() |
| | |
| | X, y = X.to(device), y.to(device) |
| | |
| | y_logits = model(X) |
| | |
| | loss = loss_fn(y_logits, y) |
| | train_loss += loss.detach() * X.shape[0] |
| | |
| | loss.backward() |
| | optimizer.step() |
| | |
| | y_pred = y_logits.argmax(dim = 1) |
| | |
| | train_acc += (y_pred == y).sum() |
| | |
| | |
| | train_loss = train_loss.item() / num_samps |
| | train_acc = train_acc.item() / num_samps |
| | |
| | return train_loss, train_acc |
| |
|
| |
|
| | def test_step(model: torch.nn.Module, |
| | dataloader: torch.utils.data.DataLoader, |
| | loss_fn: torch.nn.Module, |
| | device: torch.device) -> Tuple[float, float]: |
| | |
| | ''' |
| | Performs a testing step for a PyTorch model. |
| | |
| | Args: |
| | model (torch.nn.Module): PyTorch model that will be tested. |
| | dataloader (torch.utils.data.DataLoader): Dataloader containing data to test on. |
| | loss_fn (torch.nn.Module): Loss function used as the error metric. |
| | device (torch.device): Device to compute on. |
| | |
| | Returns: |
| | test_loss (float): The average loss calculated over batches. |
| | test_acc (float): The average accuracy calculated over batches. |
| | ''' |
| | |
| | model.eval() |
| | test_loss = torch.tensor(0.0, device = device) |
| | test_acc = torch.tensor(0.0, device = device) |
| | num_samps = len(dataloader.dataset) |
| |
|
| | with torch.inference_mode(): |
| | |
| | for X, y in dataloader: |
| | X, y = X.to(device), y.to(device) |
| |
|
| | y_logits = model(X) |
| |
|
| | test_loss += loss_fn(y_logits, y) * X.shape[0] |
| |
|
| | y_pred = y_logits.argmax(dim = 1) |
| |
|
| | test_acc += (y_pred == y).sum() |
| |
|
| | |
| | test_loss = test_loss.item() / num_samps |
| | test_acc = test_acc.item() / num_samps |
| | |
| | return test_loss, test_acc |
| |
|
| |
|
| | def train(model: torch.nn.Module, |
| | train_dl: torch.utils.data.DataLoader, |
| | test_dl: torch.utils.data.DataLoader, |
| | loss_fn: torch.nn.Module, |
| | optimizer: torch.optim.Optimizer, |
| | num_epochs: int, |
| | patience: int, |
| | min_delta: float, |
| | device: torch.device, |
| | save_mod: bool = True, |
| | save_dir: str = '', |
| | mod_name: str = '') -> Dict[str, List[float]]: |
| | ''' |
| | Performs the training and testing steps for a PyTorch model, |
| | with early stopping applied for test loss. |
| | |
| | Args: |
| | model (torch.nn.Module): PyTorch model to train. |
| | train_dl (torch.utils.data.DataLoader): DataLoader for training. |
| | test_dl (torch.utils.data.DataLoader): DataLoader for testing. |
| | loss_fn (torch.nn.Module): Loss function used as the error metric. |
| | optimizer (torch.optim.Optimizer): Optimizer used to update model parameters per batch. |
| | |
| | num_epochs (int): Max number of epochs to train. |
| | patience (int): Number of epochs to wait before early stopping. |
| | min_delta (float): Minimum decrease in loss to reset counter. |
| | |
| | device (torch.device): Device to train on. |
| | save_mod (bool, optional): If True, saves the model after each epoch. Default is True. |
| | save_dir (str, optional): Directory to save the model to. Must be nonempty if save_mod is True. |
| | mod_name (str, optional): Filename for the saved model. Must be nonempty if save_mod is True. |
| | |
| | returns: |
| | res (dict): A results dictionary containing lists of train and test metrics for each epoch. |
| | ''' |
| | |
| | bold_start, bold_end = '\033[1m', '\033[0m' |
| |
|
| | if save_mod: |
| | assert save_dir, 'save_dir cannot be None or empty.' |
| | assert mod_name, 'mod_name cannot be None or empty.' |
| |
|
| | |
| | res = {'train_loss': [], |
| | 'train_acc': [], |
| | 'test_loss': [], |
| | 'test_acc': [] |
| | } |
| | |
| | |
| | best_loss, counter = None, 0 |
| | |
| | for epoch in range(num_epochs): |
| | |
| | train_loss, train_acc = train_step(model, train_dl, loss_fn, optimizer, device) |
| | test_loss, test_acc = test_step(model, test_dl, loss_fn, device) |
| | |
| | |
| | res['train_loss'].append(train_loss) |
| | res['train_acc'].append(train_acc) |
| | res['test_loss'].append(test_loss) |
| | res['test_acc'].append(test_acc) |
| | |
| | print(f'Epoch: {epoch + 1} | ' + |
| | f'train_loss = {train_loss:.4f} | train_acc = {train_acc:.4f} | ' + |
| | f'test_loss = {test_loss:.4f} | test_acc = {test_acc:.4f}') |
| | |
| | |
| | if best_loss == None: |
| | best_loss = test_loss |
| | if save_mod: |
| | utils.save_model(model, save_dir, mod_name) |
| |
|
| | elif test_loss < best_loss - min_delta: |
| | best_loss = test_loss |
| | counter = 0 |
| |
|
| | if save_mod: |
| | utils.save_model(model, save_dir, mod_name) |
| | print(f'{bold_start}[SAVED]{bold_end} Adequate improvement in test loss; model saved.') |
| |
|
| | else: |
| | counter += 1 |
| | if counter > patience: |
| | print(f'{bold_start}[ALERT]{bold_end} No improvement in test loss after {counter} epochs; early stopping triggered.') |
| | break |
| |
|
| | return res |