diff --git a/DeepSpeech.py b/DeepSpeech.py index ba43d385..e57267ff 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -9,6 +9,7 @@ LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' import absl.app +import json import numpy as np import progressbar import shutil @@ -393,15 +394,18 @@ def log_grads_and_vars(grads_and_vars): log_variable(variable, gradient=gradient) -def try_loading(session, saver, checkpoint_filename, caption): +def try_loading(session, saver, checkpoint_filename, caption, load_step=True): try: checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename) if not checkpoint: return False checkpoint_path = checkpoint.model_checkpoint_path saver.restore(session, checkpoint_path) - restored_step = session.run(tfv1.train.get_global_step()) - log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step)) + if load_step: + restored_step = session.run(tfv1.train.get_global_step()) + log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step)) + else: + log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path)) return True except tf.errors.InvalidArgumentError as e: log_error(str(e)) @@ -484,11 +488,9 @@ def train(): # Checkpointing checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') - checkpoint_filename = 'checkpoint' best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') - best_dev_filename = 'best_dev_checkpoint' # Save flags next to checkpoints os.makedirs(FLAGS.checkpoint_dir, exist_ok=True) @@ -514,7 +516,7 @@ def train(): 'a CPU-capable graph. If your system is capable of ' 'using CuDNN RNN, you can just specify the CuDNN RNN ' 'checkpoint normally with --checkpoint_dir.') - exit(1) + sys.exit(1) log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint)) ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint) @@ -532,7 +534,7 @@ def train(): log_error('Tried to load a CuDNN RNN checkpoint but there were ' 'more missing variables than just the Adam moment ' 'tensors.') - exit(1) + sys.exit(1) # Initialize Adam moment tensors from scratch to allow use of CuDNN # RNN checkpoints. @@ -544,9 +546,9 @@ def train(): tfv1.get_default_graph().finalize() if not loaded and FLAGS.load in ['auto', 'last']: - loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent') + loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent') if not loaded and FLAGS.load in ['auto', 'best']: - loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation') + loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation') if not loaded: if FLAGS.load in ['auto', 'init']: log_info('Initializing variables...') @@ -643,7 +645,7 @@ def train(): if dev_loss < best_dev_loss: best_dev_loss = dev_loss - save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename) + save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping @@ -668,6 +670,9 @@ def train(): def test(): evaluate(FLAGS.test_files.split(','), create_model, try_loading) + if FLAGS.test_output_file: + # Save decoded tuples as JSON, converting NumPy floats to Python floats + json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float) def create_inference_graph(batch_size=1, n_steps=16, tflite=False): @@ -859,15 +864,14 @@ def do_single_file_inference(input_file_path): saver = tfv1.train.Saver() # Restore variables from training checkpoint - # TODO: This restores the most recent checkpoint, but if we use validation to counteract - # over-fitting, we may want to restore an earlier checkpoint. - checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) - if not checkpoint: - log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir)) - exit(1) - - checkpoint_path = checkpoint.model_checkpoint_path - saver.restore(session, checkpoint_path) + loaded = False + if not loaded and FLAGS.load in ['auto', 'last']: + loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False) + if not loaded and FLAGS.load in ['auto', 'best']: + loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False) + if not loaded: + print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir)) + sys.exit(1) features, features_len = audiofile_to_features(input_file_path) previous_state_c = np.zeros([1, Config.n_cell_dim]) diff --git a/evaluate.py b/evaluate.py index 679d3fe2..435a6be8 100755 --- a/evaluate.py +++ b/evaluate.py @@ -85,12 +85,14 @@ def evaluate(test_csvs, create_model, try_loading): with tfv1.Session(config=Config.session_config) as session: # Restore variables from training checkpoint - loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') - if not loaded: + loaded = False + if not loaded and FLAGS.load in ['auto', 'best']: + loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') + if not loaded and FLAGS.load in ['auto', 'last']: loaded = try_loading(session, saver, 'checkpoint', 'most recent') if not loaded: - log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir)) - exit(1) + print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir)) + sys.exit(1) def run_test(init_op, dataset): wav_filenames = [] @@ -160,7 +162,7 @@ def main(_): if not FLAGS.test_files: log_error('You need to specify what files to use for evaluation via ' 'the --test_files flag.') - exit(1) + sys.exit(1) from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 64d0e338..3c481002 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -60,8 +60,10 @@ DecoderState::next(const double *probs, bool full_beam = false; if (ext_scorer_ != nullptr) { size_t num_prefixes = std::min(prefixes_.size(), beam_size_); - std::sort( - prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare); + std::partial_sort(prefixes_.begin(), + prefixes_.begin() + num_prefixes, + prefixes_.end(), + prefix_compare); min_cutoff = prefixes_[num_prefixes - 1]->score + std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta); @@ -169,7 +171,8 @@ DecoderState::decode() const if (!prefix->is_empty() && prefix->character != space_id_) { float score = 0.0; std::vector ngram = ext_scorer_->make_ngram(prefix); - score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha; + bool bos = ngram.size() < ext_scorer_->get_max_order(); + score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha; score += ext_scorer_->beta; scores[prefix] += score; } @@ -178,7 +181,10 @@ DecoderState::decode() const using namespace std::placeholders; size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_); - std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores)); + std::partial_sort(prefixes_copy.begin(), + prefixes_copy.begin() + num_prefixes, + prefixes_copy.end(), + std::bind(prefix_compare_external, _1, _2, scores)); //TODO: expose this as an API parameter const int top_paths = 1; diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 49a6d794..c265b430 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -296,7 +296,9 @@ void Scorer::fill_dictionary(const std::vector& vocabulary, bool ad fst::StdVectorFst dictionary; // For each unigram convert to ints and put in trie for (const auto& word : vocabulary) { - add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); + if (word != START_TOKEN && word != UNK_TOKEN && word != END_TOKEN) { + add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); + } } /* Simplify FST