Bugfixes.
PiperOrigin-RevId: 217168657
This commit is contained in:
parent
4371a68427
commit
09c208bd4a
@ -955,9 +955,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||||
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
||||||
bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
|
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
|
||||||
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
|
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
|
||||||
fw_aux_input_to_output_weights, bw_input_gate_bias,
|
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||||
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
||||||
bw_projection_weights, bw_projection_bias, &lstm_params,
|
bw_projection_weights, bw_projection_bias, &lstm_params,
|
||||||
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
|
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
|
||||||
|
@ -769,7 +769,6 @@ TfLiteStatus EvalFloat(
|
|||||||
float* aux_input_to_cell_weights_ptr = nullptr;
|
float* aux_input_to_cell_weights_ptr = nullptr;
|
||||||
float* aux_input_to_output_weights_ptr = nullptr;
|
float* aux_input_to_output_weights_ptr = nullptr;
|
||||||
if (aux_input_size > 0) {
|
if (aux_input_size > 0) {
|
||||||
aux_input_ptr = aux_input->data.f;
|
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
|
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
|
||||||
}
|
}
|
||||||
@ -787,6 +786,9 @@ TfLiteStatus EvalFloat(
|
|||||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||||
|
if (aux_input) {
|
||||||
|
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
||||||
|
}
|
||||||
float* output_ptr_time =
|
float* output_ptr_time =
|
||||||
output->data.f + t_rel * output_step + output_offset;
|
output->data.f + t_rel * output_step + output_offset;
|
||||||
|
|
||||||
@ -967,7 +969,6 @@ TfLiteStatus EvalHybrid(
|
|||||||
float aux_input_to_cell_weights_scale = 0.0f;
|
float aux_input_to_cell_weights_scale = 0.0f;
|
||||||
float aux_input_to_output_weights_scale = 0.0f;
|
float aux_input_to_output_weights_scale = 0.0f;
|
||||||
if (aux_input_size > 0) {
|
if (aux_input_size > 0) {
|
||||||
aux_input_ptr = aux_input->data.f;
|
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
aux_input_to_input_weights_ptr =
|
aux_input_to_input_weights_ptr =
|
||||||
reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
|
reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
|
||||||
@ -998,6 +999,9 @@ TfLiteStatus EvalHybrid(
|
|||||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||||
|
if (aux_input) {
|
||||||
|
aux_input_ptr = aux_input->data.f + t_rel * input_step;
|
||||||
|
}
|
||||||
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
||||||
|
|
||||||
LstmStepWithAuxInput(
|
LstmStepWithAuxInput(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user