Merge pull request #3395 from CatalinVoss/patch-1
Minor Training Variable Consistency fix
This commit is contained in:
commit
b72e2643c4
|
@ -730,7 +730,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
logits = tf.squeeze(logits, [1])
|
logits = tf.squeeze(logits, [1])
|
||||||
|
|
||||||
# Apply softmax for CTC decoder
|
# Apply softmax for CTC decoder
|
||||||
logits = tf.nn.softmax(logits, name='logits')
|
probs = tf.nn.softmax(logits, name='logits')
|
||||||
|
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
if tflite:
|
if tflite:
|
||||||
|
@ -743,7 +743,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
'input_lengths': seq_length,
|
'input_lengths': seq_length,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'outputs': logits,
|
'outputs': probs,
|
||||||
},
|
},
|
||||||
layers
|
layers
|
||||||
)
|
)
|
||||||
|
@ -763,7 +763,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
inputs['input_lengths'] = seq_length
|
inputs['input_lengths'] = seq_length
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
'outputs': logits,
|
'outputs': probs,
|
||||||
'new_state_c': new_state_c,
|
'new_state_c': new_state_c,
|
||||||
'new_state_h': new_state_h,
|
'new_state_h': new_state_h,
|
||||||
'mfccs': mfccs,
|
'mfccs': mfccs,
|
||||||
|
@ -900,21 +900,21 @@ def do_single_file_inference(input_file_path):
|
||||||
features = create_overlapping_windows(features).eval(session=session)
|
features = create_overlapping_windows(features).eval(session=session)
|
||||||
features_len = features_len.eval(session=session)
|
features_len = features_len.eval(session=session)
|
||||||
|
|
||||||
logits = outputs['outputs'].eval(feed_dict={
|
probs = outputs['outputs'].eval(feed_dict={
|
||||||
inputs['input']: features,
|
inputs['input']: features,
|
||||||
inputs['input_lengths']: features_len,
|
inputs['input_lengths']: features_len,
|
||||||
inputs['previous_state_c']: previous_state_c,
|
inputs['previous_state_c']: previous_state_c,
|
||||||
inputs['previous_state_h']: previous_state_h,
|
inputs['previous_state_h']: previous_state_h,
|
||||||
}, session=session)
|
}, session=session)
|
||||||
|
|
||||||
logits = np.squeeze(logits)
|
probs = np.squeeze(probs)
|
||||||
|
|
||||||
if FLAGS.scorer_path:
|
if FLAGS.scorer_path:
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
FLAGS.scorer_path, Config.alphabet)
|
FLAGS.scorer_path, Config.alphabet)
|
||||||
else:
|
else:
|
||||||
scorer = None
|
scorer = None
|
||||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width,
|
||||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||||
# Print highest probability result
|
# Print highest probability result
|
||||||
|
|
Loading…
Reference in New Issue