Remove sequence lenghts input (now using masking)
This commit is contained in:
parent
9b738dd70d
commit
966b9971cf
@ -213,14 +213,10 @@ TFModelState::infer(const std::vector<float>& mfcc,
|
||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run(
|
||||
{
|
||||
{"input_node", input},
|
||||
{"input_lengths", input_lengths},
|
||||
{"previous_state_c", previous_state_c_t},
|
||||
{"previous_state_h", previous_state_h_t}
|
||||
},
|
||||
|
@ -636,7 +636,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
@ -668,7 +667,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': probs,
|
||||
@ -688,9 +686,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': probs,
|
||||
'new_state_c': new_state_c,
|
||||
@ -844,7 +839,6 @@ def do_single_file_inference(input_file_path):
|
||||
(outputs['outputs'], outputs['new_state_c'], outputs['new_state_h']),
|
||||
feed_dict={
|
||||
inputs['input']: input_chunk,
|
||||
inputs['input_lengths']: [chunk_len],
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user