Fix export graph
This commit is contained in:
parent
25a9e76afc
commit
778f5deb9d
@ -1785,8 +1785,8 @@ def train(server=None):
|
||||
sys.exit(1)
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
||||
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, n_input + 2*n_input*n_context], name='input_node')
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
|
||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user