diff --git a/DeepSpeech.py b/DeepSpeech.py index c8afc030..0f3f5d9f 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -655,8 +655,7 @@ def test(): hdf5_cache_path=FLAGS.test_cached_features_path) graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) - - evaluate.evaluate(test_data, graph, Config.alphabet) + evaluate.evaluate(test_data, graph) def create_inference_graph(batch_size=1, n_steps=16, tflite=False): diff --git a/evaluate.py b/evaluate.py index e6967343..ebbbd4f8 100755 --- a/evaluate.py +++ b/evaluate.py @@ -82,7 +82,7 @@ def calculate_report(labels, decodings, distances, losses): return samples_wer, samples -def evaluate(test_data, inference_graph, alphabet): +def evaluate(test_data, inference_graph): scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) @@ -175,10 +175,10 @@ def evaluate(test_data, inference_graph, alphabet): # Second pass, decode logits and compute WER and edit distance metrics for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))): seq_lengths = batch['features_len'].values.astype(np.int32) - decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width, + decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer) - ground_truths.extend(alphabet.decode(l) for l in batch['transcript']) + ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript']) predictions.extend(d[0][1] for d in decoded) distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)] @@ -211,14 +211,11 @@ def main(_): 'the --test_files flag.') exit(1) - global alphabet - alphabet = Alphabet(FLAGS.alphabet_config_path) - # sort examples by length, improves packing of batches and timesteps test_data = preprocess( FLAGS.test_files.split(','), FLAGS.test_batch_size, - alphabet=alphabet, + alphabet=Config.alphabet, numcep=Config.n_input, numcontext=Config.n_context, hdf5_cache_path=FLAGS.hdf5_test_set).sort_values( @@ -228,7 +225,7 @@ def main(_): from DeepSpeech import create_inference_graph graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) - samples = evaluate(test_data, graph, alphabet) + samples = evaluate(test_data, graph) if FLAGS.test_output_file: # Save decoded tuples as JSON, converting NumPy floats to Python floats