Expose cutoff_prob and cutoff_top_n as flags
This commit is contained in:
parent
6e287bd340
commit
12baf5ffbc
@ -896,7 +896,9 @@ def do_single_file_inference(input_file_path):
|
|||||||
Config.alphabet)
|
Config.alphabet)
|
||||||
else:
|
else:
|
||||||
scorer = None
|
scorer = None
|
||||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer)
|
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||||
|
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||||
|
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||||
# Print highest probability result
|
# Print highest probability result
|
||||||
print(decoded[0][1])
|
print(decoded[0][1])
|
||||||
|
|
||||||
|
@ -116,7 +116,8 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
break
|
break
|
||||||
|
|
||||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||||
num_processes=num_processes, scorer=scorer)
|
num_processes=num_processes, scorer=scorer,
|
||||||
|
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||||
predictions.extend(d[0][1] for d in decoded)
|
predictions.extend(d[0][1] for d in decoded)
|
||||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||||
|
@ -138,6 +138,8 @@ def create_flags():
|
|||||||
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||||
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||||
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||||
|
f.DEFINE_float('cutoff_prob', 1.0, 'only consider characters until this probability mass is reached. 1.0 = disabled.')
|
||||||
|
f.DEFINE_integer('cutoff_top_n', 300, 'only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.')
|
||||||
|
|
||||||
# Inference mode
|
# Inference mode
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user