Bugfixes.
PiperOrigin-RevId: 217168657
This commit is contained in:
parent
4371a68427
commit
09c208bd4a
@ -955,9 +955,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
||||
bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
|
||||
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
|
||||
fw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
|
||||
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
|
||||
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
||||
bw_projection_weights, bw_projection_bias, &lstm_params,
|
||||
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
|
||||
|
@ -769,7 +769,6 @@ TfLiteStatus EvalFloat(
|
||||
float* aux_input_to_cell_weights_ptr = nullptr;
|
||||
float* aux_input_to_output_weights_ptr = nullptr;
|
||||
if (aux_input_size > 0) {
|
||||
aux_input_ptr = aux_input->data.f;
|
||||
if (!use_cifg) {
|
||||
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
|
||||
}
|
||||
@ -787,6 +786,9 @@ TfLiteStatus EvalFloat(
|
||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||
if (aux_input) {
|
||||
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
||||
}
|
||||
float* output_ptr_time =
|
||||
output->data.f + t_rel * output_step + output_offset;
|
||||
|
||||
@ -967,7 +969,6 @@ TfLiteStatus EvalHybrid(
|
||||
float aux_input_to_cell_weights_scale = 0.0f;
|
||||
float aux_input_to_output_weights_scale = 0.0f;
|
||||
if (aux_input_size > 0) {
|
||||
aux_input_ptr = aux_input->data.f;
|
||||
if (!use_cifg) {
|
||||
aux_input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
|
||||
@ -998,6 +999,9 @@ TfLiteStatus EvalHybrid(
|
||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||
if (aux_input) {
|
||||
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
||||
}
|
||||
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
||||
|
||||
LstmStepWithAuxInput(
|
||||
|
Loading…
Reference in New Issue
Block a user