Logging context and hyper parameters to JSON file
This commit is contained in:
parent
147ca6bde1
commit
2892feb0ac
126
DeepSpeech.ipynb
126
DeepSpeech.ipynb
@ -81,6 +81,9 @@
|
||||
"import tempfile\n",
|
||||
"import numpy as np\n",
|
||||
"import tensorflow as tf\n",
|
||||
"import json\n",
|
||||
"import subprocess\n",
|
||||
"import datetime\n",
|
||||
"from util.gpu import get_available_gpus\n",
|
||||
"from util.text import sparse_tensor_value_to_text, wers\n",
|
||||
"from tensorflow.python.ops import ctc_ops\n",
|
||||
@ -1025,6 +1028,30 @@
|
||||
" log_variable(variable, gradient=gradient)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally we define the log directory plus some helpers."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"log_dir = '%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\"))\n",
|
||||
"\n",
|
||||
"def get_git_revision_hash():\n",
|
||||
" return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()\n",
|
||||
"\n",
|
||||
"def get_git_branch():\n",
|
||||
" return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -1177,11 +1204,15 @@
|
||||
"\n",
|
||||
" # Prepare tensor board logging\n",
|
||||
" merged = tf.merge_all_summaries()\n",
|
||||
" writer = tf.train.SummaryWriter('%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\")), session.graph)\n",
|
||||
" writer = tf.train.SummaryWriter(log_dir, session.graph)\n",
|
||||
"\n",
|
||||
" # Init all variables in session\n",
|
||||
" session.run(tf.initialize_all_variables())\n",
|
||||
" \n",
|
||||
" # Init recent word error rate levels\n",
|
||||
" last_train_wer = 0.0\n",
|
||||
" last_validation_wer = 0.0\n",
|
||||
" \n",
|
||||
" # Loop over the data set for training_epochs epochs\n",
|
||||
" for epoch in range(training_iters):\n",
|
||||
" # Define total accuracy for the epoch\n",
|
||||
@ -1189,7 +1220,7 @@
|
||||
" \n",
|
||||
" # Validation step to determine the best point in time to stop\n",
|
||||
" if epoch % validation_step == 0:\n",
|
||||
" _, wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
|
||||
" _, last_validation_wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
|
||||
" # TODO: Determine on base of WER, if model starts overfitting\n",
|
||||
" print\n",
|
||||
"\n",
|
||||
@ -1210,7 +1241,7 @@
|
||||
" # Print progress message\n",
|
||||
" if epoch % display_step == 0:\n",
|
||||
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n",
|
||||
" print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
|
||||
" _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
|
||||
" print\n",
|
||||
"\n",
|
||||
" # Checkpoint the model\n",
|
||||
@ -1221,7 +1252,8 @@
|
||||
" print\n",
|
||||
" \n",
|
||||
" # Indicate optimization has concluded\n",
|
||||
" print \"Optimization Finished!\""
|
||||
" print \"Optimization Finished!\"\n",
|
||||
" return last_train_wer, last_validation_wer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1245,8 +1277,18 @@
|
||||
"# Obtain ted lium data\n",
|
||||
"ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n",
|
||||
"\n",
|
||||
"# Take start time for time measurement\n",
|
||||
"time_started = datetime.datetime.utcnow()\n",
|
||||
"\n",
|
||||
"# Train the network\n",
|
||||
"train(session, ted_lium)"
|
||||
"last_train_wer, last_validation_wer = train(session, ted_lium)\n",
|
||||
"\n",
|
||||
"# Take final time for time measurement\n",
|
||||
"time_finished = datetime.datetime.utcnow()\n",
|
||||
"\n",
|
||||
"# Calculate duration in seconds\n",
|
||||
"duration = time_finished - time_started\n",
|
||||
"duration = duration.days * 86400 + duration.seconds"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1265,8 +1307,80 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test network\n",
|
||||
"print_wer_report(session, \"Test\", forward(session, data_sets.test))"
|
||||
"test_decodings, test_labels = forward(session, ted_lium.test)\n",
|
||||
"_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Logging Hyper Parameters and Results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now as training and test are done, we persist the results alongside with the involved hyper parameters for further reporting."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:\n",
|
||||
" json.dump({ \\\n",
|
||||
" 'context': { \\\n",
|
||||
" 'time_started': time_started.isoformat(), \\\n",
|
||||
" 'time_finished': time_finished.isoformat(), \\\n",
|
||||
" 'git_hash': get_git_revision_hash(), \\\n",
|
||||
" 'git_branch': get_git_branch() \\\n",
|
||||
" }, \\\n",
|
||||
" 'parameters': { \\\n",
|
||||
" 'learning_rate': learning_rate, \\\n",
|
||||
" 'beta1': beta1, \\\n",
|
||||
" 'beta2': beta2, \\\n",
|
||||
" 'epsilon': epsilon, \\\n",
|
||||
" 'training_iters': training_iters, \\\n",
|
||||
" 'batch_size': batch_size, \\\n",
|
||||
" 'validation_step': validation_step, \\\n",
|
||||
" 'dropout_rate': dropout_rate, \\\n",
|
||||
" 'relu_clip': relu_clip, \\\n",
|
||||
" 'n_input': n_input, \\\n",
|
||||
" 'n_context': n_context, \\\n",
|
||||
" 'n_hidden_1': n_hidden_1, \\\n",
|
||||
" 'n_hidden_2': n_hidden_2, \\\n",
|
||||
" 'n_hidden_3': n_hidden_3, \\\n",
|
||||
" 'n_hidden_5': n_hidden_5, \\\n",
|
||||
" 'n_hidden_6': n_hidden_6, \\\n",
|
||||
" 'n_cell_dim': n_cell_dim, \\\n",
|
||||
" 'n_character': n_character, \\\n",
|
||||
" 'num_examples_train': ted_lium.train.num_examples, \\\n",
|
||||
" 'num_examples_validation': ted_lium.validation.num_examples, \\\n",
|
||||
" 'num_examples_test': ted_lium.test.num_examples \\\n",
|
||||
" }, \\\n",
|
||||
" 'results': { \\\n",
|
||||
" 'duration': duration, \\\n",
|
||||
" 'last_train_wer': last_train_wer, \\\n",
|
||||
" 'last_validation_wer': last_validation_wer, \\\n",
|
||||
" 'test_wer': test_wer \\\n",
|
||||
" } \\\n",
|
||||
" }, dump_file, sort_keys=True, indent = 4)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user