torchsight.evaluators.evaluator
module
Base Evaluator class.
Source code
"""Base Evaluator class."""
import time
import torch
from torchsight.loggers import PrintLogger
from torchsight.utils import PrintMixin, merge_dicts
class Evaluator(PrintMixin):
"""An evaluator base class.
The evaluator is an interface to evaluate different models. You can override the methods to get the dataset,
the dataloader and the model, and override the method to compute the forward pass where you can put any
metric that you want to track.
This class is intended to avoid boilerplate code and only focus on compute the metric or evaluation that you
want to have.
"""
def __init__(self, checkpoint, params=None, device=None):
"""Initialize the evaluator, get the dataset and the model to evaluate.
Arguments:
checkpoint (str): The checkpoint generated by the trainer where we can find the model state.
params (dict): A dict to replace the base params of the evaluator.
device (str, optional): The device to use for the evaluation.
"""
self.params = merge_dicts(self.get_base_params(), params, verbose=True)
self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.print('Using device "{}"'.format(self.device))
self.checkpoint_path = checkpoint
self.print('Loading checkpoint ...')
self.checkpoint = torch.load(checkpoint, map_location=self.device)
self.print('Loading dataset ...')
self.dataset = self.get_dataset()
self.print('Loading dataloader ...')
self.dataloader = self.get_dataloader()
self.print('Loading model ...')
self.model = self.get_model()
self.logger = self.get_logger()
# Keep a dict with the key-value pairs to log after each batch
self.current_log = {}
# Avoid memory leaks by removing the loaded checkpoint but keep the path to it
self.checkpoint = None
@staticmethod
def get_base_params():
"""Get the base parameters of the evaluator."""
return {}
###############################
### GETTERS ###
###############################
def get_dataset(self):
"""Get the dataset for the evaluation.
Returns:
torch.utils.data.Dataset: The dataset to use for the evaluation.
"""
raise NotImplementedError('You must provide your own dataset')
def get_dataloader(self):
"""Get the dataloader to use for the evaluation.
Returns:
torch.utils.data.Dataloader: The dataloader to use for the evaluation.
"""
raise NotImplementedError('You must provide your own dataloader')
def get_model(self):
"""Get the model to use to make the predictions.
Returns:
torch.nn.Module: The model to use to make the predictions over the data.
"""
raise NotImplementedError('You must provide your own model')
def get_logger(self):
"""Get the logger to use during the training to show the information about the process.
This base implementation uses the PrintLogger that will print the log to the console.
Returns:
pymatch.loggers.Logger: A Logger to use during the training.
"""
return PrintLogger(description=None)
###############################
### METHODS ###
###############################
def eval_mode(self):
"""Put the model in evaluation mode.
You can override this method to pass arguments to the eval() method of the model
or do anything what you want.
"""
self.model.eval()
def evaluate(self):
"""Run the model over the entire dataset and compute the evaluation metric."""
self.print('Starting evaluation ...')
self.model.to(self.device)
self.eval_mode()
# The number of batches that the dataset has
n_batches = len(self.dataloader)
# The start time of the evaluation and the last batch's end time
start_time = time.time()
last_endtime = start_time
with torch.no_grad():
for batch, data in enumerate(self.dataloader):
self.forward(*data)
total_time = time.time() - start_time
batch_time = time.time() - last_endtime
last_endtime = time.time()
self.logger.log(merge_dicts({
self.__class__.__name__: None,
'Batch': '{}/{}'.format(batch + 1, n_batches),
'Time': '{:.3f} s'.format(batch_time),
'Total': '{:.1f} s'.format(total_time)
}, self.current_log))
self.current_log = {}
# A callback after each batch
self.batch_callback(batch)
# A final callback
self.evaluate_callback()
def forward(self, *args):
"""Forward pass through the model, make the predictions and store them.
Inside this method you should store the predictions made by the model, perform
some metrics and do the evaluation.
"""
raise NotImplementedError('Yous should implement your own forward to make the predictions')
def batch_callback(self, batch):
"""A callback that is called after is batch.
Arguments:
batch (int): The number of the batch that was ran before this callback.
"""
def evaluate_callback(self):
"""A callback that is called at the end of the evaluation."""
Classes
class Evaluator (ancestors: PrintMixin)
-
An evaluator base class.
The evaluator is an interface to evaluate different models. You can override the methods to get the dataset, the dataloader and the model, and override the method to compute the forward pass where you can put any metric that you want to track.
This class is intended to avoid boilerplate code and only focus on compute the metric or evaluation that you want to have.
Source code
class Evaluator(PrintMixin): """An evaluator base class. The evaluator is an interface to evaluate different models. You can override the methods to get the dataset, the dataloader and the model, and override the method to compute the forward pass where you can put any metric that you want to track. This class is intended to avoid boilerplate code and only focus on compute the metric or evaluation that you want to have. """ def __init__(self, checkpoint, params=None, device=None): """Initialize the evaluator, get the dataset and the model to evaluate. Arguments: checkpoint (str): The checkpoint generated by the trainer where we can find the model state. params (dict): A dict to replace the base params of the evaluator. device (str, optional): The device to use for the evaluation. """ self.params = merge_dicts(self.get_base_params(), params, verbose=True) self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu' self.print('Using device "{}"'.format(self.device)) self.checkpoint_path = checkpoint self.print('Loading checkpoint ...') self.checkpoint = torch.load(checkpoint, map_location=self.device) self.print('Loading dataset ...') self.dataset = self.get_dataset() self.print('Loading dataloader ...') self.dataloader = self.get_dataloader() self.print('Loading model ...') self.model = self.get_model() self.logger = self.get_logger() # Keep a dict with the key-value pairs to log after each batch self.current_log = {} # Avoid memory leaks by removing the loaded checkpoint but keep the path to it self.checkpoint = None @staticmethod def get_base_params(): """Get the base parameters of the evaluator.""" return {} ############################### ### GETTERS ### ############################### def get_dataset(self): """Get the dataset for the evaluation. Returns: torch.utils.data.Dataset: The dataset to use for the evaluation. """ raise NotImplementedError('You must provide your own dataset') def get_dataloader(self): """Get the dataloader to use for the evaluation. Returns: torch.utils.data.Dataloader: The dataloader to use for the evaluation. """ raise NotImplementedError('You must provide your own dataloader') def get_model(self): """Get the model to use to make the predictions. Returns: torch.nn.Module: The model to use to make the predictions over the data. """ raise NotImplementedError('You must provide your own model') def get_logger(self): """Get the logger to use during the training to show the information about the process. This base implementation uses the PrintLogger that will print the log to the console. Returns: pymatch.loggers.Logger: A Logger to use during the training. """ return PrintLogger(description=None) ############################### ### METHODS ### ############################### def eval_mode(self): """Put the model in evaluation mode. You can override this method to pass arguments to the eval() method of the model or do anything what you want. """ self.model.eval() def evaluate(self): """Run the model over the entire dataset and compute the evaluation metric.""" self.print('Starting evaluation ...') self.model.to(self.device) self.eval_mode() # The number of batches that the dataset has n_batches = len(self.dataloader) # The start time of the evaluation and the last batch's end time start_time = time.time() last_endtime = start_time with torch.no_grad(): for batch, data in enumerate(self.dataloader): self.forward(*data) total_time = time.time() - start_time batch_time = time.time() - last_endtime last_endtime = time.time() self.logger.log(merge_dicts({ self.__class__.__name__: None, 'Batch': '{}/{}'.format(batch + 1, n_batches), 'Time': '{:.3f} s'.format(batch_time), 'Total': '{:.1f} s'.format(total_time) }, self.current_log)) self.current_log = {} # A callback after each batch self.batch_callback(batch) # A final callback self.evaluate_callback() def forward(self, *args): """Forward pass through the model, make the predictions and store them. Inside this method you should store the predictions made by the model, perform some metrics and do the evaluation. """ raise NotImplementedError('Yous should implement your own forward to make the predictions') def batch_callback(self, batch): """A callback that is called after is batch. Arguments: batch (int): The number of the batch that was ran before this callback. """ def evaluate_callback(self): """A callback that is called at the end of the evaluation."""
Subclasses
Static methods
def get_base_params()
-
Get the base parameters of the evaluator.
Source code
@staticmethod def get_base_params(): """Get the base parameters of the evaluator.""" return {}
Methods
def __init__(self, checkpoint, params=None, device=None)
-
Initialize the evaluator, get the dataset and the model to evaluate.
Arguments
checkpoint
:str
- The checkpoint generated by the trainer where we can find the model state.
params
:dict
- A dict to replace the base params of the evaluator.
device
:str
, optional- The device to use for the evaluation.
Source code
def __init__(self, checkpoint, params=None, device=None): """Initialize the evaluator, get the dataset and the model to evaluate. Arguments: checkpoint (str): The checkpoint generated by the trainer where we can find the model state. params (dict): A dict to replace the base params of the evaluator. device (str, optional): The device to use for the evaluation. """ self.params = merge_dicts(self.get_base_params(), params, verbose=True) self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu' self.print('Using device "{}"'.format(self.device)) self.checkpoint_path = checkpoint self.print('Loading checkpoint ...') self.checkpoint = torch.load(checkpoint, map_location=self.device) self.print('Loading dataset ...') self.dataset = self.get_dataset() self.print('Loading dataloader ...') self.dataloader = self.get_dataloader() self.print('Loading model ...') self.model = self.get_model() self.logger = self.get_logger() # Keep a dict with the key-value pairs to log after each batch self.current_log = {} # Avoid memory leaks by removing the loaded checkpoint but keep the path to it self.checkpoint = None
def batch_callback(self, batch)
-
A callback that is called after is batch.
Arguments
batch
:int
- The number of the batch that was ran before this callback.
Source code
def batch_callback(self, batch): """A callback that is called after is batch. Arguments: batch (int): The number of the batch that was ran before this callback. """
def eval_mode(self)
-
Put the model in evaluation mode.
You can override this method to pass arguments to the eval() method of the model or do anything what you want.
Source code
def eval_mode(self): """Put the model in evaluation mode. You can override this method to pass arguments to the eval() method of the model or do anything what you want. """ self.model.eval()
def evaluate(self)
-
Run the model over the entire dataset and compute the evaluation metric.
Source code
def evaluate(self): """Run the model over the entire dataset and compute the evaluation metric.""" self.print('Starting evaluation ...') self.model.to(self.device) self.eval_mode() # The number of batches that the dataset has n_batches = len(self.dataloader) # The start time of the evaluation and the last batch's end time start_time = time.time() last_endtime = start_time with torch.no_grad(): for batch, data in enumerate(self.dataloader): self.forward(*data) total_time = time.time() - start_time batch_time = time.time() - last_endtime last_endtime = time.time() self.logger.log(merge_dicts({ self.__class__.__name__: None, 'Batch': '{}/{}'.format(batch + 1, n_batches), 'Time': '{:.3f} s'.format(batch_time), 'Total': '{:.1f} s'.format(total_time) }, self.current_log)) self.current_log = {} # A callback after each batch self.batch_callback(batch) # A final callback self.evaluate_callback()
def evaluate_callback(self)
-
A callback that is called at the end of the evaluation.
Source code
def evaluate_callback(self): """A callback that is called at the end of the evaluation."""
def forward(self, *args)
-
Forward pass through the model, make the predictions and store them.
Inside this method you should store the predictions made by the model, perform some metrics and do the evaluation.
Source code
def forward(self, *args): """Forward pass through the model, make the predictions and store them. Inside this method you should store the predictions made by the model, perform some metrics and do the evaluation. """ raise NotImplementedError('Yous should implement your own forward to make the predictions')
def get_dataloader(self)
-
Get the dataloader to use for the evaluation.
Returns
torch.utils.data.Dataloader: The dataloader to use for the evaluation.
Source code
def get_dataloader(self): """Get the dataloader to use for the evaluation. Returns: torch.utils.data.Dataloader: The dataloader to use for the evaluation. """ raise NotImplementedError('You must provide your own dataloader')
def get_dataset(self)
-
Get the dataset for the evaluation.
Returns
torch.utils.data.Dataset: The dataset to use for the evaluation.
Source code
def get_dataset(self): """Get the dataset for the evaluation. Returns: torch.utils.data.Dataset: The dataset to use for the evaluation. """ raise NotImplementedError('You must provide your own dataset')
def get_logger(self)
-
Get the logger to use during the training to show the information about the process.
This base implementation uses the PrintLogger that will print the log to the console.
Returns
pymatch.loggers.Logger: A Logger to use during the training.
Source code
def get_logger(self): """Get the logger to use during the training to show the information about the process. This base implementation uses the PrintLogger that will print the log to the console. Returns: pymatch.loggers.Logger: A Logger to use during the training. """ return PrintLogger(description=None)
def get_model(self)
-
Get the model to use to make the predictions.
Returns
torch.nn.Module: The model to use to make the predictions over the data.
Source code
def get_model(self): """Get the model to use to make the predictions. Returns: torch.nn.Module: The model to use to make the predictions over the data. """ raise NotImplementedError('You must provide your own model')
Inherited members