torchsight.cli.stats.dldenet module

Some stats about the DLDENet.

Source code
"""Some stats about the DLDENet."""
import click


@click.group()
def dldenet():
    """Show stats about the DLDENet."""


@dldenet.command()
@click.option('-c', '--checkpoint', required=True, type=click.Path(exists=True), help='A checkpoint generated by a trainer.')
@click.option('--device', help='The device to use to load the model and make the computations.')
def weights(checkpoint, device):
    """Visualize the norm of the classification weights of each class and generate a 
    similarity matrix between the each class weights.
    """
    import torch
    from matplotlib import pyplot as plt
    from torchsight.models import DLDENet

    device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
    checkpoint = torch.load(checkpoint, map_location=device)
    hp = checkpoint['hyperparameters']
    dataset = hp['datasets']['use']
    weights = DLDENet.from_checkpoint(checkpoint).classification.weights

    if dataset == 'flickr32':
        from torchsight.datasets import Flickr32Dataset
        label_to_class = Flickr32Dataset(root=hp['datasets']['flickr32']['root']).label_to_class
        for label, name in label_to_class.items():
            print('{} - {}'.format(str(label).rjust(2), name))

    with torch.no_grad():
        norms = weights.norm(dim=0)

        plt.subplot(1, 2, 1)
        plt.bar(x=[i for i in range(norms.shape[0])],
                height=norms.numpy())
        plt.title('The norm the classification vector for each class in the DLDENet')

        # weights has shape (embedding size, num classes)
        similarity = torch.matmul(weights.permute(1, 0), weights)
        # similarity is a matrix with shape (classes, classes) but is not normalized
        for i in range(similarity.shape[0]):
            for j in range(similarity.shape[1]):
                similarity[i, j] /= norms[i] * norms[j]

        plt.subplot(1, 2, 2)
        plt.imshow(similarity.numpy())
        plt.colorbar()
        plt.title('Similarity matrix between the classification vectors in the DLDENet')
        plt.show()


@dldenet.command()
@click.option('-l', '--log-file', required=True, type=click.Path(exists=True), help='The logs generated by a PrintLogger.')
@click.option('--loss', default='Loss', show_default=True,
              help='The key in the log to search for the total loss.')
@click.option('--classification', default='Class.', show_default=True,
              help='The key in the log to search for the classification loss')
@click.option('--regression', default='Regr.', show_default=True,
              help='The key in the log to search for the regression loss.')
@click.option('--similarity', default='Simil.', show_default=True,
              help='The key in the log to search for the similarity loss.')
@click.option('--epoch-key', default='Epoch', show_default=True,
              help='The key in the log to search for the epoch number.')
def losses(log_file, loss, classification, regression, similarity, epoch_key):
    """Visualize the losses in time."""
    import torch
    import numpy as np
    from matplotlib import pyplot as plt
    from torchsight.loggers import PrintLogger

    class Stats:
        def __init__(self, epoch, dataset):
            self.dataset = dataset
            self.epoch = epoch
            self.losses = []
            self.classifications = []
            self.regressions = []
            self.similarities = []

        def _mean(self, array):
            return float(torch.Tensor([float(value) for value in array]).mean())

        @property
        def loss(self):
            return self._mean(self.losses)

        @property
        def classification(self):
            return self._mean(self.classifications)

        @property
        def regression(self):
            return self._mean(self.regressions)

        @property
        def similarity(self):
            return self._mean(self.similarities)

    train_logs = PrintLogger.read(filepath=log_file, keep=lambda x: x[:10] == '[Training]')
    valid_logs = PrintLogger.read(filepath=log_file, keep=lambda x: x[:12] == '[Validating]')

    stats = []

    for logs, dataset in [(train_logs, 'training'), (valid_logs, 'validation')]:
        for log in logs:
            epoch = log[epoch_key]
            stat = next((stat for stat in stats if stat.epoch == epoch and stat.dataset == dataset), None)
            if stat is None:
                stat = Stats(epoch, dataset)
                stats.append(stat)
            stat.losses.append(log[loss])
            stat.classifications.append(log[classification])
            stat.regressions.append(log[regression])
            stat.similarities.append(log[similarity])

    epochs = list({int(stat.epoch) for stat in stats})
    epochs.sort()
    train_stats = [stat for stat in stats if stat.dataset == 'training']
    valid_stats = [stat for stat in stats if stat.dataset == 'validation']

    # Plot loss
    plt.subplot(2, 2, 1)
    print(np.array([stat.loss for stat in valid_stats]))
    plt.plot(np.array(epochs), np.array([stat.loss for stat in train_stats]), color='blue')
    plt.plot(np.array(epochs), np.array([stat.loss for stat in valid_stats]), color='green')
    plt.title('Total loss per epoch')
    plt.show()