Remove sequence lenghts input (now using masking)

This commit is contained in:
Reuben Morais 2021-01-02 17:21:35 +00:00
parent 9b738dd70d
commit 966b9971cf
2 changed files with 0 additions and 10 deletions

View File

@ -213,14 +213,10 @@ TFModelState::infer(const std::vector<float>& mfcc,
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor input_lengths(DT_INT32, TensorShape({1}));
input_lengths.scalar<int>()() = n_frames;
vector<Tensor> outputs;
Status status = session_->Run(
{
{"input_node", input},
{"input_lengths", input_lengths},
{"previous_state_c", previous_state_c_t},
{"previous_state_h", previous_state_h_t}
},

View File

@ -636,7 +636,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
@ -668,7 +667,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': probs,
@ -688,9 +686,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': probs,
'new_state_c': new_state_c,
@ -844,7 +839,6 @@ def do_single_file_inference(input_file_path):
(outputs['outputs'], outputs['new_state_c'], outputs['new_state_h']),
feed_dict={
inputs['input']: input_chunk,
inputs['input_lengths']: [chunk_len],
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
})