From a7f8e4af213888040c1d79de70379b271b575159 Mon Sep 17 00:00:00 2001 From: nicolaspanel Date: Fri, 1 Mar 2019 18:56:50 +0100 Subject: [PATCH] enable dynamic batch size through --export_batch_size -1 (see #1923) --- DeepSpeech.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 20a69326..9f9b5ce1 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -664,18 +664,23 @@ def test(): def create_inference_graph(batch_size=1, n_steps=16, tflite=False): + batch_size = batch_size if batch_size > 0 else None # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') - if not tflite: - previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None) - previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None) + if batch_size <= 0: + # no state management since n_step is expected to be dynamic too (see below) + previous_state = previous_state_c = previous_state_h = None else: - previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') - previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') + if not tflite: + previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None) + previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None) + else: + previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') + previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') - previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) + previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) no_dropout = [0.0] * 6 @@ -696,9 +701,23 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): # Apply softmax for CTC decoder logits = tf.nn.softmax(logits) - new_state_c, new_state_h = layers['rnn_output_state'] + if batch_size <= 0: + if tflite: + raise NotImplementedError('dynamic batch_size does not support tflite nor streaming') + if n_steps > 0: + raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too') + return ( + { + 'input': input_tensor, + 'input_lengths': seq_length, + }, + { + 'outputs': tf.identity(logits, name='logits'), + }, + layers + ) - # Initial zero state + new_state_c, new_state_h = layers['rnn_output_state'] if not tflite: zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32) initialize_c = tf.assign(previous_state_c, zero_state) @@ -857,7 +876,6 @@ def do_single_file_inference(input_file_path): checkpoint_path = checkpoint.model_checkpoint_path saver.restore(session, checkpoint_path) - session.run(outputs['initialize_state']) features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)