diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index f1a77e2b1cb..1ce131a96ac 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -138,15 +138,19 @@ enum TemporaryTensor { kBwActivationStateQuantized = 4, kFwCellStateQuantized = 5, kBwCellStateQuantized = 6, - kScalingFactors = 7, - kProductScalingFactors = 8, - kRecoveredCellWeights = 9, - kAccumScratchBuffer = 10, - kZeroPoints = 11, - kFwRowSums = 12, - kBwRowSums = 13, - kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input. - kNumTemporaryTensors = 15 + kInputScalingFactors = 7, + kAuxInputScalingFactors = 8, + kOutputStateScalingFactors = 9, + kProductScalingFactors = 10, + kRecoveredCellWeights = 11, + kAccumScratchBuffer = 12, + kInputZeroPoints = 13, + kAuxInputZeroPoints = 14, + kOutputStateZeroPoints = 15, + kFwRowSums = 16, + kBwRowSums = 17, + kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input. + kNumTemporaryTensors = 19, }; struct OpData { @@ -699,18 +703,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // a vector once (which produces the scaling factors) and multiply it with // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). - node->temporaries->data[kScalingFactors] = - op_data->scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; + node->temporaries->data[kInputScalingFactors] = + op_data->scratch_tensor_index + kInputScalingFactors; + TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); + input_sf->type = kTfLiteFloat32; + input_sf->allocation_type = kTfLiteArenaRw; int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); + if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); + input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_sf, input_sf_size)); + } + node->temporaries->data[kAuxInputScalingFactors] = + op_data->scratch_tensor_index + kAuxInputScalingFactors; + TfLiteTensor* aux_input_sf = + GetTemporary(context, node, kAuxInputScalingFactors); + aux_input_sf->type = kTfLiteFloat32; + aux_input_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1); + aux_input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf, + aux_input_sf_size)); + } + node->temporaries->data[kOutputStateScalingFactors] = + op_data->scratch_tensor_index + kOutputStateScalingFactors; + TfLiteTensor* output_state_sf = + GetTemporary(context, node, kOutputStateScalingFactors); + output_state_sf->type = kTfLiteFloat32; + output_state_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); + output_state_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, + output_state_sf_size)); } node->temporaries->data[kProductScalingFactors] = op_data->scratch_tensor_index + kProductScalingFactors; @@ -768,16 +795,40 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Allocate temporary tensors for storing zero-points. - node->temporaries->data[kZeroPoints] = - op_data->scratch_tensor_index + kZeroPoints; - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); - zero_points->type = kTfLiteFloat32; - zero_points->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { - TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); - zero_points_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, - zero_points_size)); + node->temporaries->data[kInputZeroPoints] = + op_data->scratch_tensor_index + kInputZeroPoints; + TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); + input_zp->type = kTfLiteFloat32; + input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); + input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_zp, input_zp_size)); + } + node->temporaries->data[kAuxInputZeroPoints] = + op_data->scratch_tensor_index + kAuxInputZeroPoints; + TfLiteTensor* aux_input_zp = + GetTemporary(context, node, kAuxInputZeroPoints); + aux_input_zp->type = kTfLiteFloat32; + aux_input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1); + aux_input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp, + aux_input_zp_size)); + } + node->temporaries->data[kOutputStateZeroPoints] = + op_data->scratch_tensor_index + kOutputStateZeroPoints; + TfLiteTensor* output_state_zp = + GetTemporary(context, node, kOutputStateZeroPoints); + output_state_zp->type = kTfLiteFloat32; + output_state_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); + output_state_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, + output_state_zp_size)); } // Allocate temporary tensors for caching row sums for hybrid zero-point @@ -1071,8 +1122,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kFwCellStateQuantized); TfLiteTensor* bw_cell_state_quantized = GetTemporary(context, node, kBwCellStateQuantized); - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); TfLiteTensor* prod_scaling_factors = GetTemporary(context, node, kProductScalingFactors); TfLiteTensor* recovered_cell_weights = @@ -1082,7 +1131,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { : nullptr; TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratchBuffer); - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); const int fw_row_sums_size = fw_row_sums->dims->data[0]; @@ -1104,12 +1152,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_output_gate_bias, fw_projection_weights, fw_projection_bias, &lstm_params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, - fw_scratch_buffer, scaling_factors, prod_scaling_factors, - recovered_cell_weights, input_quantized, aux_input_quantized, - fw_activation_state_quantized, fw_cell_state_quantized, - fw_activation_state, fw_cell_state, accum_scratch, fw_output, - zero_points, fw_row_sums, fw_row_sums_size, - &op_data->compute_fw_row_sums, + fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), + GetTemporary(context, node, kAuxInputScalingFactors), + GetTemporary(context, node, kOutputStateScalingFactors), + prod_scaling_factors, recovered_cell_weights, input_quantized, + aux_input_quantized, fw_activation_state_quantized, + fw_cell_state_quantized, fw_activation_state, fw_cell_state, + accum_scratch, fw_output, + GetTemporary(context, node, kInputZeroPoints), + GetTemporary(context, node, kAuxInputZeroPoints), + GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums, + fw_row_sums_size, &op_data->compute_fw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, fw_pass_status); @@ -1130,12 +1183,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_output_gate_bias, bw_projection_weights, bw_projection_bias, &lstm_params, /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, - bw_scratch_buffer, scaling_factors, prod_scaling_factors, - recovered_cell_weights, input_quantized, aux_input_quantized, - bw_activation_state_quantized, bw_cell_state_quantized, - bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, - zero_points, bw_row_sums, bw_row_sums_size, - &op_data->compute_bw_row_sums, + bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), + GetTemporary(context, node, kAuxInputScalingFactors), + GetTemporary(context, node, kOutputStateScalingFactors), + prod_scaling_factors, recovered_cell_weights, input_quantized, + aux_input_quantized, bw_activation_state_quantized, + bw_cell_state_quantized, bw_activation_state, bw_cell_state, + accum_scratch, actual_bw_output, + GetTemporary(context, node, kInputZeroPoints), + GetTemporary(context, node, kAuxInputZeroPoints), + GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums, + bw_row_sums_size, &op_data->compute_bw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index 75de587774a..c39f715446b 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -66,13 +66,15 @@ enum HybridTemporaryTensor { kInputQuantized = 1, kOutputStateQuantized = 2, kCellStateQuantized = 3, - kScalingFactors = 4, - kProductScalingFactors = 5, - kRecoveredCellWeights = 6, - kAccumScratch = 7, - kZeroPoints = 8, - kRowSums = 9, - kNumHybridTemporaryTensors = 10, + kInputScalingFactors = 4, + kOutputStateScalingFactors = 5, + kProductScalingFactors = 6, + kRecoveredCellWeights = 7, + kAccumScratch = 8, + kInputZeroPoints = 9, + kOutputStateZeroPoints = 10, + kRowSums = 11, + kNumHybridTemporaryTensors = 12, }; TfLiteStatus PopulateQuantizedLstmParams8x8_16( @@ -1333,18 +1335,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // a vector once (which produces the scaling factors) and multiply it with // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). - node->temporaries->data[kScalingFactors] = - op_data->scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; + node->temporaries->data[kInputScalingFactors] = + op_data->scratch_tensor_index + kInputScalingFactors; + TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); + input_sf->type = kTfLiteFloat32; + input_sf->allocation_type = kTfLiteArenaRw; int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); + if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); + input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_sf, input_sf_size)); + } + node->temporaries->data[kOutputStateScalingFactors] = + op_data->scratch_tensor_index + kOutputStateScalingFactors; + TfLiteTensor* output_state_sf = + GetTemporary(context, node, kOutputStateScalingFactors); + output_state_sf->type = kTfLiteFloat32; + output_state_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); + output_state_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, + output_state_sf_size)); } node->temporaries->data[kProductScalingFactors] = op_data->scratch_tensor_index + kProductScalingFactors; @@ -1394,18 +1407,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } - - node->temporaries->data[kZeroPoints] = - op_data->scratch_tensor_index + kZeroPoints; - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); - zero_points->type = kTfLiteFloat32; - zero_points->allocation_type = kTfLiteArenaRw; - int zero_points_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { - TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); - zero_points_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, - zero_points_size)); + node->temporaries->data[kInputZeroPoints] = + op_data->scratch_tensor_index + kInputZeroPoints; + TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); + input_zp->type = kTfLiteFloat32; + input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); + input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_zp, input_zp_size)); + } + node->temporaries->data[kOutputStateZeroPoints] = + op_data->scratch_tensor_index + kOutputStateZeroPoints; + TfLiteTensor* output_state_zp = + GetTemporary(context, node, kOutputStateZeroPoints); + output_state_zp->type = kTfLiteFloat32; + output_state_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); + output_state_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, + output_state_zp_size)); } node->temporaries->data[kRowSums] = @@ -1621,7 +1644,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { projection_weights, projection_bias, params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer), - GetTemporary(context, node, kScalingFactors), + GetTemporary(context, node, kInputScalingFactors), + /*aux_input_sf=*/nullptr, + GetTemporary(context, node, kOutputStateScalingFactors), GetTemporary(context, node, kProductScalingFactors), GetTemporary(context, node, kRecoveredCellWeights), GetTemporary(context, node, kInputQuantized), @@ -1629,8 +1654,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kOutputStateQuantized), GetTemporary(context, node, kCellStateQuantized), output_state, cell_state, GetTemporary(context, node, kAccumScratch), output, - GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size, - &op_data->compute_row_sums, + GetTemporary(context, node, kInputZeroPoints), + /*aux_input_zp=*/nullptr, + GetTemporary(context, node, kOutputStateZeroPoints), row_sums, + row_sums_size, &op_data->compute_row_sums, CpuBackendContext::GetFromContext(context)); } else { const int num_intermediate_tensors = node->intermediates->size; diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 9b24dc04c26..f97411a3a97 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -785,14 +785,15 @@ inline void LstmStepHybrid( const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float* scratch0, float* scratch1, - float* scratch2, float* scratch3, float* scaling_factors, - float* scaling_factors_scratch, float* recovered_cell_weights, - int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, - int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch, - float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, - float* output_ptr, int32_t* zero_points, int32_t* row_sums, - int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs, - CpuBackendContext* context) { + float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf, + float* output_state_sf, float* scaling_factors_scratch, + float* recovered_cell_weights, int8_t* quantized_input_ptr, + int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, + int8_t* quantized_output_scratch, float* output_state_ptr, + float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, + int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp, + int32_t* row_sums, int row_sums_size, bool* compute_row_sums, + bool asymmetric_quantize_inputs, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we // can check the existence of only one to the get the condition. @@ -897,38 +898,37 @@ inline void LstmStepHybrid( if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) { tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input, - quantized_input_ptr, scaling_factors, - zero_points, asymmetric_quantize_inputs); + quantized_input_ptr, input_sf, input_zp, + asymmetric_quantize_inputs); if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_input_weights_scale, scaling_factors, n_batch, - input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, input_to_input_row_sums, compute_row_sums, - scaling_factors_scratch, context); + input_to_input_weights_scale, input_sf, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, + input_to_input_row_sums, compute_row_sums, scaling_factors_scratch, + context); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_forget_weights_scale, scaling_factors, n_batch, - forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums, - scaling_factors_scratch, context); + input_to_forget_weights_scale, input_sf, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, + input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, + context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_cell_weights_scale, scaling_factors, n_batch, - cell_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_cell_weights_scale, input_sf, n_batch, cell_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_output_weights_scale, scaling_factors, n_batch, - output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, input_to_output_row_sums, compute_row_sums, - scaling_factors_scratch, context); + input_to_output_weights_scale, input_sf, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, + input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, + context); } // For each batch and cell: compute aux_input_weight * aux_input. @@ -936,15 +936,15 @@ inline void LstmStepHybrid( if (aux_input_ptr != nullptr && !tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) { tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input, - quantized_aux_input_ptr, scaling_factors, - zero_points, asymmetric_quantize_inputs); + quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, asymmetric_quantize_inputs); if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_input_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_input_weights_scale, - scaling_factors, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_sf, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -952,24 +952,23 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_forget_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_forget_weights_scale, - scaling_factors, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_sf, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, aux_input_to_cell_weights_scale, - scaling_factors, n_batch, cell_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, - context); + quantized_aux_input_ptr, aux_input_to_cell_weights_scale, aux_input_sf, + n_batch, cell_gate_scratch, /*per_channel_scale=*/nullptr, aux_input_zp, + accum_scratch_ptr, aux_input_to_cell_row_sums, compute_row_sums, + scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_output_weights_scale, - scaling_factors, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_sf, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -978,14 +977,14 @@ inline void LstmStepHybrid( // Save quantization and matmul computation for all zero input. tensor_utils::BatchQuantizeFloats( output_state_ptr, n_batch, n_output, quantized_output_state_ptr, - scaling_factors, zero_points, asymmetric_quantize_inputs); + output_state_sf, output_state_zp, asymmetric_quantize_inputs); // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_input_weights_scale, - scaling_factors, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -993,24 +992,24 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_forget_weights_scale, - scaling_factors, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_cell_weights_scale, - scaling_factors, n_batch, cell_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, cell_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_output_weights_scale, - scaling_factors, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -1102,7 +1101,7 @@ inline void LstmStepHybrid( params->activation, projection_weights_ptr, projection_weights_scale, projection_bias_ptr, params->proj_clip, output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums, - context, scratch2, quantized_output_scratch, scaling_factors, zero_points, + context, scratch2, quantized_output_scratch, input_sf, input_zp, accum_scratch_ptr); // Copy output_state_ptr to the output. Note that the output batch rows may @@ -1892,14 +1891,16 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, - int output_offset, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, - int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) { + int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, + TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output_scratch_buffer, TfLiteTensor* output, + TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp, + TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size, + bool* compute_row_sums, CpuBackendContext* context) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int n_input = input->dims->data[input->dims->size - 1]; int max_time, n_batch; @@ -1939,10 +1940,14 @@ TfLiteStatus EvalHybrid( const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; - int32_t* zero_points_ptr = nullptr; + int32_t* input_zp_ptr = nullptr; + int32_t* aux_input_zp_ptr = nullptr; + int32_t* output_state_zp_ptr = nullptr; int32_t* row_sums_ptr = nullptr; if (params->asymmetric_quantize_inputs) { - zero_points_ptr = GetTensorData(zero_points); + input_zp_ptr = GetTensorData(input_zp); + aux_input_zp_ptr = GetTensorData(aux_input_zp); + output_state_zp_ptr = GetTensorData(output_state_zp); row_sums_ptr = GetTensorData(row_sums); } @@ -2005,7 +2010,9 @@ TfLiteStatus EvalHybrid( GetTensorData(projection_bias), params, n_batch, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, input_gate_scratch, forget_gate_scratch, cell_gate_scratch, - output_gate_scratch, GetTensorData(scaling_factors), + output_gate_scratch, GetTensorData(input_sf), + GetTensorData(aux_input_sf), + GetTensorData(output_state_sf), GetTensorData(prod_scaling_factors), GetTensorData(recovered_cell_weights), GetTensorData(input_quantized), @@ -2014,8 +2021,9 @@ TfLiteStatus EvalHybrid( GetTensorData(cell_state_quantized), GetTensorData(output_state), GetTensorData(cell_state), GetTensorData(output_scratch_buffer), output_ptr, - zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums, - params->asymmetric_quantize_inputs, context); + input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr, + row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs, + context); } } else { for (int b = 0; b < n_batch; b++) { @@ -2092,7 +2100,9 @@ TfLiteStatus EvalHybrid( /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, input_gate_scratch_ptr, forget_gate_scratch_ptr, cell_gate_scratch_ptr, - output_gate_scratch_ptr, GetTensorData(scaling_factors), + output_gate_scratch_ptr, GetTensorData(input_sf), + GetTensorData(aux_input_sf), + GetTensorData(output_state_sf), GetTensorData(prod_scaling_factors), GetTensorData(recovered_cell_weights), GetTensorData(input_quantized), @@ -2100,8 +2110,9 @@ TfLiteStatus EvalHybrid( GetTensorData(output_state_quantized), GetTensorData(cell_state_quantized), output_state_ptr, cell_state_ptr, GetTensorData(output_scratch_buffer), - output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size, - compute_row_sums, params->asymmetric_quantize_inputs, context); + output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, + row_sums_ptr, row_sums_size, compute_row_sums, + params->asymmetric_quantize_inputs, context); } } } diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index 9b3bd0c54ec..d3fdf037b5c 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -148,14 +148,16 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, - int output_offset, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, - int row_sums_size, bool* compute_row_sums, CpuBackendContext* context); + int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, + TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output_scratch_buffer, TfLiteTensor* output, + TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp, + TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size, + bool* compute_row_sums, CpuBackendContext* context); TfLiteStatus EvalInteger8x8_16( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, diff --git a/tensorflow/lite/kernels/lstm_eval_test.cc b/tensorflow/lite/kernels/lstm_eval_test.cc index 78459117859..459315bd1c9 100644 --- a/tensorflow/lite/kernels/lstm_eval_test.cc +++ b/tensorflow/lite/kernels/lstm_eval_test.cc @@ -654,15 +654,27 @@ class HybridLstmParam : public BaseLstmParam { scratch_buffer_tensor_.data.f = scratch_buffer_.data(); return &scratch_buffer_tensor_; } - TfLiteTensor* GetScalingFactors() { - PackWeightToTensor(&scaling_factors_tensor_, scaling_factors_, - scaling_factors_size_); - scaling_factors_tensor_.data.f = scaling_factors_.data(); - return &scaling_factors_tensor_; + TfLiteTensor* GetInputScalingFactors() { + PackWeightToTensor(&input_sf_tensor_, input_sf_, + quantization_extra_scratch_buffer_sizes_); + input_sf_tensor_.data.f = input_sf_.data(); + return &input_sf_tensor_; + } + TfLiteTensor* GetAuxInputScalingFactors() { + PackWeightToTensor(&aux_input_sf_tensor_, aux_input_sf_, + quantization_extra_scratch_buffer_sizes_); + aux_input_sf_tensor_.data.f = aux_input_sf_.data(); + return &aux_input_sf_tensor_; + } + TfLiteTensor* GetOutputStateScalingFactors() { + PackWeightToTensor(&output_state_sf_tensor_, output_state_sf_, + quantization_extra_scratch_buffer_sizes_); + output_state_sf_tensor_.data.f = output_state_sf_.data(); + return &output_state_sf_tensor_; } TfLiteTensor* GetProdScalingFactors() { PackWeightToTensor(&prod_scaling_factors_tensor_, prod_scaling_factors_, - prod_scaling_factors_size_); + quantization_extra_scratch_buffer_sizes_); prod_scaling_factors_tensor_.data.f = prod_scaling_factors_.data(); return &prod_scaling_factors_tensor_; } @@ -682,10 +694,23 @@ class HybridLstmParam : public BaseLstmParam { cell_quantized_tensor_.data.int8 = cell_quantized_.data(); return &cell_quantized_tensor_; } - TfLiteTensor* GetZeroPoints() { - PackWeightToTensor(&zero_points_tensor_, zero_points_, zero_points_size_); - zero_points_tensor_.data.i32 = zero_points_.data(); - return &zero_points_tensor_; + TfLiteTensor* GetInputZeroPoints() { + PackWeightToTensor(&zero_points_tensor0_, input_zp_, + quantization_extra_scratch_buffer_sizes_); + zero_points_tensor0_.data.i32 = input_zp_.data(); + return &zero_points_tensor0_; + } + TfLiteTensor* GetAuxInputZeroPoints() { + PackWeightToTensor(&zero_points_tensor1_, aux_input_zp_, + quantization_extra_scratch_buffer_sizes_); + zero_points_tensor1_.data.i32 = aux_input_zp_.data(); + return &zero_points_tensor1_; + } + TfLiteTensor* GetOutputStateZeroPoints() { + PackWeightToTensor(&zero_points_tensor2_, output_state_zp_, + quantization_extra_scratch_buffer_sizes_); + zero_points_tensor2_.data.i32 = output_state_zp_.data(); + return &zero_points_tensor2_; } TfLiteTensor* GetRowSums() { PackWeightToTensor(&row_sums_tensor_, row_sums_, row_sums_size_); @@ -776,12 +801,16 @@ class HybridLstmParam : public BaseLstmParam { ~HybridLstmParam() { TfLiteIntArrayFree(scratch_buffer_tensor_.dims); TfLiteIntArrayFree(accum_scratch_tensor_.dims); - TfLiteIntArrayFree(scaling_factors_tensor_.dims); + TfLiteIntArrayFree(input_sf_tensor_.dims); + TfLiteIntArrayFree(aux_input_sf_tensor_.dims); + TfLiteIntArrayFree(output_state_sf_tensor_.dims); TfLiteIntArrayFree(prod_scaling_factors_tensor_.dims); TfLiteIntArrayFree(input_quantized_tensor_.dims); TfLiteIntArrayFree(activation_quantized_tensor_.dims); TfLiteIntArrayFree(cell_quantized_tensor_.dims); - TfLiteIntArrayFree(zero_points_tensor_.dims); + TfLiteIntArrayFree(zero_points_tensor0_.dims); + TfLiteIntArrayFree(zero_points_tensor1_.dims); + TfLiteIntArrayFree(zero_points_tensor2_.dims); TfLiteIntArrayFree(row_sums_tensor_.dims); } @@ -792,14 +821,24 @@ class HybridLstmParam : public BaseLstmParam { std::vector scratch_buffer_size_ = {n_batch_, n_cell_ * 4}; TfLiteTensor scratch_buffer_tensor_; - std::vector scaling_factors_; - std::vector scaling_factors_size_ = {n_batch_}; - TfLiteTensor scaling_factors_tensor_; + std::vector quantization_extra_scratch_buffer_sizes_ = {n_batch_}; + std::vector input_sf_; + TfLiteTensor input_sf_tensor_; + std::vector aux_input_sf_; + TfLiteTensor aux_input_sf_tensor_; + std::vector output_state_sf_; + TfLiteTensor output_state_sf_tensor_; std::vector prod_scaling_factors_; - std::vector prod_scaling_factors_size_ = {n_batch_}; TfLiteTensor prod_scaling_factors_tensor_; + std::vector input_zp_; + TfLiteTensor zero_points_tensor0_; + std::vector aux_input_zp_; + TfLiteTensor zero_points_tensor1_; + std::vector output_state_zp_; + TfLiteTensor zero_points_tensor2_; + std::vector input_quantized_; TfLiteTensor input_quantized_tensor_; @@ -813,10 +852,6 @@ class HybridLstmParam : public BaseLstmParam { 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, 1, 14, 5, 6, 1, 1, 3, 4, -5, 6, }; - std::vector zero_points_; - std::vector zero_points_size_ = {n_batch_}; - TfLiteTensor zero_points_tensor_; - std::vector row_sums_; std::vector row_sums_size_ = {n_row_sums_, n_cell_}; TfLiteTensor row_sums_tensor_; @@ -896,13 +931,17 @@ void TestOneHybridAsymmLSTM() { /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, one_parameter.GetScratchBuffer(), - one_parameter.GetScalingFactors(), one_parameter.GetProdScalingFactors(), + one_parameter.GetInputScalingFactors(), + one_parameter.GetAuxInputScalingFactors(), + one_parameter.GetOutputStateScalingFactors(), + one_parameter.GetProdScalingFactors(), /*recovered_cell_weights=*/nullptr, one_parameter.GetInputQuantized(), /*aux_input_quantized=*/nullptr, one_parameter.GetActivationStateQuantized(), one_parameter.GetCellStateQuantized(), activation, cell, one_parameter.GetAccumScratchBuffer(), output, - one_parameter.GetZeroPoints(), one_parameter.GetRowSums(), + one_parameter.GetInputZeroPoints(), one_parameter.GetAuxInputZeroPoints(), + one_parameter.GetOutputStateZeroPoints(), one_parameter.GetRowSums(), one_parameter.GetNumRowSums(), &compute_row_sums, &context); const std::vector expected_cell = { 7.83134, 1.96158, 2.18285, 3.28739, 0.483214, diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 026b2452aef..0849c6dc0e4 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -45,13 +45,15 @@ enum TemporaryTensor { kInputQuantized = 1, kOutputStateQuantized = 2, kCellStateQuantized = 3, - kScalingFactors = 4, - kProductScalingFactors = 5, - kRecoveredCellWeights = 6, - kAccumScratch = 7, - kZeroPoints = 8, - kRowSums = 9, - kNumTemporaryTensors = 10 + kInputScalingFactors = 4, + kOutputStateScalingFactors = 5, + kProductScalingFactors = 6, + kRecoveredCellWeights = 7, + kAccumScratch = 8, + kInputZeroPoints = 9, + kOutputStateZeroPoints = 10, + kRowSums = 11, + kNumTemporaryTensors = 12, }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -416,18 +418,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // a vector once (which produces the scaling factors) and multiply it with // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). - node->temporaries->data[kScalingFactors] = - scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; + node->temporaries->data[kInputScalingFactors] = + op_data->scratch_tensor_index + kInputScalingFactors; + TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); + input_sf->type = kTfLiteFloat32; + input_sf->allocation_type = kTfLiteArenaRw; int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); + if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); + input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_sf, input_sf_size)); + } + node->temporaries->data[kOutputStateScalingFactors] = + op_data->scratch_tensor_index + kOutputStateScalingFactors; + TfLiteTensor* output_state_sf = + GetTemporary(context, node, kOutputStateScalingFactors); + output_state_sf->type = kTfLiteFloat32; + output_state_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); + output_state_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, + output_state_sf_size)); } node->temporaries->data[kProductScalingFactors] = scratch_tensor_index + kProductScalingFactors; @@ -477,15 +490,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } - node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints; - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); - zero_points->type = kTfLiteFloat32; - zero_points->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { - TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); - zero_points_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, - zero_points_size)); + node->temporaries->data[kInputZeroPoints] = + op_data->scratch_tensor_index + kInputZeroPoints; + TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); + input_zp->type = kTfLiteFloat32; + input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); + input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_zp, input_zp_size)); + } + node->temporaries->data[kOutputStateZeroPoints] = + op_data->scratch_tensor_index + kOutputStateZeroPoints; + TfLiteTensor* output_state_zp = + GetTemporary(context, node, kOutputStateZeroPoints); + output_state_zp->type = kTfLiteFloat32; + output_state_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); + output_state_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, + output_state_zp_size)); } node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums; TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); @@ -640,7 +666,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { projection_weights, projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, /*output_offset=*/0, scratch_buffer, - GetTemporary(context, node, kScalingFactors), + GetTemporary(context, node, kInputScalingFactors), + /*aux_input_sf=*/nullptr, + GetTemporary(context, node, kOutputStateScalingFactors), GetTemporary(context, node, kProductScalingFactors), GetTemporary(context, node, kRecoveredCellWeights), GetTemporary(context, node, kInputQuantized), @@ -648,8 +676,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kOutputStateQuantized), GetTemporary(context, node, kCellStateQuantized), output_state, cell_state, GetTemporary(context, node, kAccumScratch), output, - GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size, - &op_data->compute_row_sums, + GetTemporary(context, node, kInputZeroPoints), + /*aux_input_zp=*/nullptr, + GetTemporary(context, node, kOutputStateZeroPoints), row_sums, + row_sums_size, &op_data->compute_row_sums, CpuBackendContext::GetFromContext(context)); } default: