torchsight.cli.train.retinanet
module
CLI to train the RetinaNet model.
Source code
"""CLI to train the RetinaNet model."""
import click
@click.command()
@click.option('--config', type=click.Path(exists=True), help='An optional configuration file with the hyperparameters. '
'It will block all the other options, only --device and --checkpoint will work.')
@click.option('-d', '--dataset', default='coco', show_default=True, type=click.Choice(['coco', 'flickr32', 'logo32plus']))
@click.option('-dr', '--dataset-root', type=click.Path(exists=True))
@click.option('-b', '--batch-size', default=8, show_default=True)
@click.option('--resnet', default=50, show_default=True, help='The resnet backbone that the model must use.')
@click.option('--logs-dir', default='.', type=click.Path(), show_default=True,
help='Where to store the checkpoints and descriptions.')
@click.option('-c', '--checkpoint', type=click.Path(exists=True), help='A checkpoint to resume the training from it.')
@click.option('--classes', default=None,
help='Indicate which classes (identified by its string label) must be used for the training. '
'If no class is provided the trainer will use all the classes. Example: --classes "bear sheep airplane"')
@click.option('--device', default=None, help='The device that the model must use.')
def retinanet(config, dataset, dataset_root, batch_size, resnet, logs_dir, checkpoint, classes, device):
"""Train a RetinaNet instance with the indicated dataset that contains its data in the
DATASET_ROOT directory."""
from torchsight.trainers import RetinaNetTrainer
if config is None:
classes = classes.split() if classes is not None else None
hyperparameters = {
'datasets': {
'use': dataset,
'coco': {
'root': dataset_root,
'class_names': classes,
},
'logo32plus': {
'root': dataset_root,
'classes': classes,
},
'flickr32': {
'root': dataset_root,
'classes': classes,
}
},
'dataloaders': {
'batch_size': batch_size
},
'model': {
'resnet': resnet,
},
'logger': {
'dir': logs_dir
},
'checkpoint': {
'dir': logs_dir
}
}
else:
import json
with open(config, 'r') as file:
hyperparameters = json.loads(file.read())
RetinaNetTrainer(
hyperparameters=hyperparameters,
checkpoint=checkpoint,
device=device
).train()