Merge pull request #1791 from mozilla/fix-alphabet-handling

Fix handling of Alphabet around evaluate.py
This commit is contained in:
Reuben Morais 2018-12-26 19:20:54 +00:00 committed by GitHub
commit ce551f5385
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 10 deletions

View File

@ -655,8 +655,7 @@ def test():
hdf5_cache_path=FLAGS.test_cached_features_path)
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
evaluate.evaluate(test_data, graph, Config.alphabet)
evaluate.evaluate(test_data, graph)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):

View File

@ -82,7 +82,7 @@ def calculate_report(labels, decodings, distances, losses):
return samples_wer, samples
def evaluate(test_data, inference_graph, alphabet):
def evaluate(test_data, inference_graph):
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
Config.alphabet)
@ -175,10 +175,10 @@ def evaluate(test_data, inference_graph, alphabet):
# Second pass, decode logits and compute WER and edit distance metrics
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
seq_lengths = batch['features_len'].values.astype(np.int32)
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width,
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer)
ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript'])
predictions.extend(d[0][1] for d in decoded)
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
@ -211,14 +211,11 @@ def main(_):
'the --test_files flag.')
exit(1)
global alphabet
alphabet = Alphabet(FLAGS.alphabet_config_path)
# sort examples by length, improves packing of batches and timesteps
test_data = preprocess(
FLAGS.test_files.split(','),
FLAGS.test_batch_size,
alphabet=alphabet,
alphabet=Config.alphabet,
numcep=Config.n_input,
numcontext=Config.n_context,
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
@ -228,7 +225,7 @@ def main(_):
from DeepSpeech import create_inference_graph
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
samples = evaluate(test_data, graph, alphabet)
samples = evaluate(test_data, graph)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats