Merge pull request #2435 from mozilla/uplift-utf8-fixes
Uplift general fixes from UTF-8 work
This commit is contained in:
commit
44a605c8b7
@ -9,6 +9,7 @@ LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
@ -393,15 +394,18 @@ def log_grads_and_vars(grads_and_vars):
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def try_loading(session, saver, checkpoint_filename, caption):
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||
try:
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
return False
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
if load_step:
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
else:
|
||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
log_error(str(e))
|
||||
@ -484,11 +488,9 @@ def train():
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
||||
checkpoint_filename = 'checkpoint'
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
||||
best_dev_filename = 'best_dev_checkpoint'
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
|
||||
@ -514,7 +516,7 @@ def train():
|
||||
'a CPU-capable graph. If your system is capable of '
|
||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||
'checkpoint normally with --checkpoint_dir.')
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
||||
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
||||
@ -532,7 +534,7 @@ def train():
|
||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||
'more missing variables than just the Adam moment '
|
||||
'tensors.')
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
||||
# RNN checkpoints.
|
||||
@ -544,9 +546,9 @@ def train():
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
||||
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
||||
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
if FLAGS.load in ['auto', 'init']:
|
||||
log_info('Initializing variables...')
|
||||
@ -643,7 +645,7 @@ def train():
|
||||
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
@ -668,6 +670,9 @@ def train():
|
||||
|
||||
def test():
|
||||
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
@ -859,15 +864,14 @@ def do_single_file_inference(input_file_path):
|
||||
saver = tfv1.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
|
||||
# over-fitting, we may want to restore an earlier checkpoint.
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
if not checkpoint:
|
||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
||||
exit(1)
|
||||
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
loaded = False
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
|
||||
if not loaded:
|
||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
|
12
evaluate.py
12
evaluate.py
@ -85,12 +85,14 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
# Restore variables from training checkpoint
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
loaded = False
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
||||
if not loaded:
|
||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
||||
exit(1)
|
||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
@ -160,7 +162,7 @@ def main(_):
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
@ -60,8 +60,10 @@ DecoderState::next(const double *probs,
|
||||
bool full_beam = false;
|
||||
if (ext_scorer_ != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
||||
std::sort(
|
||||
prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare);
|
||||
std::partial_sort(prefixes_.begin(),
|
||||
prefixes_.begin() + num_prefixes,
|
||||
prefixes_.end(),
|
||||
prefix_compare);
|
||||
|
||||
min_cutoff = prefixes_[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
|
||||
@ -169,7 +171,8 @@ DecoderState::decode() const
|
||||
if (!prefix->is_empty() && prefix->character != space_id_) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
||||
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
||||
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
||||
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
||||
score += ext_scorer_->beta;
|
||||
scores[prefix] += score;
|
||||
}
|
||||
@ -178,7 +181,10 @@ DecoderState::decode() const
|
||||
|
||||
using namespace std::placeholders;
|
||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
||||
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
|
||||
std::partial_sort(prefixes_copy.begin(),
|
||||
prefixes_copy.begin() + num_prefixes,
|
||||
prefixes_copy.end(),
|
||||
std::bind(prefix_compare_external, _1, _2, scores));
|
||||
|
||||
//TODO: expose this as an API parameter
|
||||
const int top_paths = 1;
|
||||
|
@ -296,7 +296,9 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
|
||||
fst::StdVectorFst dictionary;
|
||||
// For each unigram convert to ints and put in trie
|
||||
for (const auto& word : vocabulary) {
|
||||
add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||
if (word != START_TOKEN && word != UNK_TOKEN && word != END_TOKEN) {
|
||||
add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||
}
|
||||
}
|
||||
|
||||
/* Simplify FST
|
||||
|
Loading…
x
Reference in New Issue
Block a user