Logging context and hyper parameters to JSON file

This commit is contained in:
Tilman Kamp 2016-10-10 15:17:50 +02:00
parent 147ca6bde1
commit 2892feb0ac

View File

@ -81,6 +81,9 @@
"import tempfile\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import json\n",
"import subprocess\n",
"import datetime\n",
"from util.gpu import get_available_gpus\n",
"from util.text import sparse_tensor_value_to_text, wers\n",
"from tensorflow.python.ops import ctc_ops\n",
@ -1025,6 +1028,30 @@
" log_variable(variable, gradient=gradient)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we define the log directory plus some helpers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"log_dir = '%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\"))\n",
"\n",
"def get_git_revision_hash():\n",
" return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()\n",
"\n",
"def get_git_branch():\n",
" return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -1177,11 +1204,15 @@
"\n",
" # Prepare tensor board logging\n",
" merged = tf.merge_all_summaries()\n",
" writer = tf.train.SummaryWriter('%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\")), session.graph)\n",
" writer = tf.train.SummaryWriter(log_dir, session.graph)\n",
"\n",
" # Init all variables in session\n",
" session.run(tf.initialize_all_variables())\n",
" \n",
" # Init recent word error rate levels\n",
" last_train_wer = 0.0\n",
" last_validation_wer = 0.0\n",
" \n",
" # Loop over the data set for training_epochs epochs\n",
" for epoch in range(training_iters):\n",
" # Define total accuracy for the epoch\n",
@ -1189,7 +1220,7 @@
" \n",
" # Validation step to determine the best point in time to stop\n",
" if epoch % validation_step == 0:\n",
" _, wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
" _, last_validation_wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
" # TODO: Determine on base of WER, if model starts overfitting\n",
" print\n",
"\n",
@ -1210,7 +1241,7 @@
" # Print progress message\n",
" if epoch % display_step == 0:\n",
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n",
" print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
" _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
" print\n",
"\n",
" # Checkpoint the model\n",
@ -1221,7 +1252,8 @@
" print\n",
" \n",
" # Indicate optimization has concluded\n",
" print \"Optimization Finished!\""
" print \"Optimization Finished!\"\n",
" return last_train_wer, last_validation_wer"
]
},
{
@ -1245,8 +1277,18 @@
"# Obtain ted lium data\n",
"ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n",
"\n",
"# Take start time for time measurement\n",
"time_started = datetime.datetime.utcnow()\n",
"\n",
"# Train the network\n",
"train(session, ted_lium)"
"last_train_wer, last_validation_wer = train(session, ted_lium)\n",
"\n",
"# Take final time for time measurement\n",
"time_finished = datetime.datetime.utcnow()\n",
"\n",
"# Calculate duration in seconds\n",
"duration = time_finished - time_started\n",
"duration = duration.days * 86400 + duration.seconds"
]
},
{
@ -1265,8 +1307,80 @@
"outputs": [],
"source": [
"# Test network\n",
"print_wer_report(session, \"Test\", forward(session, data_sets.test))"
"test_decodings, test_labels = forward(session, ted_lium.test)\n",
"_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Logging Hyper Parameters and Results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now as training and test are done, we persist the results alongside with the involved hyper parameters for further reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:\n",
" json.dump({ \\\n",
" 'context': { \\\n",
" 'time_started': time_started.isoformat(), \\\n",
" 'time_finished': time_finished.isoformat(), \\\n",
" 'git_hash': get_git_revision_hash(), \\\n",
" 'git_branch': get_git_branch() \\\n",
" }, \\\n",
" 'parameters': { \\\n",
" 'learning_rate': learning_rate, \\\n",
" 'beta1': beta1, \\\n",
" 'beta2': beta2, \\\n",
" 'epsilon': epsilon, \\\n",
" 'training_iters': training_iters, \\\n",
" 'batch_size': batch_size, \\\n",
" 'validation_step': validation_step, \\\n",
" 'dropout_rate': dropout_rate, \\\n",
" 'relu_clip': relu_clip, \\\n",
" 'n_input': n_input, \\\n",
" 'n_context': n_context, \\\n",
" 'n_hidden_1': n_hidden_1, \\\n",
" 'n_hidden_2': n_hidden_2, \\\n",
" 'n_hidden_3': n_hidden_3, \\\n",
" 'n_hidden_5': n_hidden_5, \\\n",
" 'n_hidden_6': n_hidden_6, \\\n",
" 'n_cell_dim': n_cell_dim, \\\n",
" 'n_character': n_character, \\\n",
" 'num_examples_train': ted_lium.train.num_examples, \\\n",
" 'num_examples_validation': ted_lium.validation.num_examples, \\\n",
" 'num_examples_test': ted_lium.test.num_examples \\\n",
" }, \\\n",
" 'results': { \\\n",
" 'duration': duration, \\\n",
" 'last_train_wer': last_train_wer, \\\n",
" 'last_validation_wer': last_validation_wer, \\\n",
" 'test_wer': test_wer \\\n",
" } \\\n",
" }, dump_file, sort_keys=True, indent = 4)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {