Expose cutoff_prob and cutoff_top_n as flags

This commit is contained in:
Reuben Morais 2019-10-22 14:35:29 +02:00
parent 6e287bd340
commit 12baf5ffbc
3 changed files with 7 additions and 2 deletions

View File

@ -896,7 +896,9 @@ def do_single_file_inference(input_file_path):
Config.alphabet) Config.alphabet)
else: else:
scorer = None scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer) decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result # Print highest probability result
print(decoded[0][1]) print(decoded[0][1])

View File

@ -116,7 +116,8 @@ def evaluate(test_csvs, create_model, try_loading):
break break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer) num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded) predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)

View File

@ -138,6 +138,8 @@ def create_flags():
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
f.DEFINE_float('cutoff_prob', 1.0, 'only consider characters until this probability mass is reached. 1.0 = disabled.')
f.DEFINE_integer('cutoff_top_n', 300, 'only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.')
# Inference mode # Inference mode