diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index ded45d2d..8bf7a354 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -730,7 +730,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): logits = tf.squeeze(logits, [1]) # Apply softmax for CTC decoder - logits = tf.nn.softmax(logits, name='logits') + probs = tf.nn.softmax(logits, name='logits') if batch_size <= 0: if tflite: @@ -743,7 +743,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): 'input_lengths': seq_length, }, { - 'outputs': logits, + 'outputs': probs, }, layers ) @@ -763,7 +763,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): inputs['input_lengths'] = seq_length outputs = { - 'outputs': logits, + 'outputs': probs, 'new_state_c': new_state_c, 'new_state_h': new_state_h, 'mfccs': mfccs,