Remove current re-scoring of decoder output and switch to custom op

This commit is contained in:
Reuben Morais 2017-09-01 11:48:10 +02:00
parent fc91e3d7b8
commit 2cccd33452
5 changed files with 62 additions and 65 deletions

1
.gitattributes vendored
View File

@ -1,2 +1,3 @@
*.binary filter=lfs diff=lfs merge=lfs -crlf
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf

View File

@ -24,7 +24,6 @@ from threading import Thread, Lock
from util.feeding import DataSet, ModelFeeder
from util.gpu import get_available_gpus
from util.shared_lib import check_cupti
from util.spell import correction
from util.text import sparse_tensor_value_to_texts, wer, Alphabet
from xdg import BaseDirectory as xdg
import numpy as np
@ -140,6 +139,9 @@ tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps t
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
# Decoder
tf.app.flags.DEFINE_string ('decoder_library_path', 'native_client/libctc_decoder_with_kenlm.so', 'path to the libctc_decoder_with_kenlm.so library containing the decoder implementation.')
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
@ -452,6 +454,58 @@ def BiRNN(batch_x, seq_length, dropout):
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6
custom_op_module = tf.load_op_library(FLAGS.decoder_library_path)
def decode_with_lm(inputs, sequence_length, beam_width=100,
top_paths=1, merge_repeated=True):
"""Performs beam search decoding on the logits given in input.
**Note** The `ctc_greedy_decoder` is a special case of the
`ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
that decoder is faster for this special case).
If `merge_repeated` is `True`, merge repeated classes in the output beams.
This means that if consecutive entries in a beam are the same,
only the first of these is emitted. That is, when the top path
is `A B B B B`, the return value is:
* `A B` if `merge_repeated = True`.
* `A B B B B` if `merge_repeated = False`.
Args:
inputs: 3-D `float` `Tensor`, size
`[max_time x batch_size x num_classes]`. The logits.
sequence_length: 1-D `int32` vector containing sequence lengths,
having size `[batch_size]`.
beam_width: An int scalar >= 0 (beam search beam width).
top_paths: An int scalar >= 0, <= beam_width (controls output size).
merge_repeated: Boolean. Default: True.
Returns:
A tuple `(decoded, log_probabilities)` where
decoded: A list of length top_paths, where `decoded[j]`
is a `SparseTensor` containing the decoded outputs:
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
The rows store: [batch, time].
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
The vector stores the decoded classes for beam j.
`decoded[j].shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length[j]]`.
log_probability: A `float` matrix `(batch_size x top_paths)` containing
sequence log-probabilities.
"""
decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
custom_op_module.ctc_beam_search_decoder_with_lm(
inputs, sequence_length, model_path="data/lm/lm.binary", trie_path="data/lm/trie", alphabet_path="data/alphabet.txt",
beam_width=beam_width, top_paths=top_paths, merge_repeated=merge_repeated))
return (
[tf.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(decoded_ixs, decoded_vals, decoded_shapes)],
log_probabilities)
# Accuracy and Loss
# =================
@ -485,7 +539,7 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout):
avg_loss = tf.reduce_mean(total_loss)
# Beam search decode the batch
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False)
decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=1024)
# Compute the edit (Levenshtein) distance
distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y)
@ -718,9 +772,8 @@ def calculate_report(results_tuple):
items = list(zip(*results_tuple))
mean_wer = 0.0
for label, decoding, distance, loss in items:
corrected = correction(decoding, alphabet)
sample_wer = wer(label, corrected)
sample = Sample(label, corrected, loss, distance, sample_wer)
sample_wer = wer(label, decoding)
sample = Sample(label, decoding, loss, distance, sample_wer)
samples.append(sample)
mean_wer += sample_wer

3
data/lm/trie Normal file
View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:55da7b52ddb19f46301a31d56aff35ed1508fd9bd1e04d907114d89771892219
size 43550329

View File

@ -12,4 +12,3 @@ python_speech_features
pyxdg
bs4
six
https://github.com/kpu/kenlm/archive/master.zip

View File

@ -1,59 +0,0 @@
from __future__ import absolute_import
import re
import kenlm
from heapq import heapify
from six.moves import range
# Define beam with for alt sentence search
BEAM_WIDTH = 1024
MODEL = None
# Lazy-load language model (TED corpus, Kneser-Ney, 4-gram, 30k word LM)
def get_model():
global MODEL
if MODEL is None:
MODEL = kenlm.Model('./data/lm/lm.binary')
return MODEL
def words(text):
"List of words in text."
return re.findall(r'\w+', text.lower())
# Load known word set
with open('./data/spell/words.txt') as f:
WORDS = set(words(f.read()))
def log_probability(sentence):
"Log base 10 probability of `sentence`, a list of words"
return get_model().score(' '.join(sentence), bos = False, eos = False)
def correction(sentence, alphabet):
"Most probable spelling correction for sentence."
layer = [(0,[])]
for word in words(sentence):
layer = [(-log_probability(node + [cword]), node + [cword]) for cword in candidate_words(word, alphabet) for priority, node in layer]
heapify(layer)
layer = layer[:BEAM_WIDTH]
return ' '.join(layer[0][1])
def candidate_words(word, alphabet):
"Generate possible spelling corrections for word."
return (known_words([word]) or known_words(edits1(word, alphabet)) or known_words(edits2(word, alphabet)) or [word])
def known_words(words):
"The subset of `words` that appear in the dictionary of WORDS."
return set(w for w in words if w in WORDS)
def edits1(word, alphabet):
"All edits that are one edit away from `word`."
letters = [alphabet.string_from_label(i) for i in range(alphabet.size())]
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
deletes = [L + R[1:] for L, R in splits if R]
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
inserts = [L + c + R for L, R in splits for c in letters]
return set(deletes + transposes + replaces + inserts)
def edits2(word, alphabet):
"All edits that are two edits away from `word`."
return (e2 for e1 in edits1(word, alphabet) for e2 in edits1(e1, alphabet))