Merge pull request #1926 from nicolaspanel/I1923
enable dynamic batch size through --export_batch_size -1 (see #1923)
This commit is contained in:
commit
eb1c0f9853
@ -664,18 +664,23 @@ def test():
|
|||||||
|
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node')
|
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node')
|
||||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||||
|
|
||||||
if not tflite:
|
if batch_size <= 0:
|
||||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
# no state management since n_step is expected to be dynamic too (see below)
|
||||||
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
previous_state = previous_state_c = previous_state_h = None
|
||||||
else:
|
else:
|
||||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
if not tflite:
|
||||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||||
|
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||||
|
else:
|
||||||
|
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||||
|
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||||
|
|
||||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||||
|
|
||||||
no_dropout = [0.0] * 6
|
no_dropout = [0.0] * 6
|
||||||
|
|
||||||
@ -696,9 +701,23 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
# Apply softmax for CTC decoder
|
# Apply softmax for CTC decoder
|
||||||
logits = tf.nn.softmax(logits)
|
logits = tf.nn.softmax(logits)
|
||||||
|
|
||||||
new_state_c, new_state_h = layers['rnn_output_state']
|
if batch_size <= 0:
|
||||||
|
if tflite:
|
||||||
|
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||||
|
if n_steps > 0:
|
||||||
|
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
'input': input_tensor,
|
||||||
|
'input_lengths': seq_length,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'outputs': tf.identity(logits, name='logits'),
|
||||||
|
},
|
||||||
|
layers
|
||||||
|
)
|
||||||
|
|
||||||
# Initial zero state
|
new_state_c, new_state_h = layers['rnn_output_state']
|
||||||
if not tflite:
|
if not tflite:
|
||||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||||
@ -857,7 +876,6 @@ def do_single_file_inference(input_file_path):
|
|||||||
|
|
||||||
checkpoint_path = checkpoint.model_checkpoint_path
|
checkpoint_path = checkpoint.model_checkpoint_path
|
||||||
saver.restore(session, checkpoint_path)
|
saver.restore(session, checkpoint_path)
|
||||||
|
|
||||||
session.run(outputs['initialize_state'])
|
session.run(outputs['initialize_state'])
|
||||||
|
|
||||||
features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)
|
features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user