Merge pull request #1791 from mozilla/fix-alphabet-handling
Fix handling of Alphabet around evaluate.py
This commit is contained in:
commit
ce551f5385
@ -655,8 +655,7 @@ def test():
|
||||
hdf5_cache_path=FLAGS.test_cached_features_path)
|
||||
|
||||
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
|
||||
|
||||
evaluate.evaluate(test_data, graph, Config.alphabet)
|
||||
evaluate.evaluate(test_data, graph)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
|
13
evaluate.py
13
evaluate.py
@ -82,7 +82,7 @@ def calculate_report(labels, decodings, distances, losses):
|
||||
return samples_wer, samples
|
||||
|
||||
|
||||
def evaluate(test_data, inference_graph, alphabet):
|
||||
def evaluate(test_data, inference_graph):
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||
Config.alphabet)
|
||||
@ -175,10 +175,10 @@ def evaluate(test_data, inference_graph, alphabet):
|
||||
# Second pass, decode logits and compute WER and edit distance metrics
|
||||
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
|
||||
seq_lengths = batch['features_len'].values.astype(np.int32)
|
||||
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width,
|
||||
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer)
|
||||
|
||||
ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
|
||||
ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript'])
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
|
||||
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
||||
@ -211,14 +211,11 @@ def main(_):
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
|
||||
global alphabet
|
||||
alphabet = Alphabet(FLAGS.alphabet_config_path)
|
||||
|
||||
# sort examples by length, improves packing of batches and timesteps
|
||||
test_data = preprocess(
|
||||
FLAGS.test_files.split(','),
|
||||
FLAGS.test_batch_size,
|
||||
alphabet=alphabet,
|
||||
alphabet=Config.alphabet,
|
||||
numcep=Config.n_input,
|
||||
numcontext=Config.n_context,
|
||||
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
|
||||
@ -228,7 +225,7 @@ def main(_):
|
||||
from DeepSpeech import create_inference_graph
|
||||
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
|
||||
|
||||
samples = evaluate(test_data, graph, alphabet)
|
||||
samples = evaluate(test_data, graph)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
|
Loading…
x
Reference in New Issue
Block a user