Remove current re-scoring of decoder output and switch to custom op
This commit is contained in:
parent
fc91e3d7b8
commit
2cccd33452
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1,2 +1,3 @@
|
|||||||
|
|
||||||
*.binary filter=lfs diff=lfs merge=lfs -crlf
|
*.binary filter=lfs diff=lfs merge=lfs -crlf
|
||||||
|
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf
|
||||||
|
@ -24,7 +24,6 @@ from threading import Thread, Lock
|
|||||||
from util.feeding import DataSet, ModelFeeder
|
from util.feeding import DataSet, ModelFeeder
|
||||||
from util.gpu import get_available_gpus
|
from util.gpu import get_available_gpus
|
||||||
from util.shared_lib import check_cupti
|
from util.shared_lib import check_cupti
|
||||||
from util.spell import correction
|
|
||||||
from util.text import sparse_tensor_value_to_texts, wer, Alphabet
|
from util.text import sparse_tensor_value_to_texts, wer, Alphabet
|
||||||
from xdg import BaseDirectory as xdg
|
from xdg import BaseDirectory as xdg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -140,6 +139,9 @@ tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps t
|
|||||||
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||||
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||||
|
|
||||||
|
# Decoder
|
||||||
|
|
||||||
|
tf.app.flags.DEFINE_string ('decoder_library_path', 'native_client/libctc_decoder_with_kenlm.so', 'path to the libctc_decoder_with_kenlm.so library containing the decoder implementation.')
|
||||||
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||||
|
|
||||||
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
|
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
|
||||||
@ -452,6 +454,58 @@ def BiRNN(batch_x, seq_length, dropout):
|
|||||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||||
return layer_6
|
return layer_6
|
||||||
|
|
||||||
|
custom_op_module = tf.load_op_library(FLAGS.decoder_library_path)
|
||||||
|
|
||||||
|
def decode_with_lm(inputs, sequence_length, beam_width=100,
|
||||||
|
top_paths=1, merge_repeated=True):
|
||||||
|
"""Performs beam search decoding on the logits given in input.
|
||||||
|
|
||||||
|
**Note** The `ctc_greedy_decoder` is a special case of the
|
||||||
|
`ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
|
||||||
|
that decoder is faster for this special case).
|
||||||
|
|
||||||
|
If `merge_repeated` is `True`, merge repeated classes in the output beams.
|
||||||
|
This means that if consecutive entries in a beam are the same,
|
||||||
|
only the first of these is emitted. That is, when the top path
|
||||||
|
is `A B B B B`, the return value is:
|
||||||
|
|
||||||
|
* `A B` if `merge_repeated = True`.
|
||||||
|
* `A B B B B` if `merge_repeated = False`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: 3-D `float` `Tensor`, size
|
||||||
|
`[max_time x batch_size x num_classes]`. The logits.
|
||||||
|
sequence_length: 1-D `int32` vector containing sequence lengths,
|
||||||
|
having size `[batch_size]`.
|
||||||
|
beam_width: An int scalar >= 0 (beam search beam width).
|
||||||
|
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
||||||
|
merge_repeated: Boolean. Default: True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple `(decoded, log_probabilities)` where
|
||||||
|
decoded: A list of length top_paths, where `decoded[j]`
|
||||||
|
is a `SparseTensor` containing the decoded outputs:
|
||||||
|
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
|
||||||
|
The rows store: [batch, time].
|
||||||
|
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
|
||||||
|
The vector stores the decoded classes for beam j.
|
||||||
|
`decoded[j].shape`: Shape vector, size `(2)`.
|
||||||
|
The shape values are: `[batch_size, max_decoded_length[j]]`.
|
||||||
|
log_probability: A `float` matrix `(batch_size x top_paths)` containing
|
||||||
|
sequence log-probabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
|
||||||
|
custom_op_module.ctc_beam_search_decoder_with_lm(
|
||||||
|
inputs, sequence_length, model_path="data/lm/lm.binary", trie_path="data/lm/trie", alphabet_path="data/alphabet.txt",
|
||||||
|
beam_width=beam_width, top_paths=top_paths, merge_repeated=merge_repeated))
|
||||||
|
|
||||||
|
return (
|
||||||
|
[tf.SparseTensor(ix, val, shape) for (ix, val, shape)
|
||||||
|
in zip(decoded_ixs, decoded_vals, decoded_shapes)],
|
||||||
|
log_probabilities)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Accuracy and Loss
|
# Accuracy and Loss
|
||||||
# =================
|
# =================
|
||||||
@ -485,7 +539,7 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout):
|
|||||||
avg_loss = tf.reduce_mean(total_loss)
|
avg_loss = tf.reduce_mean(total_loss)
|
||||||
|
|
||||||
# Beam search decode the batch
|
# Beam search decode the batch
|
||||||
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False)
|
decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=1024)
|
||||||
|
|
||||||
# Compute the edit (Levenshtein) distance
|
# Compute the edit (Levenshtein) distance
|
||||||
distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y)
|
distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y)
|
||||||
@ -718,9 +772,8 @@ def calculate_report(results_tuple):
|
|||||||
items = list(zip(*results_tuple))
|
items = list(zip(*results_tuple))
|
||||||
mean_wer = 0.0
|
mean_wer = 0.0
|
||||||
for label, decoding, distance, loss in items:
|
for label, decoding, distance, loss in items:
|
||||||
corrected = correction(decoding, alphabet)
|
sample_wer = wer(label, decoding)
|
||||||
sample_wer = wer(label, corrected)
|
sample = Sample(label, decoding, loss, distance, sample_wer)
|
||||||
sample = Sample(label, corrected, loss, distance, sample_wer)
|
|
||||||
samples.append(sample)
|
samples.append(sample)
|
||||||
mean_wer += sample_wer
|
mean_wer += sample_wer
|
||||||
|
|
||||||
|
3
data/lm/trie
Normal file
3
data/lm/trie
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:55da7b52ddb19f46301a31d56aff35ed1508fd9bd1e04d907114d89771892219
|
||||||
|
size 43550329
|
@ -12,4 +12,3 @@ python_speech_features
|
|||||||
pyxdg
|
pyxdg
|
||||||
bs4
|
bs4
|
||||||
six
|
six
|
||||||
https://github.com/kpu/kenlm/archive/master.zip
|
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
import re
|
|
||||||
import kenlm
|
|
||||||
from heapq import heapify
|
|
||||||
from six.moves import range
|
|
||||||
|
|
||||||
# Define beam with for alt sentence search
|
|
||||||
BEAM_WIDTH = 1024
|
|
||||||
MODEL = None
|
|
||||||
|
|
||||||
# Lazy-load language model (TED corpus, Kneser-Ney, 4-gram, 30k word LM)
|
|
||||||
def get_model():
|
|
||||||
global MODEL
|
|
||||||
if MODEL is None:
|
|
||||||
MODEL = kenlm.Model('./data/lm/lm.binary')
|
|
||||||
return MODEL
|
|
||||||
|
|
||||||
def words(text):
|
|
||||||
"List of words in text."
|
|
||||||
return re.findall(r'\w+', text.lower())
|
|
||||||
|
|
||||||
# Load known word set
|
|
||||||
with open('./data/spell/words.txt') as f:
|
|
||||||
WORDS = set(words(f.read()))
|
|
||||||
|
|
||||||
def log_probability(sentence):
|
|
||||||
"Log base 10 probability of `sentence`, a list of words"
|
|
||||||
return get_model().score(' '.join(sentence), bos = False, eos = False)
|
|
||||||
|
|
||||||
def correction(sentence, alphabet):
|
|
||||||
"Most probable spelling correction for sentence."
|
|
||||||
layer = [(0,[])]
|
|
||||||
for word in words(sentence):
|
|
||||||
layer = [(-log_probability(node + [cword]), node + [cword]) for cword in candidate_words(word, alphabet) for priority, node in layer]
|
|
||||||
heapify(layer)
|
|
||||||
layer = layer[:BEAM_WIDTH]
|
|
||||||
return ' '.join(layer[0][1])
|
|
||||||
|
|
||||||
def candidate_words(word, alphabet):
|
|
||||||
"Generate possible spelling corrections for word."
|
|
||||||
return (known_words([word]) or known_words(edits1(word, alphabet)) or known_words(edits2(word, alphabet)) or [word])
|
|
||||||
|
|
||||||
def known_words(words):
|
|
||||||
"The subset of `words` that appear in the dictionary of WORDS."
|
|
||||||
return set(w for w in words if w in WORDS)
|
|
||||||
|
|
||||||
def edits1(word, alphabet):
|
|
||||||
"All edits that are one edit away from `word`."
|
|
||||||
letters = [alphabet.string_from_label(i) for i in range(alphabet.size())]
|
|
||||||
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
|
|
||||||
deletes = [L + R[1:] for L, R in splits if R]
|
|
||||||
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
|
|
||||||
replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
|
|
||||||
inserts = [L + c + R for L, R in splits for c in letters]
|
|
||||||
return set(deletes + transposes + replaces + inserts)
|
|
||||||
|
|
||||||
def edits2(word, alphabet):
|
|
||||||
"All edits that are two edits away from `word`."
|
|
||||||
return (e2 for e1 in edits1(word, alphabet) for e2 in edits1(e1, alphabet))
|
|
Loading…
Reference in New Issue
Block a user