torchsight.cli.visualize.dldenet module

Visualize the predictions of the dldenet.

Source code
"""Visualize the predictions of the dldenet."""
import random

import click


@click.command()
@click.option('-c', '--checkpoint', type=click.Path(exists=True), required=True,
              help='The path to the checkpoint generated by a trainer.')
@click.option('-d', '--dataset', default='coco', show_default=True, type=click.Choice(['coco', 'logo32plus', 'flickr32']))
@click.option('-dr', '--dataset-root', type=click.Path(exists=True), required=True,
              help='The path to the directory where is the data of the dataset.')
@click.option('--training-set', is_flag=True, help='Show the images of the training set instead of validation.')
@click.option('--no-shuffle', is_flag=True)
@click.option('--device', help='The device to use to run the model. Default to cuda:0 if cuda is available.')
@click.option('--threshold', default=0.5, show_default=True, help='The confidence threshold for the predictions.')
@click.option('--iou-threshold', default=0.3, show_default=True, help='The threshold for Non Maximum Supresion.')
@click.option('--only-logos', is_flag=True, help='Show only images with logos in the Flickr32 dataset.')
@click.option('--tracked-means', is_flag=True)
def dldenet(checkpoint, dataset_root, dataset, training_set, no_shuffle, device, threshold, iou_threshold, tracked_means,
            only_logos):
    """Visualize the predictions of the DLDENet model loaded from CHECKPOINT with the indicated
    dataset that contains its data in DATASET-ROOT."""
    import torch

    from torchsight.datasets import CocoDataset, Logo32plusDataset, Flickr32Dataset
    from torchsight.models import DLDENet, DLDENetWithTrackedMeans
    from torchsight.transforms.augmentation import AugmentDetection
    from torchsight.utils import visualize_boxes

    device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
    checkpoint = torch.load(checkpoint, map_location=device)

    if tracked_means:
        model = DLDENetWithTrackedMeans.from_checkpoint(checkpoint, device=device)
    else:
        model = DLDENet.from_checkpoint(checkpoint, device=device)

    hyperparameters = checkpoint['hyperparameters']

    transform = AugmentDetection(params=hyperparameters['transform'], evaluation=True)
    transform_visible = AugmentDetection(params=hyperparameters['transform'], evaluation=True, normalize=False)
    params = {'root': dataset_root}

    if dataset == 'coco':
        try:
            coco_params = hyperparameters['datasets']['coco']
        except KeyError:
            coco_params = hyperparameters['datasets']
        params['classes_names'] = coco_params['class_names']
        params['dataset'] = 'train2017' if training_set else 'val2017'
        dataset = CocoDataset(**params, transform=transform)
        dataset_human = CocoDataset(**params, transform=transform_visible)
        label_to_name = dataset.classes['names']
    elif dataset == 'logo32plus':
        params['dataset'] = 'training' if training_set else 'validation'
        params['classes'] = hyperparameters['datasets']['logo32plus']['classes']
        dataset = Logo32plusDataset(**params, transform=transform)
        dataset_human = Logo32plusDataset(**params, transform=transform_visible)
        label_to_name = dataset.label_to_class
    elif dataset == 'flickr32':
        params['dataset'] = 'trainval' if training_set else 'test'
        try:
            params['brands'] = hyperparameters['datasets']['flickr32']['classes']
        except KeyError:
            print('WARN: Model was not trained using flickr32 dataset.')
            params['classes'] = None
        dataset = Flickr32Dataset(**params, transform=transform, only_boxes=only_logos)
        dataset_human = Flickr32Dataset(**params, transform=transform_visible)
        label_to_name = dataset.label_to_class
    else:
        raise ValueError('There is no dataset named "{}"'.format(dataset))

    indexes = list(range(len(dataset)))

    if not no_shuffle:
        random.shuffle(indexes)

    model.eval(threshold, iou_threshold)
    model.to(device)

    for i in indexes:
        image, *_ = dataset[i]
        image_visible, *_ = dataset_human[i]
        image = image.unsqueeze(dim=0).type(torch.float).to(device)
        detections = model(image)[0]

        visualize_boxes(image_visible, detections, label_to_name)