From 4302a5f7677d583d98d2119c8e3068d30cff3406 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 30 Sep 2019 11:43:00 +0200 Subject: [PATCH] Make language model scoring optional in Python inference code --- DeepSpeech.py | 9 ++++++--- evaluate.py | 9 ++++++--- util/flags.py | 8 -------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 8ba9d02f..a2dd045a 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -870,9 +870,12 @@ def do_single_file_inference(input_file_path): logits = np.squeeze(logits) - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, - FLAGS.lm_binary_path, FLAGS.lm_trie_path, - Config.alphabet) + if FLAGS.lm_binary_path: + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, + FLAGS.lm_binary_path, FLAGS.lm_trie_path, + Config.alphabet) + else: + scorer = None decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, scorer=scorer) # Print highest probability result print(decoded[0][1]) diff --git a/evaluate.py b/evaluate.py index 32c45367..4cda4523 100755 --- a/evaluate.py +++ b/evaluate.py @@ -42,9 +42,12 @@ def sparse_tuple_to_texts(sp_tuple, alphabet): def evaluate(test_csvs, create_model, try_loading): - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, - FLAGS.lm_binary_path, FLAGS.lm_trie_path, - Config.alphabet) + if FLAGS.lm_binary_path: + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, + FLAGS.lm_binary_path, FLAGS.lm_trie_path, + Config.alphabet) + else: + scorer = None test_csvs = FLAGS.test_files.split(',') test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] diff --git a/util/flags.py b/util/flags.py index fd86fa6c..80251cb0 100644 --- a/util/flags.py +++ b/util/flags.py @@ -143,14 +143,6 @@ def create_flags(): # Register validators for paths which require a file to be specified - f.register_validator('lm_binary_path', - os.path.isfile, - message='The file pointed to by --lm_binary_path must exist and be readable.') - - f.register_validator('lm_trie_path', - os.path.isfile, - message='The file pointed to by --lm_trie_path must exist and be readable.') - f.register_validator('alphabet_config_path', os.path.isfile, message='The file pointed to by --alphabet_config_path must exist and be readable.')