torchsight.trainers.trainer module

Abstract trainer module.

Quick start

A trainer has getters and methods: - The getters are for get the different modules necessary for the training like datasets, dataloaders, model, criterion and optimizer. Optionally you can provide a logger and a learning rate scheduler. - The methods are used to optimize and evaluate the model. - The train method does the classic training algorithm. - The validate method does the validation of the model over the validation dataset. - The eval method puts the model into evaluation mode. - The forward method does a forward pass over the model and returns the loss tensor. You must implement this method. - The backward method does the backward propagation of the loss.

To use a trainer you must implement the getters methods and the forward method.

A good practice is to use the hyperparameters dict to store the parameters for the getters, so anyone can change the hyperparameters without changing the code.

Source code
"""Abstract trainer module.

## Quick start

A trainer has getters and methods:
- The getters are for get the different modules necessary for the training like *datasets*, *dataloaders*,
*model*, *criterion* and *optimizer*. Optionally you can provide a *logger* and a *learning rate scheduler*.
- The methods are used to optimize and evaluate the model.
  - The **train** method does the classic training algorithm.
  - The **validate** method does the validation of the model over the validation dataset.
  - The **eval** method puts the model into evaluation mode.
  - The **forward** method does a forward pass over the model and returns the loss tensor. **You must implement
    this method**.
  - The **backward** method does the backward propagation of the loss.

To use a trainer you must implement the getters methods and the *forward* method.

A good practice is to use the hyperparameters dict to store the parameters for the getters, so anyone can change
the hyperparameters without changing the code.
"""
import json
import os
import time

import torch

from torchsight.loggers import PrintLogger
from torchsight.utils import merge_dicts

LOGS_DIR = './logs'


class Trainer():
    """Base Trainer class, all the trainers must extend this class."""
    # A dict with all the hyperparameters for the different components of the training
    hyperparameters = {}

    def __init__(self, hyperparameters=None, checkpoint=None, device=None, save_only_best=True):
        """Initialize the trainer.

        Arguments:
            hyperparameters (dict, optional): A dict to change the base hyperparameters.
                If it's present, it will be deeply merged with the base hyperparameters.
            checkpoint (str, optional): A path to a checkpoint (generated by the same trainer)
                to resume the training.
            device (str, optional): The device where to run the training.
            save_only_best (bool, optional): If True, it will save only the best checkpoint, not all
                the checkpoints for each epoch.
        """
        base_hyperparameters = {'checkpoint': {'dir': LOGS_DIR, 'verbose': True},
                                'logger': {'dir': LOGS_DIR}}
        # Add the base hyperparameters to the trainer hyperparameters
        self.hyperparameters = merge_dicts(self.hyperparameters, base_hyperparameters)
        # Add the modified hyperparameters given in the initialization
        self.hyperparameters = merge_dicts(self.hyperparameters, hyperparameters, verbose=True)
        # Set the device of the trainer
        self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'

        # Flag to save only the best model
        self.save_only_best = save_only_best
        self.best_loss = 1e10

        # Get the modules for the training
        print('Loading datasets ...')
        self.dataset, self.valid_dataset = self.get_datasets()
        print('Loading dataloaders ...')
        self.dataloader, self.valid_dataloader = self.get_dataloaders()
        print('Loading model ...')
        self.model = self.get_model()
        print('Loading criterion, optimizer and scheduler ...')
        self.criterion = self.get_criterion()
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_scheduler()

        # Load the checkpoint
        self.checkpoint = self.resume(checkpoint)
        # Get the logger
        self.logger = self.get_logger()
        # As we only log one time per batch we need to keep the state of all the elements that we want to log
        self.current_log = {}

    ####################################
    ###           GETTERS            ###
    ####################################

    def get_datasets(self):
        """Get the training and validation datasets.

        Returns:
            tuple: A Tuple with the torch.utils.data.Datasets for training and validation.
        """
        raise NotImplementedError('You must provide your own datasets.')

    def get_dataloaders(self):
        """Get the dataloaders for training and validation.

        Returns:
            tuple: The tuple with the torch.utils.data.DataLoader for training and validation.
        """
        raise NotImplementedError('You must provide the dataloaders for the datasets.')

    def get_model(self):
        """Get the model to train.

        Returns:
            torch.nn.Module: The model to train.
        """
        raise NotImplementedError('You must provide the model to train.')

    def get_criterion(self):
        """Get the criterion to use to train the model.

        Returns:
            torch.nn.Module: The criterion to use in the training.
        """
        return NotImplementedError('You must provide the criterion to train the model.')

    def get_optimizer(self):
        """Get the optimizer of the model.

        Returns:
            torch.optim.Optimizer: The optimizer of the model's parameters.
        """
        raise NotImplementedError('You must provide the optimizer to use during the training.')

    def get_scheduler(self):
        """Get the (optional) scheduler for the learning rate of the optimizer.

        Returns:
            torch.optim.lr_scheduler._LRScheduler: A scheduler for the learning rate.
        """
        # No error because the scheduler is optional.

    def get_logger(self):
        """Get the (optional) logger to use during the training to show the information about the process.

        This base implementation uses the PrintLogger that will print the log to the console.

        Returns:
            pymatch.loggers.Logger: A Logger to use during the training.
        """
        description = 'Hyperparameters:\n{}'.format(json.dumps(self.hyperparameters, indent=2))
        return PrintLogger(description, self.hyperparameters['logger']['dir'])

    ####################################
    ###           METHODS            ###
    ####################################

    def train(self, epochs=100, validate=True):
        """Train the model during the giving epochs.

        Arguments:
            epochs (int, optional): The number of epochs to run the model.
            validate (bool, optional): If it's True the trainer will validate the training using
                the validate() method. And if there is a scheduler it gives the validation loss
                generated by the validate() method to the scheduler to adjust the learning rate.
        """
        if not validate and self.save_only_best:
            raise ValueError('You could not disable validation and save only the best model. '
                             'The model must be validated to know if it\'s the best one.')

        self.model.to(self.device)

        # The criterion could be inside the model for example and in that case it could be None
        if self.criterion is not None:
            self.criterion.to(self.device)

        # The number of batches that the training dataset have
        n_batches = len(self.dataloader)

        # The start time of the training and the last batch's end time
        start_time = time.time()
        last_endtime = start_time

        # We start from the next epoch of the checkpoint (if there is any)
        start_epoch = 1 if self.checkpoint is None else self.checkpoint['epoch'] + 1

        for epoch in range(start_epoch, start_epoch + epochs):
            # Indicate to the model that we are in training mode, useful for batch normalization or dropouts modules.
            # For more info see:
            # https://discuss.pytorch.org/t/trying-to-understand-the-meaning-of-model-train-and-model-eval/20158
            self.model.train()

            for batch, data in enumerate(self.dataloader):
                # Optimize
                self.optimizer.zero_grad()
                loss = self.forward(*data)
                self.backward(loss)
                self.optimizer.step()

                # Log the batch
                learning_rates = [str(param_group['lr'])
                                  for i, param_group in enumerate(self.optimizer.param_groups)]

                total_time = time.time() - start_time
                batch_time = time.time() - last_endtime
                last_endtime = time.time()

                self.logger.log(merge_dicts({
                    'Training': None,
                    'Epoch': epoch,
                    'Batch': '{}/{}'.format(batch + 1, n_batches),
                    'LR': ' '.join(learning_rates),
                    'Loss': '{:.7f}'.format(float(loss)),
                    'Time': '{:.3f} s'.format(batch_time),
                    'Total': '{:.1f} s'.format(total_time)
                }, self.current_log))
                self.current_log = {}  # Restart the log dict for the next batch

                # Call the callback for the batch
                self.batch_callback(batch, epoch)

            # Call the callback for the epoch
            self.epoch_callback(epoch)

            # Save the checkpoint for this epoch
            if not self.save_only_best:
                self.save(epoch)

            if validate:
                loss = self.validate(epoch)
                self.save(epoch, loss)
                if self.scheduler is not None:
                    self.scheduler.step(loss)

    def forward(self, *args):
        """Do a forward pass over the model with the model and get the loss value.

        Arguments:
            *args: All the data that the dataloader generates while iterating over it.

        Returns:
            torch.Tensor: The loss value of the forward pass.
        """
        raise NotImplementedError('You must implement the forward pass over the model.')

    def backward(self, loss):
        """Do the backward pass over the network.

        There is a method for this because each experiment could do different things during the backward like:
        ```python
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
        ```

        But in this case it only does the backward of the loss.

        Arguments:
            loss (torch.Tensor): The loss value computed during the forward pass.
        """
        loss.backward()

    def eval(self):
        """Set the model into evaluation mode.

        It's a method to override this and provide a custom eval() call if you want.
        """
        self.model.eval()

    def validate(self, epoch):
        """Run the model over the validation dataset and return the mean loss over it."""
        self.model.to(self.device)
        self.eval()

        start_time = time.time()
        last_endtime = start_time

        n_batches = len(self.valid_dataloader)

        losses = []

        with torch.no_grad():
            for batch, data in enumerate(self.valid_dataloader):
                loss = float(self.forward(*data))

                batch_time = time.time() - last_endtime
                last_endtime = time.time()
                total_time = time.time() - start_time

                self.logger.log(merge_dicts({
                    'Validating': None,
                    'Epoch': epoch,
                    'Batch': '{}/{}'.format(batch + 1, n_batches),
                    'Loss': '{:.7f}'.format(float(loss)),
                    'Time': '{:.3f} s'.format(batch_time),
                    'Total': '{:.1f} s'.format(total_time)
                }, self.current_log))
                self.current_log = {}  # Restart the log dict for the next batch

                losses.append(loss)

        return torch.Tensor(losses).mean()

    def batch_callback(self, batch, epoch):
        """Method that is called after a batch has finished its process."""

    def epoch_callback(self, epoch):
        """Method that is called after an epoch has finished its process."""

    def get_checkpoint_name(self, epoch):
        """Get the name of the checkpoint file.

        If we are going to save all the checkpoint we write the epoch, if not,
        we use always the same name.
        """
        if self.save_only_best:
            return 'checkpoint.pth.tar'

        return 'checkpoint_epoch_{}.pth.tar'.format(epoch)

    def save(self, epoch, current_loss=None):
        """Save the checkpoint of the trainer.

        The checkpoint is a dict like:
        {'epoch': int, 'model': state_dict, 'optimizer': state_dict, 'scheduler': state_dict}
        where the scheduler is optional.

        Arguments:
            epoch (int): The epoch that has finished.
        """
        if self.save_only_best and current_loss > self.best_loss:
            return

        self.best_loss = current_loss

        params = self.hyperparameters['checkpoint']
        path = os.path.join(params['dir'], self.get_checkpoint_name(epoch))

        if params['verbose']:
            print('[Epoch {}] Saving checkpoint to: {}'.format(epoch, path))

        checkpoint = {'epoch': epoch,
                      'best_loss': self.best_loss,
                      'model': self.model.state_dict(),
                      'optimizer': self.optimizer.state_dict(),
                      'hyperparameters': self.hyperparameters}

        if self.scheduler is not None:
            checkpoint['scheduler'] = self.scheduler.state_dict()

        torch.save(checkpoint, path)

    def resume(self, checkpoint):
        """Resume the training based on a last checkpoint and get the checkpoint dict.

        This method does only return the epoch value in the dict to avoid memory leaks, we don't need
        to keep the state_dicts in memory.
        You can customize your own trainer and return more values.

        Arguments:
            checkpoint (str): The path to the checkpoint file.

        Returns:
            dict: A dict with the epoch only. The state dict are not returned to not keep them
                in memory.
        """
        if checkpoint is None:
            return None

        verbose = self.hyperparameters['checkpoint']['verbose']

        if verbose:
            print('Loading checkpoint from {}'.format(checkpoint))

        checkpoint_path = checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        for state in self.optimizer.state.values():
            for k, val in state.items():
                if torch.is_tensor(val):
                    state[k] = val.to(self.device)

        if 'scheduler' in checkpoint:
            # The scheduler could not be mapped to gpu, it raises errors
            checkpoint = torch.load(checkpoint_path)
            self.scheduler.load_state_dict(checkpoint['scheduler'])

        self.best_loss = checkpoint.get('best_loss', 1e10)

        return {'epoch': checkpoint['epoch']}

    @classmethod
    def from_checkpoint(cls, checkpoint, new_params=None, device=None, verbose=True):
        """Get an instance of the trainer based on the given checkpoint file.

        This is very useful because the checkpoint saves the hyperparameters too,
        so you have a trainer with the same hyperparameters that one from the checkpoint.

        Also, you can use this method to load the model, because you can do
        `trainer.model` to get the model instance.

        Arguments:
            checkpoint (str): The path to the file that contains the checkpoint file.
            new_params (dict, optional): A dict with new hyperparameters to change the ones
                in the checkpoint. Useful for example to change the batch size, the dataset root,
                etc.

        Returns:
            Trainer: An instance of the trainer with the exact same hyperparameters and with
                the modules with their state_dicts from the checkpoint too.
        """
        hyperparameters = merge_dicts(torch.load(checkpoint)['hyperparameters'], new_params, verbose)

        return cls(hyperparameters=hyperparameters, checkpoint=checkpoint, device=device)

Classes

class Trainer

Base Trainer class, all the trainers must extend this class.

Source code
class Trainer():
    """Base Trainer class, all the trainers must extend this class."""
    # A dict with all the hyperparameters for the different components of the training
    hyperparameters = {}

    def __init__(self, hyperparameters=None, checkpoint=None, device=None, save_only_best=True):
        """Initialize the trainer.

        Arguments:
            hyperparameters (dict, optional): A dict to change the base hyperparameters.
                If it's present, it will be deeply merged with the base hyperparameters.
            checkpoint (str, optional): A path to a checkpoint (generated by the same trainer)
                to resume the training.
            device (str, optional): The device where to run the training.
            save_only_best (bool, optional): If True, it will save only the best checkpoint, not all
                the checkpoints for each epoch.
        """
        base_hyperparameters = {'checkpoint': {'dir': LOGS_DIR, 'verbose': True},
                                'logger': {'dir': LOGS_DIR}}
        # Add the base hyperparameters to the trainer hyperparameters
        self.hyperparameters = merge_dicts(self.hyperparameters, base_hyperparameters)
        # Add the modified hyperparameters given in the initialization
        self.hyperparameters = merge_dicts(self.hyperparameters, hyperparameters, verbose=True)
        # Set the device of the trainer
        self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'

        # Flag to save only the best model
        self.save_only_best = save_only_best
        self.best_loss = 1e10

        # Get the modules for the training
        print('Loading datasets ...')
        self.dataset, self.valid_dataset = self.get_datasets()
        print('Loading dataloaders ...')
        self.dataloader, self.valid_dataloader = self.get_dataloaders()
        print('Loading model ...')
        self.model = self.get_model()
        print('Loading criterion, optimizer and scheduler ...')
        self.criterion = self.get_criterion()
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_scheduler()

        # Load the checkpoint
        self.checkpoint = self.resume(checkpoint)
        # Get the logger
        self.logger = self.get_logger()
        # As we only log one time per batch we need to keep the state of all the elements that we want to log
        self.current_log = {}

    ####################################
    ###           GETTERS            ###
    ####################################

    def get_datasets(self):
        """Get the training and validation datasets.

        Returns:
            tuple: A Tuple with the torch.utils.data.Datasets for training and validation.
        """
        raise NotImplementedError('You must provide your own datasets.')

    def get_dataloaders(self):
        """Get the dataloaders for training and validation.

        Returns:
            tuple: The tuple with the torch.utils.data.DataLoader for training and validation.
        """
        raise NotImplementedError('You must provide the dataloaders for the datasets.')

    def get_model(self):
        """Get the model to train.

        Returns:
            torch.nn.Module: The model to train.
        """
        raise NotImplementedError('You must provide the model to train.')

    def get_criterion(self):
        """Get the criterion to use to train the model.

        Returns:
            torch.nn.Module: The criterion to use in the training.
        """
        return NotImplementedError('You must provide the criterion to train the model.')

    def get_optimizer(self):
        """Get the optimizer of the model.

        Returns:
            torch.optim.Optimizer: The optimizer of the model's parameters.
        """
        raise NotImplementedError('You must provide the optimizer to use during the training.')

    def get_scheduler(self):
        """Get the (optional) scheduler for the learning rate of the optimizer.

        Returns:
            torch.optim.lr_scheduler._LRScheduler: A scheduler for the learning rate.
        """
        # No error because the scheduler is optional.

    def get_logger(self):
        """Get the (optional) logger to use during the training to show the information about the process.

        This base implementation uses the PrintLogger that will print the log to the console.

        Returns:
            pymatch.loggers.Logger: A Logger to use during the training.
        """
        description = 'Hyperparameters:\n{}'.format(json.dumps(self.hyperparameters, indent=2))
        return PrintLogger(description, self.hyperparameters['logger']['dir'])

    ####################################
    ###           METHODS            ###
    ####################################

    def train(self, epochs=100, validate=True):
        """Train the model during the giving epochs.

        Arguments:
            epochs (int, optional): The number of epochs to run the model.
            validate (bool, optional): If it's True the trainer will validate the training using
                the validate() method. And if there is a scheduler it gives the validation loss
                generated by the validate() method to the scheduler to adjust the learning rate.
        """
        if not validate and self.save_only_best:
            raise ValueError('You could not disable validation and save only the best model. '
                             'The model must be validated to know if it\'s the best one.')

        self.model.to(self.device)

        # The criterion could be inside the model for example and in that case it could be None
        if self.criterion is not None:
            self.criterion.to(self.device)

        # The number of batches that the training dataset have
        n_batches = len(self.dataloader)

        # The start time of the training and the last batch's end time
        start_time = time.time()
        last_endtime = start_time

        # We start from the next epoch of the checkpoint (if there is any)
        start_epoch = 1 if self.checkpoint is None else self.checkpoint['epoch'] + 1

        for epoch in range(start_epoch, start_epoch + epochs):
            # Indicate to the model that we are in training mode, useful for batch normalization or dropouts modules.
            # For more info see:
            # https://discuss.pytorch.org/t/trying-to-understand-the-meaning-of-model-train-and-model-eval/20158
            self.model.train()

            for batch, data in enumerate(self.dataloader):
                # Optimize
                self.optimizer.zero_grad()
                loss = self.forward(*data)
                self.backward(loss)
                self.optimizer.step()

                # Log the batch
                learning_rates = [str(param_group['lr'])
                                  for i, param_group in enumerate(self.optimizer.param_groups)]

                total_time = time.time() - start_time
                batch_time = time.time() - last_endtime
                last_endtime = time.time()

                self.logger.log(merge_dicts({
                    'Training': None,
                    'Epoch': epoch,
                    'Batch': '{}/{}'.format(batch + 1, n_batches),
                    'LR': ' '.join(learning_rates),
                    'Loss': '{:.7f}'.format(float(loss)),
                    'Time': '{:.3f} s'.format(batch_time),
                    'Total': '{:.1f} s'.format(total_time)
                }, self.current_log))
                self.current_log = {}  # Restart the log dict for the next batch

                # Call the callback for the batch
                self.batch_callback(batch, epoch)

            # Call the callback for the epoch
            self.epoch_callback(epoch)

            # Save the checkpoint for this epoch
            if not self.save_only_best:
                self.save(epoch)

            if validate:
                loss = self.validate(epoch)
                self.save(epoch, loss)
                if self.scheduler is not None:
                    self.scheduler.step(loss)

    def forward(self, *args):
        """Do a forward pass over the model with the model and get the loss value.

        Arguments:
            *args: All the data that the dataloader generates while iterating over it.

        Returns:
            torch.Tensor: The loss value of the forward pass.
        """
        raise NotImplementedError('You must implement the forward pass over the model.')

    def backward(self, loss):
        """Do the backward pass over the network.

        There is a method for this because each experiment could do different things during the backward like:
        ```python
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
        ```

        But in this case it only does the backward of the loss.

        Arguments:
            loss (torch.Tensor): The loss value computed during the forward pass.
        """
        loss.backward()

    def eval(self):
        """Set the model into evaluation mode.

        It's a method to override this and provide a custom eval() call if you want.
        """
        self.model.eval()

    def validate(self, epoch):
        """Run the model over the validation dataset and return the mean loss over it."""
        self.model.to(self.device)
        self.eval()

        start_time = time.time()
        last_endtime = start_time

        n_batches = len(self.valid_dataloader)

        losses = []

        with torch.no_grad():
            for batch, data in enumerate(self.valid_dataloader):
                loss = float(self.forward(*data))

                batch_time = time.time() - last_endtime
                last_endtime = time.time()
                total_time = time.time() - start_time

                self.logger.log(merge_dicts({
                    'Validating': None,
                    'Epoch': epoch,
                    'Batch': '{}/{}'.format(batch + 1, n_batches),
                    'Loss': '{:.7f}'.format(float(loss)),
                    'Time': '{:.3f} s'.format(batch_time),
                    'Total': '{:.1f} s'.format(total_time)
                }, self.current_log))
                self.current_log = {}  # Restart the log dict for the next batch

                losses.append(loss)

        return torch.Tensor(losses).mean()

    def batch_callback(self, batch, epoch):
        """Method that is called after a batch has finished its process."""

    def epoch_callback(self, epoch):
        """Method that is called after an epoch has finished its process."""

    def get_checkpoint_name(self, epoch):
        """Get the name of the checkpoint file.

        If we are going to save all the checkpoint we write the epoch, if not,
        we use always the same name.
        """
        if self.save_only_best:
            return 'checkpoint.pth.tar'

        return 'checkpoint_epoch_{}.pth.tar'.format(epoch)

    def save(self, epoch, current_loss=None):
        """Save the checkpoint of the trainer.

        The checkpoint is a dict like:
        {'epoch': int, 'model': state_dict, 'optimizer': state_dict, 'scheduler': state_dict}
        where the scheduler is optional.

        Arguments:
            epoch (int): The epoch that has finished.
        """
        if self.save_only_best and current_loss > self.best_loss:
            return

        self.best_loss = current_loss

        params = self.hyperparameters['checkpoint']
        path = os.path.join(params['dir'], self.get_checkpoint_name(epoch))

        if params['verbose']:
            print('[Epoch {}] Saving checkpoint to: {}'.format(epoch, path))

        checkpoint = {'epoch': epoch,
                      'best_loss': self.best_loss,
                      'model': self.model.state_dict(),
                      'optimizer': self.optimizer.state_dict(),
                      'hyperparameters': self.hyperparameters}

        if self.scheduler is not None:
            checkpoint['scheduler'] = self.scheduler.state_dict()

        torch.save(checkpoint, path)

    def resume(self, checkpoint):
        """Resume the training based on a last checkpoint and get the checkpoint dict.

        This method does only return the epoch value in the dict to avoid memory leaks, we don't need
        to keep the state_dicts in memory.
        You can customize your own trainer and return more values.

        Arguments:
            checkpoint (str): The path to the checkpoint file.

        Returns:
            dict: A dict with the epoch only. The state dict are not returned to not keep them
                in memory.
        """
        if checkpoint is None:
            return None

        verbose = self.hyperparameters['checkpoint']['verbose']

        if verbose:
            print('Loading checkpoint from {}'.format(checkpoint))

        checkpoint_path = checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        for state in self.optimizer.state.values():
            for k, val in state.items():
                if torch.is_tensor(val):
                    state[k] = val.to(self.device)

        if 'scheduler' in checkpoint:
            # The scheduler could not be mapped to gpu, it raises errors
            checkpoint = torch.load(checkpoint_path)
            self.scheduler.load_state_dict(checkpoint['scheduler'])

        self.best_loss = checkpoint.get('best_loss', 1e10)

        return {'epoch': checkpoint['epoch']}

    @classmethod
    def from_checkpoint(cls, checkpoint, new_params=None, device=None, verbose=True):
        """Get an instance of the trainer based on the given checkpoint file.

        This is very useful because the checkpoint saves the hyperparameters too,
        so you have a trainer with the same hyperparameters that one from the checkpoint.

        Also, you can use this method to load the model, because you can do
        `trainer.model` to get the model instance.

        Arguments:
            checkpoint (str): The path to the file that contains the checkpoint file.
            new_params (dict, optional): A dict with new hyperparameters to change the ones
                in the checkpoint. Useful for example to change the batch size, the dataset root,
                etc.

        Returns:
            Trainer: An instance of the trainer with the exact same hyperparameters and with
                the modules with their state_dicts from the checkpoint too.
        """
        hyperparameters = merge_dicts(torch.load(checkpoint)['hyperparameters'], new_params, verbose)

        return cls(hyperparameters=hyperparameters, checkpoint=checkpoint, device=device)

Class variables

var hyperparameters

Static methods

def from_checkpoint(cls, checkpoint, new_params=None, device=None, verbose=True)

Get an instance of the trainer based on the given checkpoint file.

This is very useful because the checkpoint saves the hyperparameters too, so you have a trainer with the same hyperparameters that one from the checkpoint.

Also, you can use this method to load the model, because you can do trainer.model to get the model instance.

Arguments

checkpoint : str
The path to the file that contains the checkpoint file.
new_params : dict, optional
A dict with new hyperparameters to change the ones in the checkpoint. Useful for example to change the batch size, the dataset root, etc.

Returns

Trainer
An instance of the trainer with the exact same hyperparameters and with the modules with their state_dicts from the checkpoint too.
Source code
@classmethod
def from_checkpoint(cls, checkpoint, new_params=None, device=None, verbose=True):
    """Get an instance of the trainer based on the given checkpoint file.

    This is very useful because the checkpoint saves the hyperparameters too,
    so you have a trainer with the same hyperparameters that one from the checkpoint.

    Also, you can use this method to load the model, because you can do
    `trainer.model` to get the model instance.

    Arguments:
        checkpoint (str): The path to the file that contains the checkpoint file.
        new_params (dict, optional): A dict with new hyperparameters to change the ones
            in the checkpoint. Useful for example to change the batch size, the dataset root,
            etc.

    Returns:
        Trainer: An instance of the trainer with the exact same hyperparameters and with
            the modules with their state_dicts from the checkpoint too.
    """
    hyperparameters = merge_dicts(torch.load(checkpoint)['hyperparameters'], new_params, verbose)

    return cls(hyperparameters=hyperparameters, checkpoint=checkpoint, device=device)

Methods

def __init__(self, hyperparameters=None, checkpoint=None, device=None, save_only_best=True)

Initialize the trainer.

Arguments

hyperparameters : dict, optional
A dict to change the base hyperparameters. If it's present, it will be deeply merged with the base hyperparameters.
checkpoint : str, optional
A path to a checkpoint (generated by the same trainer) to resume the training.
device : str, optional
The device where to run the training.
save_only_best : bool, optional
If True, it will save only the best checkpoint, not all the checkpoints for each epoch.
Source code
def __init__(self, hyperparameters=None, checkpoint=None, device=None, save_only_best=True):
    """Initialize the trainer.

    Arguments:
        hyperparameters (dict, optional): A dict to change the base hyperparameters.
            If it's present, it will be deeply merged with the base hyperparameters.
        checkpoint (str, optional): A path to a checkpoint (generated by the same trainer)
            to resume the training.
        device (str, optional): The device where to run the training.
        save_only_best (bool, optional): If True, it will save only the best checkpoint, not all
            the checkpoints for each epoch.
    """
    base_hyperparameters = {'checkpoint': {'dir': LOGS_DIR, 'verbose': True},
                            'logger': {'dir': LOGS_DIR}}
    # Add the base hyperparameters to the trainer hyperparameters
    self.hyperparameters = merge_dicts(self.hyperparameters, base_hyperparameters)
    # Add the modified hyperparameters given in the initialization
    self.hyperparameters = merge_dicts(self.hyperparameters, hyperparameters, verbose=True)
    # Set the device of the trainer
    self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # Flag to save only the best model
    self.save_only_best = save_only_best
    self.best_loss = 1e10

    # Get the modules for the training
    print('Loading datasets ...')
    self.dataset, self.valid_dataset = self.get_datasets()
    print('Loading dataloaders ...')
    self.dataloader, self.valid_dataloader = self.get_dataloaders()
    print('Loading model ...')
    self.model = self.get_model()
    print('Loading criterion, optimizer and scheduler ...')
    self.criterion = self.get_criterion()
    self.optimizer = self.get_optimizer()
    self.scheduler = self.get_scheduler()

    # Load the checkpoint
    self.checkpoint = self.resume(checkpoint)
    # Get the logger
    self.logger = self.get_logger()
    # As we only log one time per batch we need to keep the state of all the elements that we want to log
    self.current_log = {}
def backward(self, loss)

Do the backward pass over the network.

There is a method for this because each experiment could do different things during the backward like:

loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)

But in this case it only does the backward of the loss.

Arguments

loss : torch.Tensor
The loss value computed during the forward pass.
Source code
def backward(self, loss):
    """Do the backward pass over the network.

    There is a method for this because each experiment could do different things during the backward like:
    ```python
    loss.backward()
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
    ```

    But in this case it only does the backward of the loss.

    Arguments:
        loss (torch.Tensor): The loss value computed during the forward pass.
    """
    loss.backward()
def batch_callback(self, batch, epoch)

Method that is called after a batch has finished its process.

Source code
def batch_callback(self, batch, epoch):
    """Method that is called after a batch has finished its process."""
def epoch_callback(self, epoch)

Method that is called after an epoch has finished its process.

Source code
def epoch_callback(self, epoch):
    """Method that is called after an epoch has finished its process."""
def eval(self)

Set the model into evaluation mode.

It's a method to override this and provide a custom eval() call if you want.

Source code
def eval(self):
    """Set the model into evaluation mode.

    It's a method to override this and provide a custom eval() call if you want.
    """
    self.model.eval()
def forward(self, *args)

Do a forward pass over the model with the model and get the loss value.

Arguments

*args
All the data that the dataloader generates while iterating over it.

Returns

torch.Tensor: The loss value of the forward pass.

Source code
def forward(self, *args):
    """Do a forward pass over the model with the model and get the loss value.

    Arguments:
        *args: All the data that the dataloader generates while iterating over it.

    Returns:
        torch.Tensor: The loss value of the forward pass.
    """
    raise NotImplementedError('You must implement the forward pass over the model.')
def get_checkpoint_name(self, epoch)

Get the name of the checkpoint file.

If we are going to save all the checkpoint we write the epoch, if not, we use always the same name.

Source code
def get_checkpoint_name(self, epoch):
    """Get the name of the checkpoint file.

    If we are going to save all the checkpoint we write the epoch, if not,
    we use always the same name.
    """
    if self.save_only_best:
        return 'checkpoint.pth.tar'

    return 'checkpoint_epoch_{}.pth.tar'.format(epoch)
def get_criterion(self)

Get the criterion to use to train the model.

Returns

torch.nn.Module: The criterion to use in the training.

Source code
def get_criterion(self):
    """Get the criterion to use to train the model.

    Returns:
        torch.nn.Module: The criterion to use in the training.
    """
    return NotImplementedError('You must provide the criterion to train the model.')
def get_dataloaders(self)

Get the dataloaders for training and validation.

Returns

tuple
The tuple with the torch.utils.data.DataLoader for training and validation.
Source code
def get_dataloaders(self):
    """Get the dataloaders for training and validation.

    Returns:
        tuple: The tuple with the torch.utils.data.DataLoader for training and validation.
    """
    raise NotImplementedError('You must provide the dataloaders for the datasets.')
def get_datasets(self)

Get the training and validation datasets.

Returns

tuple
A Tuple with the torch.utils.data.Datasets for training and validation.
Source code
def get_datasets(self):
    """Get the training and validation datasets.

    Returns:
        tuple: A Tuple with the torch.utils.data.Datasets for training and validation.
    """
    raise NotImplementedError('You must provide your own datasets.')
def get_logger(self)

Get the (optional) logger to use during the training to show the information about the process.

This base implementation uses the PrintLogger that will print the log to the console.

Returns

pymatch.loggers.Logger: A Logger to use during the training.

Source code
def get_logger(self):
    """Get the (optional) logger to use during the training to show the information about the process.

    This base implementation uses the PrintLogger that will print the log to the console.

    Returns:
        pymatch.loggers.Logger: A Logger to use during the training.
    """
    description = 'Hyperparameters:\n{}'.format(json.dumps(self.hyperparameters, indent=2))
    return PrintLogger(description, self.hyperparameters['logger']['dir'])
def get_model(self)

Get the model to train.

Returns

torch.nn.Module: The model to train.

Source code
def get_model(self):
    """Get the model to train.

    Returns:
        torch.nn.Module: The model to train.
    """
    raise NotImplementedError('You must provide the model to train.')
def get_optimizer(self)

Get the optimizer of the model.

Returns

torch.optim.Optimizer: The optimizer of the model's parameters.

Source code
def get_optimizer(self):
    """Get the optimizer of the model.

    Returns:
        torch.optim.Optimizer: The optimizer of the model's parameters.
    """
    raise NotImplementedError('You must provide the optimizer to use during the training.')
def get_scheduler(self)

Get the (optional) scheduler for the learning rate of the optimizer.

Returns

torch.optim.lr_scheduler._LRScheduler: A scheduler for the learning rate.

Source code
def get_scheduler(self):
    """Get the (optional) scheduler for the learning rate of the optimizer.

    Returns:
        torch.optim.lr_scheduler._LRScheduler: A scheduler for the learning rate.
    """
def resume(self, checkpoint)

Resume the training based on a last checkpoint and get the checkpoint dict.

This method does only return the epoch value in the dict to avoid memory leaks, we don't need to keep the state_dicts in memory. You can customize your own trainer and return more values.

Arguments

checkpoint : str
The path to the checkpoint file.

Returns

dict
A dict with the epoch only. The state dict are not returned to not keep them in memory.
Source code
def resume(self, checkpoint):
    """Resume the training based on a last checkpoint and get the checkpoint dict.

    This method does only return the epoch value in the dict to avoid memory leaks, we don't need
    to keep the state_dicts in memory.
    You can customize your own trainer and return more values.

    Arguments:
        checkpoint (str): The path to the checkpoint file.

    Returns:
        dict: A dict with the epoch only. The state dict are not returned to not keep them
            in memory.
    """
    if checkpoint is None:
        return None

    verbose = self.hyperparameters['checkpoint']['verbose']

    if verbose:
        print('Loading checkpoint from {}'.format(checkpoint))

    checkpoint_path = checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=self.device)

    self.model.load_state_dict(checkpoint['model'])
    self.optimizer.load_state_dict(checkpoint['optimizer'])
    for state in self.optimizer.state.values():
        for k, val in state.items():
            if torch.is_tensor(val):
                state[k] = val.to(self.device)

    if 'scheduler' in checkpoint:
        # The scheduler could not be mapped to gpu, it raises errors
        checkpoint = torch.load(checkpoint_path)
        self.scheduler.load_state_dict(checkpoint['scheduler'])

    self.best_loss = checkpoint.get('best_loss', 1e10)

    return {'epoch': checkpoint['epoch']}
def save(self, epoch, current_loss=None)

Save the checkpoint of the trainer.

The checkpoint is a dict like: {'epoch': int, 'model': state_dict, 'optimizer': state_dict, 'scheduler': state_dict} where the scheduler is optional.

Arguments

epoch : int
The epoch that has finished.
Source code
def save(self, epoch, current_loss=None):
    """Save the checkpoint of the trainer.

    The checkpoint is a dict like:
    {'epoch': int, 'model': state_dict, 'optimizer': state_dict, 'scheduler': state_dict}
    where the scheduler is optional.

    Arguments:
        epoch (int): The epoch that has finished.
    """
    if self.save_only_best and current_loss > self.best_loss:
        return

    self.best_loss = current_loss

    params = self.hyperparameters['checkpoint']
    path = os.path.join(params['dir'], self.get_checkpoint_name(epoch))

    if params['verbose']:
        print('[Epoch {}] Saving checkpoint to: {}'.format(epoch, path))

    checkpoint = {'epoch': epoch,
                  'best_loss': self.best_loss,
                  'model': self.model.state_dict(),
                  'optimizer': self.optimizer.state_dict(),
                  'hyperparameters': self.hyperparameters}

    if self.scheduler is not None:
        checkpoint['scheduler'] = self.scheduler.state_dict()

    torch.save(checkpoint, path)
def train(self, epochs=100, validate=True)

Train the model during the giving epochs.

Arguments

epochs : int, optional
The number of epochs to run the model.
validate : bool, optional
If it's True the trainer will validate the training using the validate() method. And if there is a scheduler it gives the validation loss generated by the validate() method to the scheduler to adjust the learning rate.
Source code
def train(self, epochs=100, validate=True):
    """Train the model during the giving epochs.

    Arguments:
        epochs (int, optional): The number of epochs to run the model.
        validate (bool, optional): If it's True the trainer will validate the training using
            the validate() method. And if there is a scheduler it gives the validation loss
            generated by the validate() method to the scheduler to adjust the learning rate.
    """
    if not validate and self.save_only_best:
        raise ValueError('You could not disable validation and save only the best model. '
                         'The model must be validated to know if it\'s the best one.')

    self.model.to(self.device)

    # The criterion could be inside the model for example and in that case it could be None
    if self.criterion is not None:
        self.criterion.to(self.device)

    # The number of batches that the training dataset have
    n_batches = len(self.dataloader)

    # The start time of the training and the last batch's end time
    start_time = time.time()
    last_endtime = start_time

    # We start from the next epoch of the checkpoint (if there is any)
    start_epoch = 1 if self.checkpoint is None else self.checkpoint['epoch'] + 1

    for epoch in range(start_epoch, start_epoch + epochs):
        # Indicate to the model that we are in training mode, useful for batch normalization or dropouts modules.
        # For more info see:
        # https://discuss.pytorch.org/t/trying-to-understand-the-meaning-of-model-train-and-model-eval/20158
        self.model.train()

        for batch, data in enumerate(self.dataloader):
            # Optimize
            self.optimizer.zero_grad()
            loss = self.forward(*data)
            self.backward(loss)
            self.optimizer.step()

            # Log the batch
            learning_rates = [str(param_group['lr'])
                              for i, param_group in enumerate(self.optimizer.param_groups)]

            total_time = time.time() - start_time
            batch_time = time.time() - last_endtime
            last_endtime = time.time()

            self.logger.log(merge_dicts({
                'Training': None,
                'Epoch': epoch,
                'Batch': '{}/{}'.format(batch + 1, n_batches),
                'LR': ' '.join(learning_rates),
                'Loss': '{:.7f}'.format(float(loss)),
                'Time': '{:.3f} s'.format(batch_time),
                'Total': '{:.1f} s'.format(total_time)
            }, self.current_log))
            self.current_log = {}  # Restart the log dict for the next batch

            # Call the callback for the batch
            self.batch_callback(batch, epoch)

        # Call the callback for the epoch
        self.epoch_callback(epoch)

        # Save the checkpoint for this epoch
        if not self.save_only_best:
            self.save(epoch)

        if validate:
            loss = self.validate(epoch)
            self.save(epoch, loss)
            if self.scheduler is not None:
                self.scheduler.step(loss)
def validate(self, epoch)

Run the model over the validation dataset and return the mean loss over it.

Source code
def validate(self, epoch):
    """Run the model over the validation dataset and return the mean loss over it."""
    self.model.to(self.device)
    self.eval()

    start_time = time.time()
    last_endtime = start_time

    n_batches = len(self.valid_dataloader)

    losses = []

    with torch.no_grad():
        for batch, data in enumerate(self.valid_dataloader):
            loss = float(self.forward(*data))

            batch_time = time.time() - last_endtime
            last_endtime = time.time()
            total_time = time.time() - start_time

            self.logger.log(merge_dicts({
                'Validating': None,
                'Epoch': epoch,
                'Batch': '{}/{}'.format(batch + 1, n_batches),
                'Loss': '{:.7f}'.format(float(loss)),
                'Time': '{:.3f} s'.format(batch_time),
                'Total': '{:.1f} s'.format(total_time)
            }, self.current_log))
            self.current_log = {}  # Restart the log dict for the next batch

            losses.append(loss)

    return torch.Tensor(losses).mean()