Make language model scoring optional in Python inference code
This commit is contained in:
parent
14c0db7294
commit
4302a5f767
|
@ -870,9 +870,12 @@ def do_single_file_inference(input_file_path):
|
||||||
|
|
||||||
logits = np.squeeze(logits)
|
logits = np.squeeze(logits)
|
||||||
|
|
||||||
|
if FLAGS.lm_binary_path:
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||||
Config.alphabet)
|
Config.alphabet)
|
||||||
|
else:
|
||||||
|
scorer = None
|
||||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer)
|
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer)
|
||||||
# Print highest probability result
|
# Print highest probability result
|
||||||
print(decoded[0][1])
|
print(decoded[0][1])
|
||||||
|
|
|
@ -42,9 +42,12 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_csvs, create_model, try_loading):
|
def evaluate(test_csvs, create_model, try_loading):
|
||||||
|
if FLAGS.lm_binary_path:
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||||
Config.alphabet)
|
Config.alphabet)
|
||||||
|
else:
|
||||||
|
scorer = None
|
||||||
|
|
||||||
test_csvs = FLAGS.test_files.split(',')
|
test_csvs = FLAGS.test_files.split(',')
|
||||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||||
|
|
|
@ -143,14 +143,6 @@ def create_flags():
|
||||||
|
|
||||||
# Register validators for paths which require a file to be specified
|
# Register validators for paths which require a file to be specified
|
||||||
|
|
||||||
f.register_validator('lm_binary_path',
|
|
||||||
os.path.isfile,
|
|
||||||
message='The file pointed to by --lm_binary_path must exist and be readable.')
|
|
||||||
|
|
||||||
f.register_validator('lm_trie_path',
|
|
||||||
os.path.isfile,
|
|
||||||
message='The file pointed to by --lm_trie_path must exist and be readable.')
|
|
||||||
|
|
||||||
f.register_validator('alphabet_config_path',
|
f.register_validator('alphabet_config_path',
|
||||||
os.path.isfile,
|
os.path.isfile,
|
||||||
message='The file pointed to by --alphabet_config_path must exist and be readable.')
|
message='The file pointed to by --alphabet_config_path must exist and be readable.')
|
||||||
|
|
Loading…
Reference in New Issue