diff --git a/DeepSpeech.py b/DeepSpeech.py index 7f0bf7b3..ba43d385 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -896,7 +896,9 @@ def do_single_file_inference(input_file_path): Config.alphabet) else: scorer = None - decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer) + decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, + scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, + cutoff_top_n=FLAGS.cutoff_top_n) # Print highest probability result print(decoded[0][1]) diff --git a/evaluate.py b/evaluate.py index e511c8b6..ac59c034 100755 --- a/evaluate.py +++ b/evaluate.py @@ -116,7 +116,8 @@ def evaluate(test_csvs, create_model, try_loading): break decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, - num_processes=num_processes, scorer=scorer) + num_processes=num_processes, scorer=scorer, + cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) predictions.extend(d[0][1] for d in decoded) ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) diff --git a/util/flags.py b/util/flags.py index e5b3afdd..3d59f831 100644 --- a/util/flags.py +++ b/util/flags.py @@ -138,6 +138,8 @@ def create_flags(): f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') + f.DEFINE_float('cutoff_prob', 1.0, 'only consider characters until this probability mass is reached. 1.0 = disabled.') + f.DEFINE_integer('cutoff_top_n', 300, 'only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.') # Inference mode