torchsight.cli.evaluate.dldenet module

Evaluate the DLDENet models.

Source code
"""Evaluate the DLDENet models."""
import click


@click.command()
@click.option('-c', '--checkpoint', required=True, type=click.Path(exists=True))
@click.option('-d', '--dataset', default='coco', type=click.Choice(['coco', 'flickr32']),
              help='The dataset to use to validate.', show_default=True)
@click.option('-dr', '--dataset-root', help='The root directory where is the data of the dataset.')
@click.option('--results-dir', default='./evaluations/dldenet/{dataset}', show_default=True,
              help='The directory where to store the results.')
@click.option('--coco-dataset', default='val2017', help='The coco dataset to use.', show_default=True)
@click.option('--classes', help='The name of the classes to detect. Default: Get from checkpoint.')
@click.option('--batch-size', default=8, show_default=True)
@click.option('--num-workers', default=8, show_default=True)
@click.option('--width-tracked-means', is_flag=True, help='Use the version with tracked means.')
@click.option('--device', help='The device where to run the evaluation. Default to cuda:0 if cuda is available.')
@click.option('--threshold', default=0.5, help='The detection threshold.', show_default=True)
@click.option('--iou-threshold', default=0.5, help='The IoU threshold for the NMS.', show_default=True)
def dldenet(checkpoint, dataset, dataset_root, results_dir, coco_dataset, classes,
            batch_size, num_workers, width_tracked_means, device, threshold, iou_threshold):
    """Evaluate the DLDENet with the indicated dataset that contains its data in DATASET-ROOT with the
    model saved at CHECKPOINT."""
    import torch

    from torchsight.evaluators import (DLDENetCOCOEvaluator,
                                       DLDENetFlickr32Evaluator)

    device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'

    if dataset == 'coco':
        class_names_from_checkpoint = classes is None
        class_names = () if classes is None else classes.split()
        results_dir = results_dir.format(dataset=dataset)
        results_file = '{}.json'.format(coco_dataset)

        DLDENetCOCOEvaluator(
            checkpoint,
            params={'results': {'dir': results_dir, 'file': results_file},
                    'dataset': {'root': dataset_root,
                                'validation': coco_dataset,
                                'class_names': class_names,
                                'class_names_from_checkpoint': class_names_from_checkpoint},
                    'dataloader': {'batch_size': batch_size, 'num_workers': num_workers},
                    'model': {'with_tracked_means': width_tracked_means,
                              'evaluation': {'threshold': threshold, 'iou_threshold': iou_threshold}}},
            device=device
        ).evaluate()

    if dataset == 'flickr32':
        DLDENetFlickr32Evaluator(
            checkpoint,
            params={
                'root': dataset_root,
                'file': '{}/flickr32_predictions.csv'.format(results_dir),
                'dataloader': {
                    'num_workers': num_workers,
                    'batch_size': batch_size
                },
                'thresholds': {
                    'detection': threshold,
                    'iou': iou_threshold
                }
            },
            device=device
        ).evaluate()