Logging context and hyper parameters to JSON file
This commit is contained in:
parent
147ca6bde1
commit
2892feb0ac
126
DeepSpeech.ipynb
126
DeepSpeech.ipynb
@ -81,6 +81,9 @@
|
|||||||
"import tempfile\n",
|
"import tempfile\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import tensorflow as tf\n",
|
"import tensorflow as tf\n",
|
||||||
|
"import json\n",
|
||||||
|
"import subprocess\n",
|
||||||
|
"import datetime\n",
|
||||||
"from util.gpu import get_available_gpus\n",
|
"from util.gpu import get_available_gpus\n",
|
||||||
"from util.text import sparse_tensor_value_to_text, wers\n",
|
"from util.text import sparse_tensor_value_to_text, wers\n",
|
||||||
"from tensorflow.python.ops import ctc_ops\n",
|
"from tensorflow.python.ops import ctc_ops\n",
|
||||||
@ -1025,6 +1028,30 @@
|
|||||||
" log_variable(variable, gradient=gradient)"
|
" log_variable(variable, gradient=gradient)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Finally we define the log directory plus some helpers."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"log_dir = '%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\"))\n",
|
||||||
|
"\n",
|
||||||
|
"def get_git_revision_hash():\n",
|
||||||
|
" return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()\n",
|
||||||
|
"\n",
|
||||||
|
"def get_git_branch():\n",
|
||||||
|
" return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip()"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -1177,11 +1204,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Prepare tensor board logging\n",
|
" # Prepare tensor board logging\n",
|
||||||
" merged = tf.merge_all_summaries()\n",
|
" merged = tf.merge_all_summaries()\n",
|
||||||
" writer = tf.train.SummaryWriter('%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\")), session.graph)\n",
|
" writer = tf.train.SummaryWriter(log_dir, session.graph)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Init all variables in session\n",
|
" # Init all variables in session\n",
|
||||||
" session.run(tf.initialize_all_variables())\n",
|
" session.run(tf.initialize_all_variables())\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
" # Init recent word error rate levels\n",
|
||||||
|
" last_train_wer = 0.0\n",
|
||||||
|
" last_validation_wer = 0.0\n",
|
||||||
|
" \n",
|
||||||
" # Loop over the data set for training_epochs epochs\n",
|
" # Loop over the data set for training_epochs epochs\n",
|
||||||
" for epoch in range(training_iters):\n",
|
" for epoch in range(training_iters):\n",
|
||||||
" # Define total accuracy for the epoch\n",
|
" # Define total accuracy for the epoch\n",
|
||||||
@ -1189,7 +1220,7 @@
|
|||||||
" \n",
|
" \n",
|
||||||
" # Validation step to determine the best point in time to stop\n",
|
" # Validation step to determine the best point in time to stop\n",
|
||||||
" if epoch % validation_step == 0:\n",
|
" if epoch % validation_step == 0:\n",
|
||||||
" _, wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
|
" _, last_validation_wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
|
||||||
" # TODO: Determine on base of WER, if model starts overfitting\n",
|
" # TODO: Determine on base of WER, if model starts overfitting\n",
|
||||||
" print\n",
|
" print\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -1210,7 +1241,7 @@
|
|||||||
" # Print progress message\n",
|
" # Print progress message\n",
|
||||||
" if epoch % display_step == 0:\n",
|
" if epoch % display_step == 0:\n",
|
||||||
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n",
|
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n",
|
||||||
" print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
|
" _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
|
||||||
" print\n",
|
" print\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Checkpoint the model\n",
|
" # Checkpoint the model\n",
|
||||||
@ -1221,7 +1252,8 @@
|
|||||||
" print\n",
|
" print\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Indicate optimization has concluded\n",
|
" # Indicate optimization has concluded\n",
|
||||||
" print \"Optimization Finished!\""
|
" print \"Optimization Finished!\"\n",
|
||||||
|
" return last_train_wer, last_validation_wer"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1245,8 +1277,18 @@
|
|||||||
"# Obtain ted lium data\n",
|
"# Obtain ted lium data\n",
|
||||||
"ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n",
|
"ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Take start time for time measurement\n",
|
||||||
|
"time_started = datetime.datetime.utcnow()\n",
|
||||||
|
"\n",
|
||||||
"# Train the network\n",
|
"# Train the network\n",
|
||||||
"train(session, ted_lium)"
|
"last_train_wer, last_validation_wer = train(session, ted_lium)\n",
|
||||||
|
"\n",
|
||||||
|
"# Take final time for time measurement\n",
|
||||||
|
"time_finished = datetime.datetime.utcnow()\n",
|
||||||
|
"\n",
|
||||||
|
"# Calculate duration in seconds\n",
|
||||||
|
"duration = time_finished - time_started\n",
|
||||||
|
"duration = duration.days * 86400 + duration.seconds"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1265,8 +1307,80 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Test network\n",
|
"# Test network\n",
|
||||||
"print_wer_report(session, \"Test\", forward(session, data_sets.test))"
|
"test_decodings, test_labels = forward(session, ted_lium.test)\n",
|
||||||
|
"_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Logging Hyper Parameters and Results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Now as training and test are done, we persist the results alongside with the involved hyper parameters for further reporting."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:\n",
|
||||||
|
" json.dump({ \\\n",
|
||||||
|
" 'context': { \\\n",
|
||||||
|
" 'time_started': time_started.isoformat(), \\\n",
|
||||||
|
" 'time_finished': time_finished.isoformat(), \\\n",
|
||||||
|
" 'git_hash': get_git_revision_hash(), \\\n",
|
||||||
|
" 'git_branch': get_git_branch() \\\n",
|
||||||
|
" }, \\\n",
|
||||||
|
" 'parameters': { \\\n",
|
||||||
|
" 'learning_rate': learning_rate, \\\n",
|
||||||
|
" 'beta1': beta1, \\\n",
|
||||||
|
" 'beta2': beta2, \\\n",
|
||||||
|
" 'epsilon': epsilon, \\\n",
|
||||||
|
" 'training_iters': training_iters, \\\n",
|
||||||
|
" 'batch_size': batch_size, \\\n",
|
||||||
|
" 'validation_step': validation_step, \\\n",
|
||||||
|
" 'dropout_rate': dropout_rate, \\\n",
|
||||||
|
" 'relu_clip': relu_clip, \\\n",
|
||||||
|
" 'n_input': n_input, \\\n",
|
||||||
|
" 'n_context': n_context, \\\n",
|
||||||
|
" 'n_hidden_1': n_hidden_1, \\\n",
|
||||||
|
" 'n_hidden_2': n_hidden_2, \\\n",
|
||||||
|
" 'n_hidden_3': n_hidden_3, \\\n",
|
||||||
|
" 'n_hidden_5': n_hidden_5, \\\n",
|
||||||
|
" 'n_hidden_6': n_hidden_6, \\\n",
|
||||||
|
" 'n_cell_dim': n_cell_dim, \\\n",
|
||||||
|
" 'n_character': n_character, \\\n",
|
||||||
|
" 'num_examples_train': ted_lium.train.num_examples, \\\n",
|
||||||
|
" 'num_examples_validation': ted_lium.validation.num_examples, \\\n",
|
||||||
|
" 'num_examples_test': ted_lium.test.num_examples \\\n",
|
||||||
|
" }, \\\n",
|
||||||
|
" 'results': { \\\n",
|
||||||
|
" 'duration': duration, \\\n",
|
||||||
|
" 'last_train_wer': last_train_wer, \\\n",
|
||||||
|
" 'last_validation_wer': last_validation_wer, \\\n",
|
||||||
|
" 'test_wer': test_wer \\\n",
|
||||||
|
" } \\\n",
|
||||||
|
" }, dump_file, sort_keys=True, indent = 4)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user