torchsight.evaluators.evaluator module

Base Evaluator class.

Source code
"""Base Evaluator class."""
import time

import torch

from torchsight.loggers import PrintLogger
from torchsight.utils import PrintMixin, merge_dicts


class Evaluator(PrintMixin):
    """An evaluator base class.

    The evaluator is an interface to evaluate different models. You can override the methods to get the dataset,
    the dataloader and the model, and override the method to compute the forward pass where you can put any
    metric that you want to track.

    This class is intended to avoid boilerplate code and only focus on compute the metric or evaluation that you
    want to have.
    """

    def __init__(self, checkpoint, params=None, device=None):
        """Initialize the evaluator, get the dataset and the model to evaluate.

        Arguments:
            checkpoint (str): The checkpoint generated by the trainer where we can find the model state.
            params (dict): A dict to replace the base params of the evaluator.
            device (str, optional): The device to use for the evaluation.
        """
        self.params = merge_dicts(self.get_base_params(), params, verbose=True)
        self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.print('Using device "{}"'.format(self.device))

        self.checkpoint_path = checkpoint
        self.print('Loading checkpoint ...')
        self.checkpoint = torch.load(checkpoint, map_location=self.device)

        self.print('Loading dataset ...')
        self.dataset = self.get_dataset()
        self.print('Loading dataloader ...')
        self.dataloader = self.get_dataloader()
        self.print('Loading model ...')
        self.model = self.get_model()
        self.logger = self.get_logger()

        # Keep a dict with the key-value pairs to log after each batch
        self.current_log = {}

        # Avoid memory leaks by removing the loaded checkpoint but keep the path to it
        self.checkpoint = None

    @staticmethod
    def get_base_params():
        """Get the base parameters of the evaluator."""
        return {}

    ###############################
    ###         GETTERS         ###
    ###############################

    def get_dataset(self):
        """Get the dataset for the evaluation.

        Returns:
            torch.utils.data.Dataset: The dataset to use for the evaluation.
        """
        raise NotImplementedError('You must provide your own dataset')

    def get_dataloader(self):
        """Get the dataloader to use for the evaluation.

        Returns:
            torch.utils.data.Dataloader: The dataloader to use for the evaluation.
        """
        raise NotImplementedError('You must provide your own dataloader')

    def get_model(self):
        """Get the model to use to make the predictions.

        Returns:
            torch.nn.Module: The model to use to make the predictions over the data.
        """
        raise NotImplementedError('You must provide your own model')

    def get_logger(self):
        """Get the logger to use during the training to show the information about the process.

        This base implementation uses the PrintLogger that will print the log to the console.

        Returns:
            pymatch.loggers.Logger: A Logger to use during the training.
        """
        return PrintLogger(description=None)

    ###############################
    ###         METHODS         ###
    ###############################

    def eval_mode(self):
        """Put the model in evaluation mode.

        You can override this method to pass arguments to the eval() method of the model
        or do anything what you want.
        """
        self.model.eval()

    def evaluate(self):
        """Run the model over the entire dataset and compute the evaluation metric."""
        self.print('Starting evaluation ...')
        self.model.to(self.device)
        self.eval_mode()

        # The number of batches that the dataset has
        n_batches = len(self.dataloader)

        # The start time of the evaluation and the last batch's end time
        start_time = time.time()
        last_endtime = start_time

        with torch.no_grad():
            for batch, data in enumerate(self.dataloader):
                self.forward(*data)

                total_time = time.time() - start_time
                batch_time = time.time() - last_endtime
                last_endtime = time.time()

                self.logger.log(merge_dicts({
                    self.__class__.__name__: None,
                    'Batch': '{}/{}'.format(batch + 1, n_batches),
                    'Time': '{:.3f} s'.format(batch_time),
                    'Total': '{:.1f} s'.format(total_time)
                }, self.current_log))
                self.current_log = {}

                # A callback after each batch
                self.batch_callback(batch)

        # A final callback
        self.evaluate_callback()

    def forward(self, *args):
        """Forward pass through the model, make the predictions and store them.

        Inside this method you should store the predictions made by the model, perform
        some metrics and do the evaluation.
        """
        raise NotImplementedError('Yous should implement your own forward to make the predictions')

    def batch_callback(self, batch):
        """A callback that is called after is batch.

        Arguments:
            batch (int): The number of the batch that was ran before this callback.
        """

    def evaluate_callback(self):
        """A callback that is called at the end of the evaluation."""

Classes

class Evaluator (ancestors: PrintMixin)

An evaluator base class.

The evaluator is an interface to evaluate different models. You can override the methods to get the dataset, the dataloader and the model, and override the method to compute the forward pass where you can put any metric that you want to track.

This class is intended to avoid boilerplate code and only focus on compute the metric or evaluation that you want to have.

Source code
class Evaluator(PrintMixin):
    """An evaluator base class.

    The evaluator is an interface to evaluate different models. You can override the methods to get the dataset,
    the dataloader and the model, and override the method to compute the forward pass where you can put any
    metric that you want to track.

    This class is intended to avoid boilerplate code and only focus on compute the metric or evaluation that you
    want to have.
    """

    def __init__(self, checkpoint, params=None, device=None):
        """Initialize the evaluator, get the dataset and the model to evaluate.

        Arguments:
            checkpoint (str): The checkpoint generated by the trainer where we can find the model state.
            params (dict): A dict to replace the base params of the evaluator.
            device (str, optional): The device to use for the evaluation.
        """
        self.params = merge_dicts(self.get_base_params(), params, verbose=True)
        self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.print('Using device "{}"'.format(self.device))

        self.checkpoint_path = checkpoint
        self.print('Loading checkpoint ...')
        self.checkpoint = torch.load(checkpoint, map_location=self.device)

        self.print('Loading dataset ...')
        self.dataset = self.get_dataset()
        self.print('Loading dataloader ...')
        self.dataloader = self.get_dataloader()
        self.print('Loading model ...')
        self.model = self.get_model()
        self.logger = self.get_logger()

        # Keep a dict with the key-value pairs to log after each batch
        self.current_log = {}

        # Avoid memory leaks by removing the loaded checkpoint but keep the path to it
        self.checkpoint = None

    @staticmethod
    def get_base_params():
        """Get the base parameters of the evaluator."""
        return {}

    ###############################
    ###         GETTERS         ###
    ###############################

    def get_dataset(self):
        """Get the dataset for the evaluation.

        Returns:
            torch.utils.data.Dataset: The dataset to use for the evaluation.
        """
        raise NotImplementedError('You must provide your own dataset')

    def get_dataloader(self):
        """Get the dataloader to use for the evaluation.

        Returns:
            torch.utils.data.Dataloader: The dataloader to use for the evaluation.
        """
        raise NotImplementedError('You must provide your own dataloader')

    def get_model(self):
        """Get the model to use to make the predictions.

        Returns:
            torch.nn.Module: The model to use to make the predictions over the data.
        """
        raise NotImplementedError('You must provide your own model')

    def get_logger(self):
        """Get the logger to use during the training to show the information about the process.

        This base implementation uses the PrintLogger that will print the log to the console.

        Returns:
            pymatch.loggers.Logger: A Logger to use during the training.
        """
        return PrintLogger(description=None)

    ###############################
    ###         METHODS         ###
    ###############################

    def eval_mode(self):
        """Put the model in evaluation mode.

        You can override this method to pass arguments to the eval() method of the model
        or do anything what you want.
        """
        self.model.eval()

    def evaluate(self):
        """Run the model over the entire dataset and compute the evaluation metric."""
        self.print('Starting evaluation ...')
        self.model.to(self.device)
        self.eval_mode()

        # The number of batches that the dataset has
        n_batches = len(self.dataloader)

        # The start time of the evaluation and the last batch's end time
        start_time = time.time()
        last_endtime = start_time

        with torch.no_grad():
            for batch, data in enumerate(self.dataloader):
                self.forward(*data)

                total_time = time.time() - start_time
                batch_time = time.time() - last_endtime
                last_endtime = time.time()

                self.logger.log(merge_dicts({
                    self.__class__.__name__: None,
                    'Batch': '{}/{}'.format(batch + 1, n_batches),
                    'Time': '{:.3f} s'.format(batch_time),
                    'Total': '{:.1f} s'.format(total_time)
                }, self.current_log))
                self.current_log = {}

                # A callback after each batch
                self.batch_callback(batch)

        # A final callback
        self.evaluate_callback()

    def forward(self, *args):
        """Forward pass through the model, make the predictions and store them.

        Inside this method you should store the predictions made by the model, perform
        some metrics and do the evaluation.
        """
        raise NotImplementedError('Yous should implement your own forward to make the predictions')

    def batch_callback(self, batch):
        """A callback that is called after is batch.

        Arguments:
            batch (int): The number of the batch that was ran before this callback.
        """

    def evaluate_callback(self):
        """A callback that is called at the end of the evaluation."""

Subclasses

Static methods

def get_base_params()

Get the base parameters of the evaluator.

Source code
@staticmethod
def get_base_params():
    """Get the base parameters of the evaluator."""
    return {}

Methods

def __init__(self, checkpoint, params=None, device=None)

Initialize the evaluator, get the dataset and the model to evaluate.

Arguments

checkpoint : str
The checkpoint generated by the trainer where we can find the model state.
params : dict
A dict to replace the base params of the evaluator.
device : str, optional
The device to use for the evaluation.
Source code
def __init__(self, checkpoint, params=None, device=None):
    """Initialize the evaluator, get the dataset and the model to evaluate.

    Arguments:
        checkpoint (str): The checkpoint generated by the trainer where we can find the model state.
        params (dict): A dict to replace the base params of the evaluator.
        device (str, optional): The device to use for the evaluation.
    """
    self.params = merge_dicts(self.get_base_params(), params, verbose=True)
    self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
    self.print('Using device "{}"'.format(self.device))

    self.checkpoint_path = checkpoint
    self.print('Loading checkpoint ...')
    self.checkpoint = torch.load(checkpoint, map_location=self.device)

    self.print('Loading dataset ...')
    self.dataset = self.get_dataset()
    self.print('Loading dataloader ...')
    self.dataloader = self.get_dataloader()
    self.print('Loading model ...')
    self.model = self.get_model()
    self.logger = self.get_logger()

    # Keep a dict with the key-value pairs to log after each batch
    self.current_log = {}

    # Avoid memory leaks by removing the loaded checkpoint but keep the path to it
    self.checkpoint = None
def batch_callback(self, batch)

A callback that is called after is batch.

Arguments

batch : int
The number of the batch that was ran before this callback.
Source code
def batch_callback(self, batch):
    """A callback that is called after is batch.

    Arguments:
        batch (int): The number of the batch that was ran before this callback.
    """
def eval_mode(self)

Put the model in evaluation mode.

You can override this method to pass arguments to the eval() method of the model or do anything what you want.

Source code
def eval_mode(self):
    """Put the model in evaluation mode.

    You can override this method to pass arguments to the eval() method of the model
    or do anything what you want.
    """
    self.model.eval()
def evaluate(self)

Run the model over the entire dataset and compute the evaluation metric.

Source code
def evaluate(self):
    """Run the model over the entire dataset and compute the evaluation metric."""
    self.print('Starting evaluation ...')
    self.model.to(self.device)
    self.eval_mode()

    # The number of batches that the dataset has
    n_batches = len(self.dataloader)

    # The start time of the evaluation and the last batch's end time
    start_time = time.time()
    last_endtime = start_time

    with torch.no_grad():
        for batch, data in enumerate(self.dataloader):
            self.forward(*data)

            total_time = time.time() - start_time
            batch_time = time.time() - last_endtime
            last_endtime = time.time()

            self.logger.log(merge_dicts({
                self.__class__.__name__: None,
                'Batch': '{}/{}'.format(batch + 1, n_batches),
                'Time': '{:.3f} s'.format(batch_time),
                'Total': '{:.1f} s'.format(total_time)
            }, self.current_log))
            self.current_log = {}

            # A callback after each batch
            self.batch_callback(batch)

    # A final callback
    self.evaluate_callback()
def evaluate_callback(self)

A callback that is called at the end of the evaluation.

Source code
def evaluate_callback(self):
    """A callback that is called at the end of the evaluation."""
def forward(self, *args)

Forward pass through the model, make the predictions and store them.

Inside this method you should store the predictions made by the model, perform some metrics and do the evaluation.

Source code
def forward(self, *args):
    """Forward pass through the model, make the predictions and store them.

    Inside this method you should store the predictions made by the model, perform
    some metrics and do the evaluation.
    """
    raise NotImplementedError('Yous should implement your own forward to make the predictions')
def get_dataloader(self)

Get the dataloader to use for the evaluation.

Returns

torch.utils.data.Dataloader: The dataloader to use for the evaluation.

Source code
def get_dataloader(self):
    """Get the dataloader to use for the evaluation.

    Returns:
        torch.utils.data.Dataloader: The dataloader to use for the evaluation.
    """
    raise NotImplementedError('You must provide your own dataloader')
def get_dataset(self)

Get the dataset for the evaluation.

Returns

torch.utils.data.Dataset: The dataset to use for the evaluation.

Source code
def get_dataset(self):
    """Get the dataset for the evaluation.

    Returns:
        torch.utils.data.Dataset: The dataset to use for the evaluation.
    """
    raise NotImplementedError('You must provide your own dataset')
def get_logger(self)

Get the logger to use during the training to show the information about the process.

This base implementation uses the PrintLogger that will print the log to the console.

Returns

pymatch.loggers.Logger: A Logger to use during the training.

Source code
def get_logger(self):
    """Get the logger to use during the training to show the information about the process.

    This base implementation uses the PrintLogger that will print the log to the console.

    Returns:
        pymatch.loggers.Logger: A Logger to use during the training.
    """
    return PrintLogger(description=None)
def get_model(self)

Get the model to use to make the predictions.

Returns

torch.nn.Module: The model to use to make the predictions over the data.

Source code
def get_model(self):
    """Get the model to use to make the predictions.

    Returns:
        torch.nn.Module: The model to use to make the predictions over the data.
    """
    raise NotImplementedError('You must provide your own model')

Inherited members