Logging context and hyper parameters to JSON file
This commit is contained in:
		
							parent
							
								
									147ca6bde1
								
							
						
					
					
						commit
						2892feb0ac
					
				
							
								
								
									
										126
									
								
								DeepSpeech.ipynb
									
									
									
									
									
								
							
							
						
						
									
										126
									
								
								DeepSpeech.ipynb
									
									
									
									
									
								
							| @ -81,6 +81,9 @@ | |||||||
|     "import tempfile\n", |     "import tempfile\n", | ||||||
|     "import numpy as np\n", |     "import numpy as np\n", | ||||||
|     "import tensorflow as tf\n", |     "import tensorflow as tf\n", | ||||||
|  |     "import json\n", | ||||||
|  |     "import subprocess\n", | ||||||
|  |     "import datetime\n", | ||||||
|     "from util.gpu import get_available_gpus\n", |     "from util.gpu import get_available_gpus\n", | ||||||
|     "from util.text import sparse_tensor_value_to_text, wers\n", |     "from util.text import sparse_tensor_value_to_text, wers\n", | ||||||
|     "from tensorflow.python.ops import ctc_ops\n", |     "from tensorflow.python.ops import ctc_ops\n", | ||||||
| @ -1025,6 +1028,30 @@ | |||||||
|     "        log_variable(variable, gradient=gradient)" |     "        log_variable(variable, gradient=gradient)" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "Finally we define the log directory plus some helpers." | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": true | ||||||
|  |    }, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "log_dir = '%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\"))\n", | ||||||
|  |     "\n", | ||||||
|  |     "def get_git_revision_hash():\n", | ||||||
|  |     "    return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()\n", | ||||||
|  |     "\n", | ||||||
|  |     "def get_git_branch():\n", | ||||||
|  |     "    return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip()" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "markdown", |    "cell_type": "markdown", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
| @ -1177,11 +1204,15 @@ | |||||||
|     "\n", |     "\n", | ||||||
|     "    # Prepare tensor board logging\n", |     "    # Prepare tensor board logging\n", | ||||||
|     "    merged = tf.merge_all_summaries()\n", |     "    merged = tf.merge_all_summaries()\n", | ||||||
|     "    writer = tf.train.SummaryWriter('%s/%s' % (\"logs\", time.strftime(\"%Y%m%d-%H%M%S\")), session.graph)\n", |     "    writer = tf.train.SummaryWriter(log_dir, session.graph)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "    # Init all variables in session\n", |     "    # Init all variables in session\n", | ||||||
|     "    session.run(tf.initialize_all_variables())\n", |     "    session.run(tf.initialize_all_variables())\n", | ||||||
|     "    \n", |     "    \n", | ||||||
|  |     "    # Init recent word error rate levels\n", | ||||||
|  |     "    last_train_wer = 0.0\n", | ||||||
|  |     "    last_validation_wer = 0.0\n", | ||||||
|  |     "    \n", | ||||||
|     "    # Loop over the data set for training_epochs epochs\n", |     "    # Loop over the data set for training_epochs epochs\n", | ||||||
|     "    for epoch in range(training_iters):\n", |     "    for epoch in range(training_iters):\n", | ||||||
|     "        # Define total accuracy for the epoch\n", |     "        # Define total accuracy for the epoch\n", | ||||||
| @ -1189,7 +1220,7 @@ | |||||||
|     "        \n", |     "        \n", | ||||||
|     "        # Validation step to determine the best point in time to stop\n", |     "        # Validation step to determine the best point in time to stop\n", | ||||||
|     "        if epoch % validation_step == 0:\n", |     "        if epoch % validation_step == 0:\n", | ||||||
|     "            _, wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n", |     "            _, last_validation_wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n", | ||||||
|     "            # TODO: Determine on base of WER, if model starts overfitting\n", |     "            # TODO: Determine on base of WER, if model starts overfitting\n", | ||||||
|     "            print\n", |     "            print\n", | ||||||
|     "\n", |     "\n", | ||||||
| @ -1210,7 +1241,7 @@ | |||||||
|     "        # Print progress message\n", |     "        # Print progress message\n", | ||||||
|     "        if epoch % display_step == 0:\n", |     "        if epoch % display_step == 0:\n", | ||||||
|     "            print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n", |     "            print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n", | ||||||
|     "            print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n", |     "            _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n", | ||||||
|     "            print\n", |     "            print\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "        # Checkpoint the model\n", |     "        # Checkpoint the model\n", | ||||||
| @ -1221,7 +1252,8 @@ | |||||||
|     "            print\n", |     "            print\n", | ||||||
|     "        \n", |     "        \n", | ||||||
|     "    # Indicate optimization has concluded\n", |     "    # Indicate optimization has concluded\n", | ||||||
|     "    print \"Optimization Finished!\"" |     "    print \"Optimization Finished!\"\n", | ||||||
|  |     "    return last_train_wer, last_validation_wer" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -1245,8 +1277,18 @@ | |||||||
|     "# Obtain ted lium data\n", |     "# Obtain ted lium data\n", | ||||||
|     "ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n", |     "ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n", | ||||||
|     "\n", |     "\n", | ||||||
|  |     "# Take start time for time measurement\n", | ||||||
|  |     "time_started = datetime.datetime.utcnow()\n", | ||||||
|  |     "\n", | ||||||
|     "# Train the network\n", |     "# Train the network\n", | ||||||
|     "train(session, ted_lium)" |     "last_train_wer, last_validation_wer = train(session, ted_lium)\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Take final time for time measurement\n", | ||||||
|  |     "time_finished = datetime.datetime.utcnow()\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Calculate duration in seconds\n", | ||||||
|  |     "duration = time_finished - time_started\n", | ||||||
|  |     "duration = duration.days * 86400 + duration.seconds" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @ -1265,8 +1307,80 @@ | |||||||
|    "outputs": [], |    "outputs": [], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "# Test network\n", |     "# Test network\n", | ||||||
|     "print_wer_report(session, \"Test\", forward(session, data_sets.test))" |     "test_decodings, test_labels = forward(session, ted_lium.test)\n", | ||||||
|  |     "_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)" | ||||||
|    ] |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "# Logging Hyper Parameters and Results" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "Now as training and test are done, we persist the results alongside with the involved hyper parameters for further reporting." | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": false | ||||||
|  |    }, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:\n", | ||||||
|  |     "    json.dump({ \\\n", | ||||||
|  |     "        'context': { \\\n", | ||||||
|  |     "            'time_started': time_started.isoformat(), \\\n", | ||||||
|  |     "            'time_finished': time_finished.isoformat(), \\\n", | ||||||
|  |     "            'git_hash': get_git_revision_hash(), \\\n", | ||||||
|  |     "            'git_branch': get_git_branch() \\\n", | ||||||
|  |     "        }, \\\n", | ||||||
|  |     "        'parameters': { \\\n", | ||||||
|  |     "            'learning_rate': learning_rate, \\\n", | ||||||
|  |     "            'beta1': beta1, \\\n", | ||||||
|  |     "            'beta2': beta2, \\\n", | ||||||
|  |     "            'epsilon': epsilon, \\\n", | ||||||
|  |     "            'training_iters': training_iters, \\\n", | ||||||
|  |     "            'batch_size': batch_size, \\\n", | ||||||
|  |     "            'validation_step': validation_step, \\\n", | ||||||
|  |     "            'dropout_rate': dropout_rate, \\\n", | ||||||
|  |     "            'relu_clip': relu_clip, \\\n", | ||||||
|  |     "            'n_input': n_input, \\\n", | ||||||
|  |     "            'n_context': n_context, \\\n", | ||||||
|  |     "            'n_hidden_1': n_hidden_1, \\\n", | ||||||
|  |     "            'n_hidden_2': n_hidden_2, \\\n", | ||||||
|  |     "            'n_hidden_3': n_hidden_3, \\\n", | ||||||
|  |     "            'n_hidden_5': n_hidden_5, \\\n", | ||||||
|  |     "            'n_hidden_6': n_hidden_6, \\\n", | ||||||
|  |     "            'n_cell_dim': n_cell_dim, \\\n", | ||||||
|  |     "            'n_character': n_character, \\\n", | ||||||
|  |     "            'num_examples_train': ted_lium.train.num_examples, \\\n", | ||||||
|  |     "            'num_examples_validation': ted_lium.validation.num_examples, \\\n", | ||||||
|  |     "            'num_examples_test': ted_lium.test.num_examples \\\n", | ||||||
|  |     "        }, \\\n", | ||||||
|  |     "        'results': { \\\n", | ||||||
|  |     "            'duration': duration, \\\n", | ||||||
|  |     "            'last_train_wer': last_train_wer, \\\n", | ||||||
|  |     "            'last_validation_wer': last_validation_wer, \\\n", | ||||||
|  |     "            'test_wer': test_wer \\\n", | ||||||
|  |     "        } \\\n", | ||||||
|  |     "    }, dump_file, sort_keys=True, indent = 4)\n" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": { | ||||||
|  |     "collapsed": true | ||||||
|  |    }, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|   } |   } | ||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user