From 61a6e22f5fd3296c3838da13ec1fb28fccc5f83c Mon Sep 17 00:00:00 2001 From: Robert David Date: Fri, 26 Jun 2020 16:22:10 -0700 Subject: [PATCH] Add additional scratch buffers for Hybrid LSTM, storing quantization temporary informations separately for different inputs. This allows to process LSTM gates independently, instead of processing strictly sequentially based on different inputs. PiperOrigin-RevId: 318563361 Change-Id: I27b4d14ec9e93083a5ad48729260d7b4a1d43cde --- .../kernels/bidirectional_sequence_lstm.cc | 148 ++++++++++++------ tensorflow/lite/kernels/lstm.cc | 93 +++++++---- tensorflow/lite/kernels/lstm_eval.cc | 139 ++++++++-------- tensorflow/lite/kernels/lstm_eval.h | 18 ++- tensorflow/lite/kernels/lstm_eval_test.cc | 83 +++++++--- .../kernels/unidirectional_sequence_lstm.cc | 90 +++++++---- 6 files changed, 369 insertions(+), 202 deletions(-) diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index f1a77e2b1cb..1ce131a96ac 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -138,15 +138,19 @@ enum TemporaryTensor { kBwActivationStateQuantized = 4, kFwCellStateQuantized = 5, kBwCellStateQuantized = 6, - kScalingFactors = 7, - kProductScalingFactors = 8, - kRecoveredCellWeights = 9, - kAccumScratchBuffer = 10, - kZeroPoints = 11, - kFwRowSums = 12, - kBwRowSums = 13, - kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input. - kNumTemporaryTensors = 15 + kInputScalingFactors = 7, + kAuxInputScalingFactors = 8, + kOutputStateScalingFactors = 9, + kProductScalingFactors = 10, + kRecoveredCellWeights = 11, + kAccumScratchBuffer = 12, + kInputZeroPoints = 13, + kAuxInputZeroPoints = 14, + kOutputStateZeroPoints = 15, + kFwRowSums = 16, + kBwRowSums = 17, + kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input. + kNumTemporaryTensors = 19, }; struct OpData { @@ -699,18 +703,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // a vector once (which produces the scaling factors) and multiply it with // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). - node->temporaries->data[kScalingFactors] = - op_data->scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; + node->temporaries->data[kInputScalingFactors] = + op_data->scratch_tensor_index + kInputScalingFactors; + TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); + input_sf->type = kTfLiteFloat32; + input_sf->allocation_type = kTfLiteArenaRw; int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); + if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); + input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_sf, input_sf_size)); + } + node->temporaries->data[kAuxInputScalingFactors] = + op_data->scratch_tensor_index + kAuxInputScalingFactors; + TfLiteTensor* aux_input_sf = + GetTemporary(context, node, kAuxInputScalingFactors); + aux_input_sf->type = kTfLiteFloat32; + aux_input_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1); + aux_input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf, + aux_input_sf_size)); + } + node->temporaries->data[kOutputStateScalingFactors] = + op_data->scratch_tensor_index + kOutputStateScalingFactors; + TfLiteTensor* output_state_sf = + GetTemporary(context, node, kOutputStateScalingFactors); + output_state_sf->type = kTfLiteFloat32; + output_state_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); + output_state_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, + output_state_sf_size)); } node->temporaries->data[kProductScalingFactors] = op_data->scratch_tensor_index + kProductScalingFactors; @@ -768,16 +795,40 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Allocate temporary tensors for storing zero-points. - node->temporaries->data[kZeroPoints] = - op_data->scratch_tensor_index + kZeroPoints; - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); - zero_points->type = kTfLiteFloat32; - zero_points->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { - TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); - zero_points_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, - zero_points_size)); + node->temporaries->data[kInputZeroPoints] = + op_data->scratch_tensor_index + kInputZeroPoints; + TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); + input_zp->type = kTfLiteFloat32; + input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); + input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_zp, input_zp_size)); + } + node->temporaries->data[kAuxInputZeroPoints] = + op_data->scratch_tensor_index + kAuxInputZeroPoints; + TfLiteTensor* aux_input_zp = + GetTemporary(context, node, kAuxInputZeroPoints); + aux_input_zp->type = kTfLiteFloat32; + aux_input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1); + aux_input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp, + aux_input_zp_size)); + } + node->temporaries->data[kOutputStateZeroPoints] = + op_data->scratch_tensor_index + kOutputStateZeroPoints; + TfLiteTensor* output_state_zp = + GetTemporary(context, node, kOutputStateZeroPoints); + output_state_zp->type = kTfLiteFloat32; + output_state_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); + output_state_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, + output_state_zp_size)); } // Allocate temporary tensors for caching row sums for hybrid zero-point @@ -1071,8 +1122,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kFwCellStateQuantized); TfLiteTensor* bw_cell_state_quantized = GetTemporary(context, node, kBwCellStateQuantized); - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); TfLiteTensor* prod_scaling_factors = GetTemporary(context, node, kProductScalingFactors); TfLiteTensor* recovered_cell_weights = @@ -1082,7 +1131,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { : nullptr; TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratchBuffer); - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); const int fw_row_sums_size = fw_row_sums->dims->data[0]; @@ -1104,12 +1152,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_output_gate_bias, fw_projection_weights, fw_projection_bias, &lstm_params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, - fw_scratch_buffer, scaling_factors, prod_scaling_factors, - recovered_cell_weights, input_quantized, aux_input_quantized, - fw_activation_state_quantized, fw_cell_state_quantized, - fw_activation_state, fw_cell_state, accum_scratch, fw_output, - zero_points, fw_row_sums, fw_row_sums_size, - &op_data->compute_fw_row_sums, + fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), + GetTemporary(context, node, kAuxInputScalingFactors), + GetTemporary(context, node, kOutputStateScalingFactors), + prod_scaling_factors, recovered_cell_weights, input_quantized, + aux_input_quantized, fw_activation_state_quantized, + fw_cell_state_quantized, fw_activation_state, fw_cell_state, + accum_scratch, fw_output, + GetTemporary(context, node, kInputZeroPoints), + GetTemporary(context, node, kAuxInputZeroPoints), + GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums, + fw_row_sums_size, &op_data->compute_fw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, fw_pass_status); @@ -1130,12 +1183,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_output_gate_bias, bw_projection_weights, bw_projection_bias, &lstm_params, /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, - bw_scratch_buffer, scaling_factors, prod_scaling_factors, - recovered_cell_weights, input_quantized, aux_input_quantized, - bw_activation_state_quantized, bw_cell_state_quantized, - bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, - zero_points, bw_row_sums, bw_row_sums_size, - &op_data->compute_bw_row_sums, + bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), + GetTemporary(context, node, kAuxInputScalingFactors), + GetTemporary(context, node, kOutputStateScalingFactors), + prod_scaling_factors, recovered_cell_weights, input_quantized, + aux_input_quantized, bw_activation_state_quantized, + bw_cell_state_quantized, bw_activation_state, bw_cell_state, + accum_scratch, actual_bw_output, + GetTemporary(context, node, kInputZeroPoints), + GetTemporary(context, node, kAuxInputZeroPoints), + GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums, + bw_row_sums_size, &op_data->compute_bw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index 75de587774a..c39f715446b 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -66,13 +66,15 @@ enum HybridTemporaryTensor { kInputQuantized = 1, kOutputStateQuantized = 2, kCellStateQuantized = 3, - kScalingFactors = 4, - kProductScalingFactors = 5, - kRecoveredCellWeights = 6, - kAccumScratch = 7, - kZeroPoints = 8, - kRowSums = 9, - kNumHybridTemporaryTensors = 10, + kInputScalingFactors = 4, + kOutputStateScalingFactors = 5, + kProductScalingFactors = 6, + kRecoveredCellWeights = 7, + kAccumScratch = 8, + kInputZeroPoints = 9, + kOutputStateZeroPoints = 10, + kRowSums = 11, + kNumHybridTemporaryTensors = 12, }; TfLiteStatus PopulateQuantizedLstmParams8x8_16( @@ -1333,18 +1335,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // a vector once (which produces the scaling factors) and multiply it with // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). - node->temporaries->data[kScalingFactors] = - op_data->scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; + node->temporaries->data[kInputScalingFactors] = + op_data->scratch_tensor_index + kInputScalingFactors; + TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); + input_sf->type = kTfLiteFloat32; + input_sf->allocation_type = kTfLiteArenaRw; int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); + if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); + input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_sf, input_sf_size)); + } + node->temporaries->data[kOutputStateScalingFactors] = + op_data->scratch_tensor_index + kOutputStateScalingFactors; + TfLiteTensor* output_state_sf = + GetTemporary(context, node, kOutputStateScalingFactors); + output_state_sf->type = kTfLiteFloat32; + output_state_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); + output_state_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, + output_state_sf_size)); } node->temporaries->data[kProductScalingFactors] = op_data->scratch_tensor_index + kProductScalingFactors; @@ -1394,18 +1407,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } - - node->temporaries->data[kZeroPoints] = - op_data->scratch_tensor_index + kZeroPoints; - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); - zero_points->type = kTfLiteFloat32; - zero_points->allocation_type = kTfLiteArenaRw; - int zero_points_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { - TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); - zero_points_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, - zero_points_size)); + node->temporaries->data[kInputZeroPoints] = + op_data->scratch_tensor_index + kInputZeroPoints; + TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); + input_zp->type = kTfLiteFloat32; + input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); + input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_zp, input_zp_size)); + } + node->temporaries->data[kOutputStateZeroPoints] = + op_data->scratch_tensor_index + kOutputStateZeroPoints; + TfLiteTensor* output_state_zp = + GetTemporary(context, node, kOutputStateZeroPoints); + output_state_zp->type = kTfLiteFloat32; + output_state_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); + output_state_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, + output_state_zp_size)); } node->temporaries->data[kRowSums] = @@ -1621,7 +1644,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { projection_weights, projection_bias, params, /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer), - GetTemporary(context, node, kScalingFactors), + GetTemporary(context, node, kInputScalingFactors), + /*aux_input_sf=*/nullptr, + GetTemporary(context, node, kOutputStateScalingFactors), GetTemporary(context, node, kProductScalingFactors), GetTemporary(context, node, kRecoveredCellWeights), GetTemporary(context, node, kInputQuantized), @@ -1629,8 +1654,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kOutputStateQuantized), GetTemporary(context, node, kCellStateQuantized), output_state, cell_state, GetTemporary(context, node, kAccumScratch), output, - GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size, - &op_data->compute_row_sums, + GetTemporary(context, node, kInputZeroPoints), + /*aux_input_zp=*/nullptr, + GetTemporary(context, node, kOutputStateZeroPoints), row_sums, + row_sums_size, &op_data->compute_row_sums, CpuBackendContext::GetFromContext(context)); } else { const int num_intermediate_tensors = node->intermediates->size; diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 9b24dc04c26..f97411a3a97 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -785,14 +785,15 @@ inline void LstmStepHybrid( const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float* scratch0, float* scratch1, - float* scratch2, float* scratch3, float* scaling_factors, - float* scaling_factors_scratch, float* recovered_cell_weights, - int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, - int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch, - float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, - float* output_ptr, int32_t* zero_points, int32_t* row_sums, - int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs, - CpuBackendContext* context) { + float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf, + float* output_state_sf, float* scaling_factors_scratch, + float* recovered_cell_weights, int8_t* quantized_input_ptr, + int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, + int8_t* quantized_output_scratch, float* output_state_ptr, + float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, + int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp, + int32_t* row_sums, int row_sums_size, bool* compute_row_sums, + bool asymmetric_quantize_inputs, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we // can check the existence of only one to the get the condition. @@ -897,38 +898,37 @@ inline void LstmStepHybrid( if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) { tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input, - quantized_input_ptr, scaling_factors, - zero_points, asymmetric_quantize_inputs); + quantized_input_ptr, input_sf, input_zp, + asymmetric_quantize_inputs); if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_input_weights_scale, scaling_factors, n_batch, - input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, input_to_input_row_sums, compute_row_sums, - scaling_factors_scratch, context); + input_to_input_weights_scale, input_sf, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, + input_to_input_row_sums, compute_row_sums, scaling_factors_scratch, + context); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_forget_weights_scale, scaling_factors, n_batch, - forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums, - scaling_factors_scratch, context); + input_to_forget_weights_scale, input_sf, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, + input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, + context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_cell_weights_scale, scaling_factors, n_batch, - cell_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_cell_weights_scale, input_sf, n_batch, cell_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - input_to_output_weights_scale, scaling_factors, n_batch, - output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, input_to_output_row_sums, compute_row_sums, - scaling_factors_scratch, context); + input_to_output_weights_scale, input_sf, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr, + input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, + context); } // For each batch and cell: compute aux_input_weight * aux_input. @@ -936,15 +936,15 @@ inline void LstmStepHybrid( if (aux_input_ptr != nullptr && !tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) { tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input, - quantized_aux_input_ptr, scaling_factors, - zero_points, asymmetric_quantize_inputs); + quantized_aux_input_ptr, aux_input_sf, + aux_input_zp, asymmetric_quantize_inputs); if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_input_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_input_weights_scale, - scaling_factors, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_sf, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -952,24 +952,23 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_forget_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_forget_weights_scale, - scaling_factors, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_sf, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, aux_input_to_cell_weights_scale, - scaling_factors, n_batch, cell_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, - context); + quantized_aux_input_ptr, aux_input_to_cell_weights_scale, aux_input_sf, + n_batch, cell_gate_scratch, /*per_channel_scale=*/nullptr, aux_input_zp, + accum_scratch_ptr, aux_input_to_cell_row_sums, compute_row_sums, + scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, aux_input_to_output_weights_scale, - scaling_factors, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_sf, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -978,14 +977,14 @@ inline void LstmStepHybrid( // Save quantization and matmul computation for all zero input. tensor_utils::BatchQuantizeFloats( output_state_ptr, n_batch, n_output, quantized_output_state_ptr, - scaling_factors, zero_points, asymmetric_quantize_inputs); + output_state_sf, output_state_zp, asymmetric_quantize_inputs); // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_input_weights_scale, - scaling_factors, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -993,24 +992,24 @@ inline void LstmStepHybrid( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_forget_weights_scale, - scaling_factors, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_cell_weights_scale, - scaling_factors, n_batch, cell_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, cell_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, quantized_output_state_ptr, recurrent_to_output_weights_scale, - scaling_factors, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + output_state_sf, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -1102,7 +1101,7 @@ inline void LstmStepHybrid( params->activation, projection_weights_ptr, projection_weights_scale, projection_bias_ptr, params->proj_clip, output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums, - context, scratch2, quantized_output_scratch, scaling_factors, zero_points, + context, scratch2, quantized_output_scratch, input_sf, input_zp, accum_scratch_ptr); // Copy output_state_ptr to the output. Note that the output batch rows may @@ -1892,14 +1891,16 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, - int output_offset, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, - int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) { + int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, + TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output_scratch_buffer, TfLiteTensor* output, + TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp, + TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size, + bool* compute_row_sums, CpuBackendContext* context) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int n_input = input->dims->data[input->dims->size - 1]; int max_time, n_batch; @@ -1939,10 +1940,14 @@ TfLiteStatus EvalHybrid( const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; - int32_t* zero_points_ptr = nullptr; + int32_t* input_zp_ptr = nullptr; + int32_t* aux_input_zp_ptr = nullptr; + int32_t* output_state_zp_ptr = nullptr; int32_t* row_sums_ptr = nullptr; if (params->asymmetric_quantize_inputs) { - zero_points_ptr = GetTensorData(zero_points); + input_zp_ptr = GetTensorData(input_zp); + aux_input_zp_ptr = GetTensorData(aux_input_zp); + output_state_zp_ptr = GetTensorData(output_state_zp); row_sums_ptr = GetTensorData(row_sums); } @@ -2005,7 +2010,9 @@ TfLiteStatus EvalHybrid( GetTensorData(projection_bias), params, n_batch, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, input_gate_scratch, forget_gate_scratch, cell_gate_scratch, - output_gate_scratch, GetTensorData(scaling_factors), + output_gate_scratch, GetTensorData(input_sf), + GetTensorData(aux_input_sf), + GetTensorData(output_state_sf), GetTensorData(prod_scaling_factors), GetTensorData(recovered_cell_weights), GetTensorData(input_quantized), @@ -2014,8 +2021,9 @@ TfLiteStatus EvalHybrid( GetTensorData(cell_state_quantized), GetTensorData(output_state), GetTensorData(cell_state), GetTensorData(output_scratch_buffer), output_ptr, - zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums, - params->asymmetric_quantize_inputs, context); + input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr, + row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs, + context); } } else { for (int b = 0; b < n_batch; b++) { @@ -2092,7 +2100,9 @@ TfLiteStatus EvalHybrid( /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, input_gate_scratch_ptr, forget_gate_scratch_ptr, cell_gate_scratch_ptr, - output_gate_scratch_ptr, GetTensorData(scaling_factors), + output_gate_scratch_ptr, GetTensorData(input_sf), + GetTensorData(aux_input_sf), + GetTensorData(output_state_sf), GetTensorData(prod_scaling_factors), GetTensorData(recovered_cell_weights), GetTensorData(input_quantized), @@ -2100,8 +2110,9 @@ TfLiteStatus EvalHybrid( GetTensorData(output_state_quantized), GetTensorData(cell_state_quantized), output_state_ptr, cell_state_ptr, GetTensorData(output_scratch_buffer), - output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size, - compute_row_sums, params->asymmetric_quantize_inputs, context); + output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, + row_sums_ptr, row_sums_size, compute_row_sums, + params->asymmetric_quantize_inputs, context); } } } diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index 9b3bd0c54ec..d3fdf037b5c 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -148,14 +148,16 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, - int output_offset, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, - int row_sums_size, bool* compute_row_sums, CpuBackendContext* context); + int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, + TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output_scratch_buffer, TfLiteTensor* output, + TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp, + TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size, + bool* compute_row_sums, CpuBackendContext* context); TfLiteStatus EvalInteger8x8_16( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, diff --git a/tensorflow/lite/kernels/lstm_eval_test.cc b/tensorflow/lite/kernels/lstm_eval_test.cc index 78459117859..459315bd1c9 100644 --- a/tensorflow/lite/kernels/lstm_eval_test.cc +++ b/tensorflow/lite/kernels/lstm_eval_test.cc @@ -654,15 +654,27 @@ class HybridLstmParam : public BaseLstmParam { scratch_buffer_tensor_.data.f = scratch_buffer_.data(); return &scratch_buffer_tensor_; } - TfLiteTensor* GetScalingFactors() { - PackWeightToTensor(&scaling_factors_tensor_, scaling_factors_, - scaling_factors_size_); - scaling_factors_tensor_.data.f = scaling_factors_.data(); - return &scaling_factors_tensor_; + TfLiteTensor* GetInputScalingFactors() { + PackWeightToTensor(&input_sf_tensor_, input_sf_, + quantization_extra_scratch_buffer_sizes_); + input_sf_tensor_.data.f = input_sf_.data(); + return &input_sf_tensor_; + } + TfLiteTensor* GetAuxInputScalingFactors() { + PackWeightToTensor(&aux_input_sf_tensor_, aux_input_sf_, + quantization_extra_scratch_buffer_sizes_); + aux_input_sf_tensor_.data.f = aux_input_sf_.data(); + return &aux_input_sf_tensor_; + } + TfLiteTensor* GetOutputStateScalingFactors() { + PackWeightToTensor(&output_state_sf_tensor_, output_state_sf_, + quantization_extra_scratch_buffer_sizes_); + output_state_sf_tensor_.data.f = output_state_sf_.data(); + return &output_state_sf_tensor_; } TfLiteTensor* GetProdScalingFactors() { PackWeightToTensor(&prod_scaling_factors_tensor_, prod_scaling_factors_, - prod_scaling_factors_size_); + quantization_extra_scratch_buffer_sizes_); prod_scaling_factors_tensor_.data.f = prod_scaling_factors_.data(); return &prod_scaling_factors_tensor_; } @@ -682,10 +694,23 @@ class HybridLstmParam : public BaseLstmParam { cell_quantized_tensor_.data.int8 = cell_quantized_.data(); return &cell_quantized_tensor_; } - TfLiteTensor* GetZeroPoints() { - PackWeightToTensor(&zero_points_tensor_, zero_points_, zero_points_size_); - zero_points_tensor_.data.i32 = zero_points_.data(); - return &zero_points_tensor_; + TfLiteTensor* GetInputZeroPoints() { + PackWeightToTensor(&zero_points_tensor0_, input_zp_, + quantization_extra_scratch_buffer_sizes_); + zero_points_tensor0_.data.i32 = input_zp_.data(); + return &zero_points_tensor0_; + } + TfLiteTensor* GetAuxInputZeroPoints() { + PackWeightToTensor(&zero_points_tensor1_, aux_input_zp_, + quantization_extra_scratch_buffer_sizes_); + zero_points_tensor1_.data.i32 = aux_input_zp_.data(); + return &zero_points_tensor1_; + } + TfLiteTensor* GetOutputStateZeroPoints() { + PackWeightToTensor(&zero_points_tensor2_, output_state_zp_, + quantization_extra_scratch_buffer_sizes_); + zero_points_tensor2_.data.i32 = output_state_zp_.data(); + return &zero_points_tensor2_; } TfLiteTensor* GetRowSums() { PackWeightToTensor(&row_sums_tensor_, row_sums_, row_sums_size_); @@ -776,12 +801,16 @@ class HybridLstmParam : public BaseLstmParam { ~HybridLstmParam() { TfLiteIntArrayFree(scratch_buffer_tensor_.dims); TfLiteIntArrayFree(accum_scratch_tensor_.dims); - TfLiteIntArrayFree(scaling_factors_tensor_.dims); + TfLiteIntArrayFree(input_sf_tensor_.dims); + TfLiteIntArrayFree(aux_input_sf_tensor_.dims); + TfLiteIntArrayFree(output_state_sf_tensor_.dims); TfLiteIntArrayFree(prod_scaling_factors_tensor_.dims); TfLiteIntArrayFree(input_quantized_tensor_.dims); TfLiteIntArrayFree(activation_quantized_tensor_.dims); TfLiteIntArrayFree(cell_quantized_tensor_.dims); - TfLiteIntArrayFree(zero_points_tensor_.dims); + TfLiteIntArrayFree(zero_points_tensor0_.dims); + TfLiteIntArrayFree(zero_points_tensor1_.dims); + TfLiteIntArrayFree(zero_points_tensor2_.dims); TfLiteIntArrayFree(row_sums_tensor_.dims); } @@ -792,14 +821,24 @@ class HybridLstmParam : public BaseLstmParam { std::vector scratch_buffer_size_ = {n_batch_, n_cell_ * 4}; TfLiteTensor scratch_buffer_tensor_; - std::vector scaling_factors_; - std::vector scaling_factors_size_ = {n_batch_}; - TfLiteTensor scaling_factors_tensor_; + std::vector quantization_extra_scratch_buffer_sizes_ = {n_batch_}; + std::vector input_sf_; + TfLiteTensor input_sf_tensor_; + std::vector aux_input_sf_; + TfLiteTensor aux_input_sf_tensor_; + std::vector output_state_sf_; + TfLiteTensor output_state_sf_tensor_; std::vector prod_scaling_factors_; - std::vector prod_scaling_factors_size_ = {n_batch_}; TfLiteTensor prod_scaling_factors_tensor_; + std::vector input_zp_; + TfLiteTensor zero_points_tensor0_; + std::vector aux_input_zp_; + TfLiteTensor zero_points_tensor1_; + std::vector output_state_zp_; + TfLiteTensor zero_points_tensor2_; + std::vector input_quantized_; TfLiteTensor input_quantized_tensor_; @@ -813,10 +852,6 @@ class HybridLstmParam : public BaseLstmParam { 16, 4, 5, 6, 1, 1, 3, 4, -5, 6, 1, 14, 5, 6, 1, 1, 3, 4, -5, 6, }; - std::vector zero_points_; - std::vector zero_points_size_ = {n_batch_}; - TfLiteTensor zero_points_tensor_; - std::vector row_sums_; std::vector row_sums_size_ = {n_row_sums_, n_cell_}; TfLiteTensor row_sums_tensor_; @@ -896,13 +931,17 @@ void TestOneHybridAsymmLSTM() { /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, one_parameter.GetScratchBuffer(), - one_parameter.GetScalingFactors(), one_parameter.GetProdScalingFactors(), + one_parameter.GetInputScalingFactors(), + one_parameter.GetAuxInputScalingFactors(), + one_parameter.GetOutputStateScalingFactors(), + one_parameter.GetProdScalingFactors(), /*recovered_cell_weights=*/nullptr, one_parameter.GetInputQuantized(), /*aux_input_quantized=*/nullptr, one_parameter.GetActivationStateQuantized(), one_parameter.GetCellStateQuantized(), activation, cell, one_parameter.GetAccumScratchBuffer(), output, - one_parameter.GetZeroPoints(), one_parameter.GetRowSums(), + one_parameter.GetInputZeroPoints(), one_parameter.GetAuxInputZeroPoints(), + one_parameter.GetOutputStateZeroPoints(), one_parameter.GetRowSums(), one_parameter.GetNumRowSums(), &compute_row_sums, &context); const std::vector expected_cell = { 7.83134, 1.96158, 2.18285, 3.28739, 0.483214, diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 026b2452aef..0849c6dc0e4 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -45,13 +45,15 @@ enum TemporaryTensor { kInputQuantized = 1, kOutputStateQuantized = 2, kCellStateQuantized = 3, - kScalingFactors = 4, - kProductScalingFactors = 5, - kRecoveredCellWeights = 6, - kAccumScratch = 7, - kZeroPoints = 8, - kRowSums = 9, - kNumTemporaryTensors = 10 + kInputScalingFactors = 4, + kOutputStateScalingFactors = 5, + kProductScalingFactors = 6, + kRecoveredCellWeights = 7, + kAccumScratch = 8, + kInputZeroPoints = 9, + kOutputStateZeroPoints = 10, + kRowSums = 11, + kNumTemporaryTensors = 12, }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -416,18 +418,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // a vector once (which produces the scaling factors) and multiply it with // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). - node->temporaries->data[kScalingFactors] = - scratch_tensor_index + kScalingFactors; - TfLiteTensor* scaling_factors = - GetTemporary(context, node, kScalingFactors); - scaling_factors->type = kTfLiteFloat32; - scaling_factors->allocation_type = kTfLiteArenaRw; + node->temporaries->data[kInputScalingFactors] = + op_data->scratch_tensor_index + kInputScalingFactors; + TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors); + input_sf->type = kTfLiteFloat32; + input_sf->allocation_type = kTfLiteArenaRw; int scaling_dims[1] = {n_batch}; - if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { - TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); - scaling_factors_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, - scaling_factors_size)); + if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); + input_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_sf, input_sf_size)); + } + node->temporaries->data[kOutputStateScalingFactors] = + op_data->scratch_tensor_index + kOutputStateScalingFactors; + TfLiteTensor* output_state_sf = + GetTemporary(context, node, kOutputStateScalingFactors); + output_state_sf->type = kTfLiteFloat32; + output_state_sf->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); + output_state_sf_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, + output_state_sf_size)); } node->temporaries->data[kProductScalingFactors] = scratch_tensor_index + kProductScalingFactors; @@ -477,15 +490,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } - node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints; - TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); - zero_points->type = kTfLiteFloat32; - zero_points->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { - TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); - zero_points_size->data[0] = n_batch; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, - zero_points_size)); + node->temporaries->data[kInputZeroPoints] = + op_data->scratch_tensor_index + kInputZeroPoints; + TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints); + input_zp->type = kTfLiteFloat32; + input_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); + input_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, input_zp, input_zp_size)); + } + node->temporaries->data[kOutputStateZeroPoints] = + op_data->scratch_tensor_index + kOutputStateZeroPoints; + TfLiteTensor* output_state_zp = + GetTemporary(context, node, kOutputStateZeroPoints); + output_state_zp->type = kTfLiteFloat32; + output_state_zp->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { + TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); + output_state_zp_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, + output_state_zp_size)); } node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums; TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); @@ -640,7 +666,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { projection_weights, projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, /*output_offset=*/0, scratch_buffer, - GetTemporary(context, node, kScalingFactors), + GetTemporary(context, node, kInputScalingFactors), + /*aux_input_sf=*/nullptr, + GetTemporary(context, node, kOutputStateScalingFactors), GetTemporary(context, node, kProductScalingFactors), GetTemporary(context, node, kRecoveredCellWeights), GetTemporary(context, node, kInputQuantized), @@ -648,8 +676,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kOutputStateQuantized), GetTemporary(context, node, kCellStateQuantized), output_state, cell_state, GetTemporary(context, node, kAccumScratch), output, - GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size, - &op_data->compute_row_sums, + GetTemporary(context, node, kInputZeroPoints), + /*aux_input_zp=*/nullptr, + GetTemporary(context, node, kOutputStateZeroPoints), row_sums, + row_sums_size, &op_data->compute_row_sums, CpuBackendContext::GetFromContext(context)); } default: