diff --git a/.gitignore b/.gitignore index 94da40767d391d4310690d82de9e9e04b181004c..a87e4795f68e861bb7e465a8694ab334815d1bde 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ dist build brambox.egg-info __pycache__ +.idea/ +*.html diff --git a/scripts/show_training_loss.py b/scripts/show_training_loss.py index ff7ac0f045b349f9bdd89d009c539859dff47442..9e56ed7cfbebc9ab93acdf6348a10f08030037ae 100644 --- a/scripts/show_training_loss.py +++ b/scripts/show_training_loss.py @@ -1,7 +1,6 @@ -#!python +#!/usr/bin/env python import re import argparse -import matplotlib.pyplot as plt class Batch: @@ -18,10 +17,14 @@ class Batch: def main(): - - parser = argparse.ArgumentParser(description='Parse darknet stdout, plot the loss and indicate weigts file with lowest avg precision and total precision') - parser.add_argument('input', help='Input text file containing darknet stdout') - parser.add_argument('--weights-step', default=100, help=('Multiple of iterations a new weigts file is saved. This is used to point to the most interesting weights file')) + parser = argparse.ArgumentParser( + description='Parse DarkNet stdout, plot the loss and indicate weights file ' + 'with lowest avg precision and total precision.') + parser.add_argument('input', help='Input text file containing darknet stdout.') + parser.add_argument('--weights-step', default=100, help= + 'Multiple of iterations a new weights file is saved. ' + 'This is used to point to the most interesting weights file.') + parser.add_argument('--backend', default='mpl', help='Set the rendering engine of the plot to "mpl" or "ply".') args = parser.parse_args() with open(args.input) as f: @@ -45,17 +48,28 @@ def main(): subsampled_values = [v for v in values if v.iteration % args.weights_step == 0] sorted_subsampled_values = sorted(subsampled_values, key=lambda v: v.avg_loss) for i in range(10): - print("Candidate", i+1, "=", sorted_subsampled_values[i]) + print("Candidate", i + 1, "=", sorted_subsampled_values[i]) # plot loss total_losses = [v.total_loss for v in values] avg_losses = [v.avg_loss for v in values] iterations = [v.iteration for v in values] + plot(avg_losses, iterations, total_losses, backend=args.backend) + + +def plot(avg_losses, iterations, total_losses, backend='mpl'): + if not backend or backend == 'mpl': + plot_mpl(avg_losses, iterations, total_losses) + elif backend == 'ply': + plot_ply(avg_losses, iterations, total_losses) + + +def plot_mpl(avg_losses, iterations, total_losses): + import matplotlib.pyplot as plt plt.figure(figsize=(10, 8)) plt.plot(iterations, total_losses, label="total loss", linewidth=1) plt.plot(iterations, avg_losses, label="avg loss", linewidth=1) - plt.gcf().suptitle('Training loss', weight='bold') plt.gca().set_ylabel('Loss') plt.gca().set_xlabel('Iteration') @@ -64,5 +78,17 @@ def main(): plt.show() +def plot_ply(avg_losses, iterations, total_losses): + import plotly.offline as po + import plotly.graph_objs as go + + plots = list(map(lambda loss: go.Scatter(x=iterations, y=loss, mode='lines'), [total_losses, avg_losses])) + plots[0].name = "total loss" + plots[1].name = "average loss" + layout = go.Layout(title='Training loss', xaxis=dict(title='Iteration'), yaxis=dict(title='Loss', range=[0, 10])) + fig = go.Figure(data=plots, layout=layout) + po.plot(fig) + + if __name__ == '__main__': main() diff --git a/setup.py b/setup.py index 431b7cf440a91f3fa322e64065656b9312f836b0..66c32f5d40accaf91c2fd9d2c779f69919efd686 100644 --- a/setup.py +++ b/setup.py @@ -14,4 +14,4 @@ setup(name='brambox', 'scripts/sparse_link.py', 'scripts/replace_image_channel.py', 'scripts/show_training_loss.py'], - test_suite='tests') + test_suite='tests', install_requires=['matplotlib', 'plotly'])