From 8fffaf7946018f8364ab974cdf3620c79119b7a7 Mon Sep 17 00:00:00 2001 From: Kristof Van Engeland Date: Fri, 27 Oct 2017 16:41:41 +0200 Subject: [PATCH] Added PlotLy backend to training loss script. --- .gitignore | 2 ++ scripts/show_training_loss.py | 42 ++++++++++++++++++++++++++++------- setup.py | 2 +- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 94da407..a87e479 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ dist build brambox.egg-info __pycache__ +.idea/ +*.html diff --git a/scripts/show_training_loss.py b/scripts/show_training_loss.py index ff7ac0f..9e56ed7 100644 --- a/scripts/show_training_loss.py +++ b/scripts/show_training_loss.py @@ -1,7 +1,6 @@ -#!python +#!/usr/bin/env python import re import argparse -import matplotlib.pyplot as plt class Batch: @@ -18,10 +17,14 @@ class Batch: def main(): - - parser = argparse.ArgumentParser(description='Parse darknet stdout, plot the loss and indicate weigts file with lowest avg precision and total precision') - parser.add_argument('input', help='Input text file containing darknet stdout') - parser.add_argument('--weights-step', default=100, help=('Multiple of iterations a new weigts file is saved. This is used to point to the most interesting weights file')) + parser = argparse.ArgumentParser( + description='Parse DarkNet stdout, plot the loss and indicate weights file ' + 'with lowest avg precision and total precision.') + parser.add_argument('input', help='Input text file containing darknet stdout.') + parser.add_argument('--weights-step', default=100, help= + 'Multiple of iterations a new weights file is saved. ' + 'This is used to point to the most interesting weights file.') + parser.add_argument('--backend', default='mpl', help='Set the rendering engine of the plot to "mpl" or "ply".') args = parser.parse_args() with open(args.input) as f: @@ -45,17 +48,28 @@ def main(): subsampled_values = [v for v in values if v.iteration % args.weights_step == 0] sorted_subsampled_values = sorted(subsampled_values, key=lambda v: v.avg_loss) for i in range(10): - print("Candidate", i+1, "=", sorted_subsampled_values[i]) + print("Candidate", i + 1, "=", sorted_subsampled_values[i]) # plot loss total_losses = [v.total_loss for v in values] avg_losses = [v.avg_loss for v in values] iterations = [v.iteration for v in values] + plot(avg_losses, iterations, total_losses, backend=args.backend) + + +def plot(avg_losses, iterations, total_losses, backend='mpl'): + if not backend or backend == 'mpl': + plot_mpl(avg_losses, iterations, total_losses) + elif backend == 'ply': + plot_ply(avg_losses, iterations, total_losses) + + +def plot_mpl(avg_losses, iterations, total_losses): + import matplotlib.pyplot as plt plt.figure(figsize=(10, 8)) plt.plot(iterations, total_losses, label="total loss", linewidth=1) plt.plot(iterations, avg_losses, label="avg loss", linewidth=1) - plt.gcf().suptitle('Training loss', weight='bold') plt.gca().set_ylabel('Loss') plt.gca().set_xlabel('Iteration') @@ -64,5 +78,17 @@ def main(): plt.show() +def plot_ply(avg_losses, iterations, total_losses): + import plotly.offline as po + import plotly.graph_objs as go + + plots = list(map(lambda loss: go.Scatter(x=iterations, y=loss, mode='lines'), [total_losses, avg_losses])) + plots[0].name = "total loss" + plots[1].name = "average loss" + layout = go.Layout(title='Training loss', xaxis=dict(title='Iteration'), yaxis=dict(title='Loss', range=[0, 10])) + fig = go.Figure(data=plots, layout=layout) + po.plot(fig) + + if __name__ == '__main__': main() diff --git a/setup.py b/setup.py index 431b7cf..66c32f5 100644 --- a/setup.py +++ b/setup.py @@ -14,4 +14,4 @@ setup(name='brambox', 'scripts/sparse_link.py', 'scripts/replace_image_channel.py', 'scripts/show_training_loss.py'], - test_suite='tests') + test_suite='tests', install_requires=['matplotlib', 'plotly']) -- GitLab