Fix #12; integration of WER in training, validation and test

This commit is contained in:
Tilman Kamp 2016-10-07 17:23:20 +02:00
parent 122c27c6a2
commit 147ca6bde1
2 changed files with 168 additions and 153 deletions

View File

@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"collapsed": false
},
@ -82,7 +82,7 @@
"import numpy as np\n",
"import tensorflow as tf\n",
"from util.gpu import get_available_gpus\n",
"from util.text import sparse_tensor_value_to_text\n",
"from util.text import sparse_tensor_value_to_text, wers\n",
"from tensorflow.python.ops import ctc_ops\n",
"from util.importers.ted_lium import read_data_sets"
]
@ -109,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -119,10 +119,11 @@
"beta1 = 0.9 # TODO: Determine a reasonable value for this\n",
"beta2 = 0.999 # TODO: Determine a reasonable value for this\n",
"epsilon = 1e-8 # TODO: Determine a reasonable value for this\n",
"training_iters = 50 # TODO: Determine a reasonable value for this\n",
"training_iters = 1250 # TODO: Determine a reasonable value for this\n",
"batch_size = 1 # TODO: Determine a reasonable value for this\n",
"display_step = 1 # TODO: Determine a reasonable value for this\n",
"checkpoint_step = 50 # TODO: Determine a reasonable value for this\n",
"display_step = 10 # TODO: Determine a reasonable value for this\n",
"validation_step = 50\n",
"checkpoint_step = 1000 # TODO: Determine a reasonable value for this\n",
"checkpoint_dir = tempfile.gettempdir() # TODO: Determine a reasonable value for this"
]
},
@ -137,13 +138,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"dropout_rate = 0.00 # TODO: Validate this is a reasonable value"
"dropout_rate = 0.01 # TODO: Validate this is a reasonable value"
]
},
{
@ -155,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -189,7 +190,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -207,7 +208,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -225,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -282,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -300,7 +301,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -318,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -336,7 +337,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -363,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {
"collapsed": false
},
@ -390,7 +391,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -592,7 +593,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -778,7 +779,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {
"collapsed": false
},
@ -819,7 +820,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -842,7 +843,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -905,7 +906,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -952,7 +953,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"metadata": {
"collapsed": false
},
@ -981,7 +982,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -1013,7 +1014,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {
"collapsed": true
},
@ -1024,6 +1025,102 @@
" log_variable(variable, gradient=gradient)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test and Validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we need a helper method to create a normal forward calculation without optimization, dropouts and special reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def forward(session, data_set):\n",
" # Set n_steps parameter\n",
" n_steps = data_set.max_batch_seq_len\n",
"\n",
" # Calculate the total number of batches\n",
" total_batch = int(data_set.num_examples/batch_size)\n",
"\n",
" # Require that we have at least as many batches as devices\n",
" assert total_batch >= len(available_devices)\n",
" \n",
" # Get gradients for each tower (Runs across all GPU's)\n",
" tower_decodings, tower_labels, _, _, _ = get_tower_results(n_steps, data_set)\n",
" return tower_decodings, tower_labels\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To report progress and to get an idea of the current state of the model, we create a method that calculates the word error rate (WER) out of (tower) decodings and their respective (original) labels. This is done for each and every entry and as a mean value of all WERs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def calculate_wer(session, tower_decodings, tower_labels):\n",
" originals = []\n",
" results = []\n",
" \n",
" # Normalization\n",
" tower_decodings = [j for i in tower_decodings for j in i]\n",
" \n",
" # Iterating over the towers\n",
" for i in range(len(tower_decodings)):\n",
" decoded, labels = session.run([tower_decodings[i], tower_labels[i]])\n",
" originals.extend(sparse_tensor_value_to_text(labels))\n",
" results.extend(sparse_tensor_value_to_text(decoded))\n",
" \n",
" # Pairwise calculation of all rates\n",
" rates, mean = wers(originals, results)\n",
" return zip(originals, results, rates), mean"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plus a convenience method to calculate and report the WER bundle all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def print_wer_report(session, caption, tower_decodings, tower_labels, show_example=True):\n",
" items, mean = calculate_wer(session, tower_decodings, tower_labels)\n",
" print \"%s WER: %f09\" % (caption, mean)\n",
" if len(items) > 0 and show_example:\n",
" print \"Example (WER = %f09)\" % items[0][2]\n",
" print \" - source: \\\"%s\\\"\" % items[0][0]\n",
" print \" - result: \\\"%s\\\"\" % items[0][1] \n",
" return items, mean"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -1040,7 +1137,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {
"collapsed": false
},
@ -1060,7 +1157,11 @@
" optimizer = create_optimizer()\n",
"\n",
" # Get gradients for each tower (Runs across all GPU's)\n",
" _, _, tower_gradients, tower_loss, accuracy = get_tower_results(n_steps, data_sets.train, optimizer)\n",
" tower_decodings, tower_labels, tower_gradients, tower_loss, accuracy = \\\n",
" get_tower_results(n_steps, data_sets.train, optimizer)\n",
" \n",
" # Validation step preparation\n",
" validation_tower_decodings, validation_tower_labels = forward(session, data_sets.validation)\n",
"\n",
" # Average tower gradients\n",
" avg_tower_gradients = average_gradients(tower_gradients)\n",
@ -1085,6 +1186,12 @@
" for epoch in range(training_iters):\n",
" # Define total accuracy for the epoch\n",
" total_accuracy = 0\n",
" \n",
" # Validation step to determine the best point in time to stop\n",
" if epoch % validation_step == 0:\n",
" _, wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
" # TODO: Determine on base of WER, if model starts overfitting\n",
" print\n",
"\n",
" # Loop over the batches\n",
" for batch in range(total_batch/len(available_devices)):\n",
@ -1099,130 +1206,38 @@
" summary_str = session.run(merged)\n",
" writer.add_summary(summary_str, step)\n",
" writer.flush()\n",
"\n",
" \n",
" # Print progress message\n",
" if epoch % display_step == 0:\n",
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n",
" print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
" print\n",
"\n",
" # Checkpoint the model\n",
" if (epoch % checkpoint_step == 0) or (epoch == training_iters - 1):\n",
" checkpoint_path = os.path.join(checkpoint_dir, 'model.ckpt')\n",
" print \"Checkpointing in directory\", \"%s\" % checkpoint_dir\n",
" saver.save(session, checkpoint_path, global_step=epoch)\n",
" print\n",
" \n",
" # Indicate optimization has concluded\n",
" print \"Optimization Finished!\""
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
},
"outputs": [],
"cell_type": "markdown",
"metadata": {},
"source": [
"def validate(session, data_sets):\n",
" # Set n_steps parameter\n",
" n_steps = data_sets.validation.max_batch_seq_len\n",
"\n",
" # Calculate the total number of batches\n",
" total_batch = int(data_sets.validation.num_examples/batch_size)\n",
"\n",
" # Require that we have at least as many batches as devices\n",
" assert total_batch >= len(available_devices)\n",
" \n",
" # Get gradients for each tower (Runs across all GPU's)\n",
" tower_decodings, tower_labels, _, _, _ = get_tower_results(n_steps, data_sets.validation)\n",
" tower_decodings = [j for i in tower_decodings for j in i]\n",
"\n",
" # Loop over the batches\n",
" for i in range(len(tower_decodings)):\n",
" decoded, labels = session.run([tower_decodings[i], tower_labels[i]])\n",
" str_decoded = sparse_tensor_value_to_text(decoded)\n",
" print('Decoded:\\n%s' % str_decoded)\n",
" str_labels = sparse_tensor_value_to_text(labels)\n",
" print('Labels:\\n%s' % str_labels)\n",
"\n",
" # Indicate verification has concluded\n",
" print \"Verification Finished!\""
"As everything is prepared, we are now able to do the training."
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.BasicLSTMCell object at 0x1138a1850>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.\n",
"WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.BasicLSTMCell object at 0x11399ced0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0001 avg_cer= 3.750000000\n",
"Checkpointing in directory /var/folders/rx/vsxg55s54vgd7s9v5_5k7q7h0000gn/T\n",
"Epoch: 0002 avg_cer= 3.557692289\n",
"Epoch: 0003 avg_cer= 3.750000000\n",
"Epoch: 0004 avg_cer= 3.750000000\n",
"Epoch: 0005 avg_cer= 3.846153736\n",
"Epoch: 0006 avg_cer= 3.557692289\n",
"Epoch: 0007 avg_cer= 3.057692289\n",
"Epoch: 0008 avg_cer= 2.961538553\n",
"Epoch: 0009 avg_cer= 3.750000000\n",
"Epoch: 0010 avg_cer= 4.288461685\n",
"Epoch: 0011 avg_cer= 4.230769157\n",
"Epoch: 0012 avg_cer= 4.153846264\n",
"Epoch: 0013 avg_cer= 4.192307472\n",
"Epoch: 0014 avg_cer= 4.153846264\n",
"Epoch: 0015 avg_cer= 4.230769157\n",
"Epoch: 0016 avg_cer= 3.980769157\n",
"Epoch: 0017 avg_cer= 3.750000000\n",
"Epoch: 0018 avg_cer= 3.923076868\n",
"Epoch: 0019 avg_cer= 3.923076868\n",
"Epoch: 0020 avg_cer= 4.096153736\n",
"Epoch: 0021 avg_cer= 3.384615421\n",
"Epoch: 0022 avg_cer= 3.096153736\n",
"Epoch: 0023 avg_cer= 2.903846264\n",
"Epoch: 0024 avg_cer= 3.057692289\n",
"Epoch: 0025 avg_cer= 3.153846264\n",
"Epoch: 0026 avg_cer= 3.288461447\n",
"Epoch: 0027 avg_cer= 3.076923132\n",
"Epoch: 0028 avg_cer= 3.153846264\n",
"Epoch: 0029 avg_cer= 3.115384579\n",
"Epoch: 0030 avg_cer= 2.576923132\n",
"Epoch: 0031 avg_cer= 2.711538553\n",
"Epoch: 0032 avg_cer= 2.288461447\n",
"Epoch: 0033 avg_cer= 2.269230843\n",
"Epoch: 0034 avg_cer= 2.692307711\n",
"Epoch: 0035 avg_cer= 2.750000000\n",
"Epoch: 0036 avg_cer= 2.711538553\n",
"Epoch: 0037 avg_cer= 2.211538553\n",
"Epoch: 0038 avg_cer= 2.423076868\n",
"Epoch: 0039 avg_cer= 1.980769277\n",
"Epoch: 0040 avg_cer= 2.461538553\n",
"Epoch: 0041 avg_cer= 2.653846264\n",
"Epoch: 0042 avg_cer= 2.615384579\n",
"Epoch: 0043 avg_cer= 2.615384579\n",
"Epoch: 0044 avg_cer= 2.769230843\n",
"Epoch: 0045 avg_cer= 2.807692289\n",
"Epoch: 0046 avg_cer= 2.153846264\n",
"Epoch: 0047 avg_cer= 2.307692289\n",
"Epoch: 0048 avg_cer= 2.519230843\n",
"Epoch: 0049 avg_cer= 2.750000000\n",
"Epoch: 0050 avg_cer= 2.730769157\n",
"Checkpointing in directory /var/folders/rx/vsxg55s54vgd7s9v5_5k7q7h0000gn/T\n",
"Optimization Finished!\n"
]
}
],
"outputs": [],
"source": [
"# Create session in which to execute\n",
"session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))\n",
@ -1234,36 +1249,23 @@
"train(session, ted_lium)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally the trained model is tested using an unbiased test set."
]
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.BasicLSTMCell object at 0x12a5c4350>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.\n",
"WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.BasicLSTMCell object at 0x13bed3d90>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Decoded:\n",
"['h jixorxrorororororororororororororororororororororororororororororxrorororororororororororororororororororororororororororororororororororxrororxzs ']\n",
"Labels:\n",
"['she had your dark suit in greasy wash water all year']\n",
"Verification Finished!\n"
]
}
],
"outputs": [],
"source": [
"# Validate network\n",
"validate(session, ted_lium)"
"# Test network\n",
"print_wer_report(session, \"Test\", forward(session, data_sets.test))"
]
}
],

View File

@ -70,6 +70,19 @@ def sparse_tuple_to_text(tuple):
# List of strings
return results
def wer(original, result):
return levenshtein(original, result) / float(len(original.split(' ')))
def wers(originals, results):
count = len(originals)
rates = []
mean = 0.0
assert count == len(results)
for i in range(count):
rate = wer(originals[i], results[i])
mean = mean + rate
rates.append(mean)
return rates, mean / float(count)
# The following code is from: http://hetland.org/coding/python/levenshtein.py