Merge pull request #1926 from nicolaspanel/I1923
enable dynamic batch size through --export_batch_size -1 (see #1923)
This commit is contained in:
commit
eb1c0f9853
@ -664,18 +664,23 @@ def test():
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node')
|
||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if not tflite:
|
||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = previous_state_c = previous_state_h = None
|
||||
else:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
if not tflite:
|
||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
else:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
no_dropout = [0.0] * 6
|
||||
|
||||
@ -696,9 +701,23 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': tf.identity(logits, name='logits'),
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
# Initial zero state
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
if not tflite:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
@ -857,7 +876,6 @@ def do_single_file_inference(input_file_path):
|
||||
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
|
||||
session.run(outputs['initialize_state'])
|
||||
|
||||
features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)
|
||||
|
Loading…
x
Reference in New Issue
Block a user