Merge pull request #2435 from mozilla/uplift-utf8-fixes

Uplift general fixes from UTF-8 work
This commit is contained in:
Reuben Morais 2019-10-25 09:09:48 +00:00 committed by GitHub
commit 44a605c8b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 29 deletions

View File

@ -9,6 +9,7 @@ LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
import absl.app
import json
import numpy as np
import progressbar
import shutil
@ -393,15 +394,18 @@ def log_grads_and_vars(grads_and_vars):
log_variable(variable, gradient=gradient)
def try_loading(session, saver, checkpoint_filename, caption):
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
if not checkpoint:
return False
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
if load_step:
restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
else:
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
return True
except tf.errors.InvalidArgumentError as e:
log_error(str(e))
@ -484,11 +488,9 @@ def train():
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
checkpoint_filename = 'checkpoint'
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
best_dev_filename = 'best_dev_checkpoint'
# Save flags next to checkpoints
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
@ -514,7 +516,7 @@ def train():
'a CPU-capable graph. If your system is capable of '
'using CuDNN RNN, you can just specify the CuDNN RNN '
'checkpoint normally with --checkpoint_dir.')
exit(1)
sys.exit(1)
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
@ -532,7 +534,7 @@ def train():
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors.')
exit(1)
sys.exit(1)
# Initialize Adam moment tensors from scratch to allow use of CuDNN
# RNN checkpoints.
@ -544,9 +546,9 @@ def train():
tfv1.get_default_graph().finalize()
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
if not loaded:
if FLAGS.load in ['auto', 'init']:
log_info('Initializing variables...')
@ -643,7 +645,7 @@ def train():
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
@ -668,6 +670,9 @@ def train():
def test():
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
@ -859,15 +864,14 @@ def do_single_file_inference(input_file_path):
saver = tfv1.train.Saver()
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if not checkpoint:
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
exit(1)
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
loaded = False
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
if not loaded:
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])

View File

@ -85,12 +85,14 @@ def evaluate(test_csvs, create_model, try_loading):
with tfv1.Session(config=Config.session_config) as session:
# Restore variables from training checkpoint
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded:
loaded = False
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
if not loaded:
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
exit(1)
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
def run_test(init_op, dataset):
wav_filenames = []
@ -160,7 +162,7 @@ def main(_):
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
exit(1)
sys.exit(1)
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)

View File

@ -60,8 +60,10 @@ DecoderState::next(const double *probs,
bool full_beam = false;
if (ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
std::sort(
prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare);
std::partial_sort(prefixes_.begin(),
prefixes_.begin() + num_prefixes,
prefixes_.end(),
prefix_compare);
min_cutoff = prefixes_[num_prefixes - 1]->score +
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
@ -169,7 +171,8 @@ DecoderState::decode() const
if (!prefix->is_empty() && prefix->character != space_id_) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
score += ext_scorer_->beta;
scores[prefix] += score;
}
@ -178,7 +181,10 @@ DecoderState::decode() const
using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes,
prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores));
//TODO: expose this as an API parameter
const int top_paths = 1;

View File

@ -296,7 +296,9 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie
for (const auto& word : vocabulary) {
add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
if (word != START_TOKEN && word != UNK_TOKEN && word != END_TOKEN) {
add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
}
}
/* Simplify FST