Fix export graph

This commit is contained in:
Reuben Morais 2018-09-14 11:42:42 -03:00
parent 25a9e76afc
commit 778f5deb9d

View File

@ -1785,8 +1785,8 @@ def train(server=None):
sys.exit(1)
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, n_input + 2*n_input*n_context], name='input_node')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)