torchsight.cli.visualize.dldenet
module
Visualize the predictions of the dldenet.
Source code
"""Visualize the predictions of the dldenet."""
import random
import click
@click.command()
@click.option('-c', '--checkpoint', type=click.Path(exists=True), required=True,
help='The path to the checkpoint generated by a trainer.')
@click.option('-d', '--dataset', default='coco', show_default=True, type=click.Choice(['coco', 'logo32plus', 'flickr32']))
@click.option('-dr', '--dataset-root', type=click.Path(exists=True), required=True,
help='The path to the directory where is the data of the dataset.')
@click.option('--training-set', is_flag=True, help='Show the images of the training set instead of validation.')
@click.option('--no-shuffle', is_flag=True)
@click.option('--device', help='The device to use to run the model. Default to cuda:0 if cuda is available.')
@click.option('--threshold', default=0.5, show_default=True, help='The confidence threshold for the predictions.')
@click.option('--iou-threshold', default=0.3, show_default=True, help='The threshold for Non Maximum Supresion.')
@click.option('--only-logos', is_flag=True, help='Show only images with logos in the Flickr32 dataset.')
@click.option('--tracked-means', is_flag=True)
def dldenet(checkpoint, dataset_root, dataset, training_set, no_shuffle, device, threshold, iou_threshold, tracked_means,
only_logos):
"""Visualize the predictions of the DLDENet model loaded from CHECKPOINT with the indicated
dataset that contains its data in DATASET-ROOT."""
import torch
from torchsight.datasets import CocoDataset, Logo32plusDataset, Flickr32Dataset
from torchsight.models import DLDENet, DLDENetWithTrackedMeans
from torchsight.transforms.augmentation import AugmentDetection
from torchsight.utils import visualize_boxes
device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load(checkpoint, map_location=device)
if tracked_means:
model = DLDENetWithTrackedMeans.from_checkpoint(checkpoint, device=device)
else:
model = DLDENet.from_checkpoint(checkpoint, device=device)
hyperparameters = checkpoint['hyperparameters']
transform = AugmentDetection(params=hyperparameters['transform'], evaluation=True)
transform_visible = AugmentDetection(params=hyperparameters['transform'], evaluation=True, normalize=False)
params = {'root': dataset_root}
if dataset == 'coco':
try:
coco_params = hyperparameters['datasets']['coco']
except KeyError:
coco_params = hyperparameters['datasets']
params['classes_names'] = coco_params['class_names']
params['dataset'] = 'train2017' if training_set else 'val2017'
dataset = CocoDataset(**params, transform=transform)
dataset_human = CocoDataset(**params, transform=transform_visible)
label_to_name = dataset.classes['names']
elif dataset == 'logo32plus':
params['dataset'] = 'training' if training_set else 'validation'
params['classes'] = hyperparameters['datasets']['logo32plus']['classes']
dataset = Logo32plusDataset(**params, transform=transform)
dataset_human = Logo32plusDataset(**params, transform=transform_visible)
label_to_name = dataset.label_to_class
elif dataset == 'flickr32':
params['dataset'] = 'trainval' if training_set else 'test'
try:
params['brands'] = hyperparameters['datasets']['flickr32']['classes']
except KeyError:
print('WARN: Model was not trained using flickr32 dataset.')
params['classes'] = None
dataset = Flickr32Dataset(**params, transform=transform, only_boxes=only_logos)
dataset_human = Flickr32Dataset(**params, transform=transform_visible)
label_to_name = dataset.label_to_class
else:
raise ValueError('There is no dataset named "{}"'.format(dataset))
indexes = list(range(len(dataset)))
if not no_shuffle:
random.shuffle(indexes)
model.eval(threshold, iou_threshold)
model.to(device)
for i in indexes:
image, *_ = dataset[i]
image_visible, *_ = dataset_human[i]
image = image.unsqueeze(dim=0).type(torch.float).to(device)
detections = model(image)[0]
visualize_boxes(image_visible, detections, label_to_name)