Merge pull request #1926 from nicolaspanel/I1923

enable dynamic batch size through --export_batch_size -1 (see #1923)
This commit is contained in:
Reuben Morais 2019-03-06 19:53:43 -03:00 committed by GitHub
commit eb1c0f9853
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -664,18 +664,23 @@ def test():
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
if not tflite:
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None
else:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
if not tflite:
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
else:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
no_dropout = [0.0] * 6
@ -696,9 +701,23 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits)
new_state_c, new_state_h = layers['rnn_output_state']
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': tf.identity(logits, name='logits'),
},
layers
)
# Initial zero state
new_state_c, new_state_h = layers['rnn_output_state']
if not tflite:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
@ -857,7 +876,6 @@ def do_single_file_inference(input_file_path):
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
session.run(outputs['initialize_state'])
features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)