torchsight.cli.stats.print
module
Extract stats from the PrintLogger.
Source code
"""Extract stats from the PrintLogger."""
import click
from torchsight.loggers import PrintLogger
@click.command()
@click.argument('log_file')
@click.option('-k', '--keys', default='Loss Class. Pos Neg Regr. Simil. w-norm LR', show_default=True)
@click.option('-nv', '--no-valid', default='LR w-norm', show_default=True)
@click.option('-ek', '--epoch-key', default='Epoch', help='The key in the logs that indicates the epoch.')
@click.option('--just', default=11, show_default=True)
def printlogger(log_file, keys, no_valid, epoch_key, just):
"""Get the mean loss per epoch over the training dataset and validation dataset.
This stats could be generated from the logs that generated the PrintLogger and were
saved in a file.
For example:
$ python cli.py train dldenet ~/dataset/coco --logs-dir logs/ > logs/logs.txt
This command indicates that the trainer save the checkpoints to logs/ and the logger
save the description into logs/description.txt; and with the '>' we put all the prints
into the logs/logs.txt file.
Now you can use:
$ python cli.py stats printlogger logs/logs.txt
And you'll get something like:
EPOCH | TRAIN LOSS | VALIDATION LOSS
1 | 1.9204515 | 1.1482288
2 | 1.1309012 | ---
Where the `---` means that there is no data yet to show.
"""
logger = PrintLogger()
keys = keys.split()
def accumulate(logs):
epochs = {}
for log in logs:
epoch = log[epoch_key]
if epoch not in epochs:
epochs[epoch] = {}
for k in keys:
if k not in epochs[epoch]:
epochs[epoch][k] = {'sum': 0, 'count': 0}
value = log.get(k, None)
epochs[epoch][k]['sum'] += float(value) if value is not None else 0
epochs[epoch][k]['count'] += 1 if value is not None else 0
return epochs
train = accumulate(logger.read(log_file, keep=lambda x: x[:10] == '[Training]'))
valid = accumulate(logger.read(log_file, keep=lambda x: x[:12] == '[Validating]'))
headers = ['Epoch']
for key in keys:
if key in no_valid:
headers += [key.center(just)]
else:
headers += [key.center(just * 2 + 1)]
print(' | '.join(headers))
for epoch in train:
values = [str(epoch).center(5)]
for k in keys:
if train[epoch][k]['count'] == 0:
train_value = '---'.rjust(just)
else:
train_value = train[epoch][k]['sum'] / train[epoch][k]['count']
if train_value > 10:
train_value = '{:.3f}'.format(train_value)
else:
train_value = '{:.7f}'.format(train_value)
train_value = train_value.rjust(just)
if k not in no_valid:
if valid.get(epoch, None) is None or valid[epoch].get(k, None) is None or valid[epoch][k]['count'] == 0:
valid_value = '---'.rjust(just)
else:
valid_value = valid[epoch][k]['sum'] / valid[epoch][k]['count']
if valid_value > 10:
valid_value = '{:.3f}'.format(valid_value)
else:
valid_value = '{:.7f}'.format(valid_value)
valid_value = valid_value.rjust(just)
values.append('{} {}'.format(train_value, valid_value))
else:
values.append('{}'.format(train_value))
print(' | '.join(values))