Merge pull request #3395 from CatalinVoss/patch-1

Minor Training Variable Consistency fix
This commit is contained in:
Reuben Morais 2020-11-03 21:50:59 +01:00 committed by GitHub
commit b72e2643c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions

View File

@ -730,7 +730,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
logits = tf.squeeze(logits, [1]) logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder # Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits') probs = tf.nn.softmax(logits, name='logits')
if batch_size <= 0: if batch_size <= 0:
if tflite: if tflite:
@ -743,7 +743,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'input_lengths': seq_length, 'input_lengths': seq_length,
}, },
{ {
'outputs': logits, 'outputs': probs,
}, },
layers layers
) )
@ -763,7 +763,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
inputs['input_lengths'] = seq_length inputs['input_lengths'] = seq_length
outputs = { outputs = {
'outputs': logits, 'outputs': probs,
'new_state_c': new_state_c, 'new_state_c': new_state_c,
'new_state_h': new_state_h, 'new_state_h': new_state_h,
'mfccs': mfccs, 'mfccs': mfccs,
@ -900,21 +900,21 @@ def do_single_file_inference(input_file_path):
features = create_overlapping_windows(features).eval(session=session) features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session) features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={ probs = outputs['outputs'].eval(feed_dict={
inputs['input']: features, inputs['input']: features,
inputs['input_lengths']: features_len, inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c, inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h, inputs['previous_state_h']: previous_state_h,
}, session=session) }, session=session)
logits = np.squeeze(logits) probs = np.squeeze(probs)
if FLAGS.scorer_path: if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet) FLAGS.scorer_path, Config.alphabet)
else: else:
scorer = None scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n) cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result # Print highest probability result