From 2892feb0acc266daa0c8e2ff1a21ef8592ca9149 Mon Sep 17 00:00:00 2001 From: Tilman Kamp Date: Mon, 10 Oct 2016 15:17:50 +0200 Subject: [PATCH] Logging context and hyper parameters to JSON file --- DeepSpeech.ipynb | 126 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 6 deletions(-) diff --git a/DeepSpeech.ipynb b/DeepSpeech.ipynb index 406ec640..ee16703e 100644 --- a/DeepSpeech.ipynb +++ b/DeepSpeech.ipynb @@ -81,6 +81,9 @@ "import tempfile\n", "import numpy as np\n", "import tensorflow as tf\n", + "import json\n", + "import subprocess\n", + "import datetime\n", "from util.gpu import get_available_gpus\n", "from util.text import sparse_tensor_value_to_text, wers\n", "from tensorflow.python.ops import ctc_ops\n", @@ -1025,6 +1028,30 @@ " log_variable(variable, gradient=gradient)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we define the log directory plus some helpers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "log_dir = '%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\"))\n", + "\n", + "def get_git_revision_hash():\n", + " return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()\n", + "\n", + "def get_git_branch():\n", + " return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1177,11 +1204,15 @@ "\n", " # Prepare tensor board logging\n", " merged = tf.merge_all_summaries()\n", - " writer = tf.train.SummaryWriter('%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\")), session.graph)\n", + " writer = tf.train.SummaryWriter(log_dir, session.graph)\n", "\n", " # Init all variables in session\n", " session.run(tf.initialize_all_variables())\n", " \n", + " # Init recent word error rate levels\n", + " last_train_wer = 0.0\n", + " last_validation_wer = 0.0\n", + " \n", " # Loop over the data set for training_epochs epochs\n", " for epoch in range(training_iters):\n", " # Define total accuracy for the epoch\n", @@ -1189,7 +1220,7 @@ " \n", " # Validation step to determine the best point in time to stop\n", " if epoch % validation_step == 0:\n", - " _, wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n", + " _, last_validation_wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n", " # TODO: Determine on base of WER, if model starts overfitting\n", " print\n", "\n", @@ -1210,7 +1241,7 @@ " # Print progress message\n", " if epoch % display_step == 0:\n", " print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n", - " print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n", + " _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n", " print\n", "\n", " # Checkpoint the model\n", @@ -1221,7 +1252,8 @@ " print\n", " \n", " # Indicate optimization has concluded\n", - " print \"Optimization Finished!\"" + " print \"Optimization Finished!\"\n", + " return last_train_wer, last_validation_wer" ] }, { @@ -1245,8 +1277,18 @@ "# Obtain ted lium data\n", "ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n", "\n", + "# Take start time for time measurement\n", + "time_started = datetime.datetime.utcnow()\n", + "\n", "# Train the network\n", - "train(session, ted_lium)" + "last_train_wer, last_validation_wer = train(session, ted_lium)\n", + "\n", + "# Take final time for time measurement\n", + "time_finished = datetime.datetime.utcnow()\n", + "\n", + "# Calculate duration in seconds\n", + "duration = time_finished - time_started\n", + "duration = duration.days * 86400 + duration.seconds" ] }, { @@ -1265,8 +1307,80 @@ "outputs": [], "source": [ "# Test network\n", - "print_wer_report(session, \"Test\", forward(session, data_sets.test))" + "test_decodings, test_labels = forward(session, ted_lium.test)\n", + "_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Logging Hyper Parameters and Results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now as training and test are done, we persist the results alongside with the involved hyper parameters for further reporting." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:\n", + " json.dump({ \\\n", + " 'context': { \\\n", + " 'time_started': time_started.isoformat(), \\\n", + " 'time_finished': time_finished.isoformat(), \\\n", + " 'git_hash': get_git_revision_hash(), \\\n", + " 'git_branch': get_git_branch() \\\n", + " }, \\\n", + " 'parameters': { \\\n", + " 'learning_rate': learning_rate, \\\n", + " 'beta1': beta1, \\\n", + " 'beta2': beta2, \\\n", + " 'epsilon': epsilon, \\\n", + " 'training_iters': training_iters, \\\n", + " 'batch_size': batch_size, \\\n", + " 'validation_step': validation_step, \\\n", + " 'dropout_rate': dropout_rate, \\\n", + " 'relu_clip': relu_clip, \\\n", + " 'n_input': n_input, \\\n", + " 'n_context': n_context, \\\n", + " 'n_hidden_1': n_hidden_1, \\\n", + " 'n_hidden_2': n_hidden_2, \\\n", + " 'n_hidden_3': n_hidden_3, \\\n", + " 'n_hidden_5': n_hidden_5, \\\n", + " 'n_hidden_6': n_hidden_6, \\\n", + " 'n_cell_dim': n_cell_dim, \\\n", + " 'n_character': n_character, \\\n", + " 'num_examples_train': ted_lium.train.num_examples, \\\n", + " 'num_examples_validation': ted_lium.validation.num_examples, \\\n", + " 'num_examples_test': ted_lium.test.num_examples \\\n", + " }, \\\n", + " 'results': { \\\n", + " 'duration': duration, \\\n", + " 'last_train_wer': last_train_wer, \\\n", + " 'last_validation_wer': last_validation_wer, \\\n", + " 'test_wer': test_wer \\\n", + " } \\\n", + " }, dump_file, sort_keys=True, indent = 4)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] } ], "metadata": {