Bugfixes.

PiperOrigin-RevId: 217168657
This commit is contained in:
A. Unique TensorFlower 2018-10-15 10:44:19 -07:00 committed by TensorFlower Gardener
parent 4371a68427
commit 09c208bd4a
2 changed files with 9 additions and 5 deletions

View File

@ -955,9 +955,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
bw_cell_to_input_weights, bw_cell_to_forget_weights,
bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, bw_input_gate_bias,
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
bw_projection_weights, bw_projection_bias, &lstm_params,
/*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,

View File

@ -769,7 +769,6 @@ TfLiteStatus EvalFloat(
float* aux_input_to_cell_weights_ptr = nullptr;
float* aux_input_to_output_weights_ptr = nullptr;
if (aux_input_size > 0) {
aux_input_ptr = aux_input->data.f;
if (!use_cifg) {
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
}
@ -787,6 +786,9 @@ TfLiteStatus EvalFloat(
// If this is the forward_sequence, step forward, otherwise step backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
if (aux_input) {
aux_input_ptr = aux_input->data.f + t_rel * input_step;
}
float* output_ptr_time =
output->data.f + t_rel * output_step + output_offset;
@ -967,7 +969,6 @@ TfLiteStatus EvalHybrid(
float aux_input_to_cell_weights_scale = 0.0f;
float aux_input_to_output_weights_scale = 0.0f;
if (aux_input_size > 0) {
aux_input_ptr = aux_input->data.f;
if (!use_cifg) {
aux_input_to_input_weights_ptr =
reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
@ -998,6 +999,9 @@ TfLiteStatus EvalHybrid(
// If this is the forward_sequence, step forward, otherwise step backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
if (aux_input) {
aux_input_ptr = aux_input->data.f + t_rel * input_step;
}
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput(