Remove sequence lenghts input (now using masking)
This commit is contained in:
parent
9b738dd70d
commit
966b9971cf
@ -213,14 +213,10 @@ TFModelState::infer(const std::vector<float>& mfcc,
|
|||||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||||
|
|
||||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
|
||||||
input_lengths.scalar<int>()() = n_frames;
|
|
||||||
|
|
||||||
vector<Tensor> outputs;
|
vector<Tensor> outputs;
|
||||||
Status status = session_->Run(
|
Status status = session_->Run(
|
||||||
{
|
{
|
||||||
{"input_node", input},
|
{"input_node", input},
|
||||||
{"input_lengths", input_lengths},
|
|
||||||
{"previous_state_c", previous_state_c_t},
|
{"previous_state_c", previous_state_c_t},
|
||||||
{"previous_state_h", previous_state_h_t}
|
{"previous_state_h", previous_state_h_t}
|
||||||
},
|
},
|
||||||
|
@ -636,7 +636,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||||
# there if this shape is changed.
|
# there if this shape is changed.
|
||||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
|
||||||
|
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
# no state management since n_step is expected to be dynamic too (see below)
|
# no state management since n_step is expected to be dynamic too (see below)
|
||||||
@ -668,7 +667,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
return (
|
return (
|
||||||
{
|
{
|
||||||
'input': input_tensor,
|
'input': input_tensor,
|
||||||
'input_lengths': seq_length,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'outputs': probs,
|
'outputs': probs,
|
||||||
@ -688,9 +686,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
'input_samples': input_samples,
|
'input_samples': input_samples,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
|
||||||
inputs['input_lengths'] = seq_length
|
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
'outputs': probs,
|
'outputs': probs,
|
||||||
'new_state_c': new_state_c,
|
'new_state_c': new_state_c,
|
||||||
@ -844,7 +839,6 @@ def do_single_file_inference(input_file_path):
|
|||||||
(outputs['outputs'], outputs['new_state_c'], outputs['new_state_h']),
|
(outputs['outputs'], outputs['new_state_c'], outputs['new_state_h']),
|
||||||
feed_dict={
|
feed_dict={
|
||||||
inputs['input']: input_chunk,
|
inputs['input']: input_chunk,
|
||||||
inputs['input_lengths']: [chunk_len],
|
|
||||||
inputs['previous_state_c']: previous_state_c,
|
inputs['previous_state_c']: previous_state_c,
|
||||||
inputs['previous_state_h']: previous_state_h,
|
inputs['previous_state_h']: previous_state_h,
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user