torchsight.cli.train.dldenet
module
CLI to train the DLDENet.
Source code
"""CLI to train the DLDENet."""
import click
from torchsight.trainers import DLDENetTrainer, DLDENetWithTrackedMeansTrainer
@click.command()
@click.option('--config', type=click.Path(exists=True), help='A JSON config file to load the configurations.'
'If you provide this options all the other options are not used (only --device can be used).')
@click.option('--device', default=None, help='The device that the model must use.')
@click.option('-d', '--dataset', default='coco', show_default=True, type=click.Choice(['coco', 'logo32plus', 'flickr32']))
@click.option('-dr', '--dataset-root', type=click.Path(exists=True))
@click.option('-b', '--batch-size', default=8, show_default=True)
@click.option('--resnet', default=50, show_default=True, help='The resnet backbone that the model must use.')
@click.option('--fixed-bias', default=-0.5, show_default=True, help='The fixed bias for the classification module.')
@click.option('--logs-dir', default='./logs', type=click.Path(), show_default=True,
help='Where to store the checkpoints and descriptions.')
@click.option('--classes', default='',
help='Indicate which classes (identified by its string label) must be used for the training. '
'If no class is provided the trainer will use all the classes. Example: --classes "bear sheep airplane"')
@click.option('--optimizer', default='adabound', type=click.Choice(['adabound', 'sgd']), show_default=True,
help='Set the optimizer that the trainer must use to train the model.')
@click.option('--adabound-lr', default=1e-3, show_default=True, help='The learning rate for the starting in Adabound.')
@click.option('--adabound-final-lr', default=1, show_default=True,
help='The final learning rate when Adabound transform to SGD.')
@click.option('--scheduler-factor', default=0.1, show_default=True,
help='The factor to scale the LR.')
@click.option('--scheduler-patience', default=5, show_default=True,
help='Hoy many epochs without relative improvement the scheduler must wait.')
@click.option('--scheduler-threshold', default=0.01, show_default=True,
help='The relative threshold that indicates an improvement for the scheduler.')
@click.option('--anchors-sizes', default='20 40 80 160 320', show_default=True)
@click.option('--not-normalize', is_flag=True,
help='Avoid normalization of the embeddings in the classification module. Only available without tracked means.')
@click.option('--tracked-means', is_flag=True, help='Use the version that tracks the means.')
@click.option('--soft-criterion', is_flag=True, help='Use soft assignment in the Loss.')
@click.option('--means-update', default='batch', type=click.Choice(['batch', 'manual']), show_default=True,
help='Update type for the means in the tracked version. See DirectionalClassification module for more info.')
@click.option('--means-lr', default=0.1, show_default=True, help='The learning rate for the "batch" means update method.')
@click.option('--num-workers', default=8, show_default=True)
@click.option('--epochs', default=100, show_default=True)
def dldenet(config, device, dataset_root, dataset, batch_size, resnet, fixed_bias, logs_dir, classes, optimizer,
adabound_lr, adabound_final_lr, scheduler_factor, scheduler_patience, scheduler_threshold,
anchors_sizes, num_workers,
not_normalize, tracked_means, soft_criterion, epochs, means_update, means_lr):
"""Train the DLDENet with weighted classification vectors using the indicated dataset that
contains is data in DATASET_ROOT directory.
"""
if config is not None:
import json
with open(config, 'r') as file:
hyperparameters = json.loads(file.read())
else:
classes = classes.split()
if dataset_root is None:
print('Error: Option "--dataset-root" is required.')
exit()
hyperparameters = {
'model': {
'resnet': resnet,
'normalize': not not_normalize,
'means_update': means_update,
'means_lr': means_lr,
'fixed_bias': fixed_bias,
'anchors': {
'sizes': [int(size) for size in anchors_sizes.split()],
},
},
'criterion': {
'soft': soft_criterion
},
'datasets': {
'use': dataset,
'coco': {'root': dataset_root, 'class_names': classes},
'logo32plus': {'root': dataset_root, 'classes': classes if classes else None},
'flickr32': {'root': dataset_root, 'classes': classes if classes else None}
},
'dataloaders': {
'batch_size': batch_size,
'num_workers': num_workers,
},
'logger': {'dir': logs_dir},
'checkpoint': {'dir': logs_dir},
'scheduler': {
'factor': scheduler_factor,
'patience': scheduler_patience,
'threshold': scheduler_threshold,
},
'optimizer': {
'use': optimizer,
'adabound': {
'lr': adabound_lr,
'final_lr': adabound_final_lr,
},
},
}
# Set the params for the trainers
params = {'hyperparameters': hyperparameters, 'device': device}
if tracked_means:
DLDENetWithTrackedMeansTrainer(**params).train(epochs)
else:
DLDENetTrainer(**params).train(epochs)
@click.command()
@click.option('-c', '--checkpoint', type=click.Path(exists=True), required=True)
@click.option('-dr', '--dataset-root', type=click.Path(exists=True), required=True)
@click.option('-b', '--batch-size', default=8, show_default=True, type=click.INT)
@click.option('--logs-dir', default='./logs', show_default=True, type=click.Path(exists=True),
help='Where to store the checkpoints and descriptions.')
@click.option('--device', help='The device that the model must use.')
@click.option('--epochs', default=100, show_default=True)
@click.option('--tracked-means', is_flag=True, help='Use the tracked means version.')
def dldenet_from_checkpoint(dataset_root, checkpoint, batch_size, logs_dir, device, epochs, tracked_means):
"""Get an instance of the trainer from the checkpoint CHECKPOINT and resume the exact same training
with the dataset that contains its data in DATASET_ROOT.
You can only change things that will not affect the coherence of the training.
"""
new_params = {
'datasets': {
'coco': {'root': dataset_root},
'logo32plus': {'root': dataset_root},
'flickr32': {'root': dataset_root}
}
}
if batch_size is not None:
new_params['dataloaders'] = {'batch_size': batch_size}
if logs_dir is not None:
new_params['logger'] = {'dir': logs_dir}
new_params['checkpoint'] = {'dir': logs_dir}
if tracked_means:
DLDENetWithTrackedMeansTrainer.from_checkpoint(checkpoint, new_params, device).train(epochs)
else:
DLDENetTrainer.from_checkpoint(checkpoint, new_params, device).train(epochs)