diff --git a/DeepSpeech.py b/DeepSpeech.py index e128ac42..4a09383b 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -1785,8 +1785,8 @@ def train(server=None): sys.exit(1) def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False): - # Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context] - input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, n_input + 2*n_input*n_context], name='input_node') + # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] + input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)