Merge pull request #2435 from mozilla/uplift-utf8-fixes
Uplift general fixes from UTF-8 work
This commit is contained in:
commit
44a605c8b7
@ -9,6 +9,7 @@ LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv
|
|||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
|
||||||
import absl.app
|
import absl.app
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import progressbar
|
import progressbar
|
||||||
import shutil
|
import shutil
|
||||||
@ -393,15 +394,18 @@ def log_grads_and_vars(grads_and_vars):
|
|||||||
log_variable(variable, gradient=gradient)
|
log_variable(variable, gradient=gradient)
|
||||||
|
|
||||||
|
|
||||||
def try_loading(session, saver, checkpoint_filename, caption):
|
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||||
try:
|
try:
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||||
if not checkpoint:
|
if not checkpoint:
|
||||||
return False
|
return False
|
||||||
checkpoint_path = checkpoint.model_checkpoint_path
|
checkpoint_path = checkpoint.model_checkpoint_path
|
||||||
saver.restore(session, checkpoint_path)
|
saver.restore(session, checkpoint_path)
|
||||||
restored_step = session.run(tfv1.train.get_global_step())
|
if load_step:
|
||||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
restored_step = session.run(tfv1.train.get_global_step())
|
||||||
|
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||||
|
else:
|
||||||
|
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||||
return True
|
return True
|
||||||
except tf.errors.InvalidArgumentError as e:
|
except tf.errors.InvalidArgumentError as e:
|
||||||
log_error(str(e))
|
log_error(str(e))
|
||||||
@ -484,11 +488,9 @@ def train():
|
|||||||
# Checkpointing
|
# Checkpointing
|
||||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||||
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
||||||
checkpoint_filename = 'checkpoint'
|
|
||||||
|
|
||||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
||||||
best_dev_filename = 'best_dev_checkpoint'
|
|
||||||
|
|
||||||
# Save flags next to checkpoints
|
# Save flags next to checkpoints
|
||||||
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
|
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
|
||||||
@ -514,7 +516,7 @@ def train():
|
|||||||
'a CPU-capable graph. If your system is capable of '
|
'a CPU-capable graph. If your system is capable of '
|
||||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||||
'checkpoint normally with --checkpoint_dir.')
|
'checkpoint normally with --checkpoint_dir.')
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
||||||
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
||||||
@ -532,7 +534,7 @@ def train():
|
|||||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||||
'more missing variables than just the Adam moment '
|
'more missing variables than just the Adam moment '
|
||||||
'tensors.')
|
'tensors.')
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
||||||
# RNN checkpoints.
|
# RNN checkpoints.
|
||||||
@ -544,9 +546,9 @@ def train():
|
|||||||
tfv1.get_default_graph().finalize()
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
|
||||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
|
||||||
if not loaded:
|
if not loaded:
|
||||||
if FLAGS.load in ['auto', 'init']:
|
if FLAGS.load in ['auto', 'init']:
|
||||||
log_info('Initializing variables...')
|
log_info('Initializing variables...')
|
||||||
@ -643,7 +645,7 @@ def train():
|
|||||||
|
|
||||||
if dev_loss < best_dev_loss:
|
if dev_loss < best_dev_loss:
|
||||||
best_dev_loss = dev_loss
|
best_dev_loss = dev_loss
|
||||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
|
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
@ -668,6 +670,9 @@ def train():
|
|||||||
|
|
||||||
def test():
|
def test():
|
||||||
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||||
|
if FLAGS.test_output_file:
|
||||||
|
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||||
|
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
@ -859,15 +864,14 @@ def do_single_file_inference(input_file_path):
|
|||||||
saver = tfv1.train.Saver()
|
saver = tfv1.train.Saver()
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
# Restore variables from training checkpoint
|
||||||
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
|
loaded = False
|
||||||
# over-fitting, we may want to restore an earlier checkpoint.
|
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
|
||||||
if not checkpoint:
|
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
|
||||||
exit(1)
|
if not loaded:
|
||||||
|
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||||
checkpoint_path = checkpoint.model_checkpoint_path
|
sys.exit(1)
|
||||||
saver.restore(session, checkpoint_path)
|
|
||||||
|
|
||||||
features, features_len = audiofile_to_features(input_file_path)
|
features, features_len = audiofile_to_features(input_file_path)
|
||||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||||
|
|||||||
12
evaluate.py
12
evaluate.py
@ -85,12 +85,14 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
# Restore variables from training checkpoint
|
# Restore variables from training checkpoint
|
||||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
loaded = False
|
||||||
if not loaded:
|
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||||
|
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||||
|
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
||||||
if not loaded:
|
if not loaded:
|
||||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def run_test(init_op, dataset):
|
def run_test(init_op, dataset):
|
||||||
wav_filenames = []
|
wav_filenames = []
|
||||||
@ -160,7 +162,7 @@ def main(_):
|
|||||||
if not FLAGS.test_files:
|
if not FLAGS.test_files:
|
||||||
log_error('You need to specify what files to use for evaluation via '
|
log_error('You need to specify what files to use for evaluation via '
|
||||||
'the --test_files flag.')
|
'the --test_files flag.')
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||||
|
|||||||
@ -60,8 +60,10 @@ DecoderState::next(const double *probs,
|
|||||||
bool full_beam = false;
|
bool full_beam = false;
|
||||||
if (ext_scorer_ != nullptr) {
|
if (ext_scorer_ != nullptr) {
|
||||||
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
||||||
std::sort(
|
std::partial_sort(prefixes_.begin(),
|
||||||
prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare);
|
prefixes_.begin() + num_prefixes,
|
||||||
|
prefixes_.end(),
|
||||||
|
prefix_compare);
|
||||||
|
|
||||||
min_cutoff = prefixes_[num_prefixes - 1]->score +
|
min_cutoff = prefixes_[num_prefixes - 1]->score +
|
||||||
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
|
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
|
||||||
@ -169,7 +171,8 @@ DecoderState::decode() const
|
|||||||
if (!prefix->is_empty() && prefix->character != space_id_) {
|
if (!prefix->is_empty() && prefix->character != space_id_) {
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
||||||
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
||||||
|
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
||||||
score += ext_scorer_->beta;
|
score += ext_scorer_->beta;
|
||||||
scores[prefix] += score;
|
scores[prefix] += score;
|
||||||
}
|
}
|
||||||
@ -178,7 +181,10 @@ DecoderState::decode() const
|
|||||||
|
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
||||||
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
|
std::partial_sort(prefixes_copy.begin(),
|
||||||
|
prefixes_copy.begin() + num_prefixes,
|
||||||
|
prefixes_copy.end(),
|
||||||
|
std::bind(prefix_compare_external, _1, _2, scores));
|
||||||
|
|
||||||
//TODO: expose this as an API parameter
|
//TODO: expose this as an API parameter
|
||||||
const int top_paths = 1;
|
const int top_paths = 1;
|
||||||
|
|||||||
@ -296,7 +296,9 @@ void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary, bool ad
|
|||||||
fst::StdVectorFst dictionary;
|
fst::StdVectorFst dictionary;
|
||||||
// For each unigram convert to ints and put in trie
|
// For each unigram convert to ints and put in trie
|
||||||
for (const auto& word : vocabulary) {
|
for (const auto& word : vocabulary) {
|
||||||
add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
if (word != START_TOKEN && word != UNK_TOKEN && word != END_TOKEN) {
|
||||||
|
add_word_to_dictionary(word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Simplify FST
|
/* Simplify FST
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user