torchsight.cli.stats.dldenet
module
Some stats about the DLDENet.
Source code
"""Some stats about the DLDENet."""
import click
@click.group()
def dldenet():
"""Show stats about the DLDENet."""
@dldenet.command()
@click.option('-c', '--checkpoint', required=True, type=click.Path(exists=True), help='A checkpoint generated by a trainer.')
@click.option('--device', help='The device to use to load the model and make the computations.')
def weights(checkpoint, device):
"""Visualize the norm of the classification weights of each class and generate a
similarity matrix between the each class weights.
"""
import torch
from matplotlib import pyplot as plt
from torchsight.models import DLDENet
device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load(checkpoint, map_location=device)
hp = checkpoint['hyperparameters']
dataset = hp['datasets']['use']
weights = DLDENet.from_checkpoint(checkpoint).classification.weights
if dataset == 'flickr32':
from torchsight.datasets import Flickr32Dataset
label_to_class = Flickr32Dataset(root=hp['datasets']['flickr32']['root']).label_to_class
for label, name in label_to_class.items():
print('{} - {}'.format(str(label).rjust(2), name))
with torch.no_grad():
norms = weights.norm(dim=0)
plt.subplot(1, 2, 1)
plt.bar(x=[i for i in range(norms.shape[0])],
height=norms.numpy())
plt.title('The norm the classification vector for each class in the DLDENet')
# weights has shape (embedding size, num classes)
similarity = torch.matmul(weights.permute(1, 0), weights)
# similarity is a matrix with shape (classes, classes) but is not normalized
for i in range(similarity.shape[0]):
for j in range(similarity.shape[1]):
similarity[i, j] /= norms[i] * norms[j]
plt.subplot(1, 2, 2)
plt.imshow(similarity.numpy())
plt.colorbar()
plt.title('Similarity matrix between the classification vectors in the DLDENet')
plt.show()
@dldenet.command()
@click.option('-l', '--log-file', required=True, type=click.Path(exists=True), help='The logs generated by a PrintLogger.')
@click.option('--loss', default='Loss', show_default=True,
help='The key in the log to search for the total loss.')
@click.option('--classification', default='Class.', show_default=True,
help='The key in the log to search for the classification loss')
@click.option('--regression', default='Regr.', show_default=True,
help='The key in the log to search for the regression loss.')
@click.option('--similarity', default='Simil.', show_default=True,
help='The key in the log to search for the similarity loss.')
@click.option('--epoch-key', default='Epoch', show_default=True,
help='The key in the log to search for the epoch number.')
def losses(log_file, loss, classification, regression, similarity, epoch_key):
"""Visualize the losses in time."""
import torch
import numpy as np
from matplotlib import pyplot as plt
from torchsight.loggers import PrintLogger
class Stats:
def __init__(self, epoch, dataset):
self.dataset = dataset
self.epoch = epoch
self.losses = []
self.classifications = []
self.regressions = []
self.similarities = []
def _mean(self, array):
return float(torch.Tensor([float(value) for value in array]).mean())
@property
def loss(self):
return self._mean(self.losses)
@property
def classification(self):
return self._mean(self.classifications)
@property
def regression(self):
return self._mean(self.regressions)
@property
def similarity(self):
return self._mean(self.similarities)
train_logs = PrintLogger.read(filepath=log_file, keep=lambda x: x[:10] == '[Training]')
valid_logs = PrintLogger.read(filepath=log_file, keep=lambda x: x[:12] == '[Validating]')
stats = []
for logs, dataset in [(train_logs, 'training'), (valid_logs, 'validation')]:
for log in logs:
epoch = log[epoch_key]
stat = next((stat for stat in stats if stat.epoch == epoch and stat.dataset == dataset), None)
if stat is None:
stat = Stats(epoch, dataset)
stats.append(stat)
stat.losses.append(log[loss])
stat.classifications.append(log[classification])
stat.regressions.append(log[regression])
stat.similarities.append(log[similarity])
epochs = list({int(stat.epoch) for stat in stats})
epochs.sort()
train_stats = [stat for stat in stats if stat.dataset == 'training']
valid_stats = [stat for stat in stats if stat.dataset == 'validation']
# Plot loss
plt.subplot(2, 2, 1)
print(np.array([stat.loss for stat in valid_stats]))
plt.plot(np.array(epochs), np.array([stat.loss for stat in train_stats]), color='blue')
plt.plot(np.array(epochs), np.array([stat.loss for stat in valid_stats]), color='green')
plt.title('Total loss per epoch')
plt.show()