Add additional scratch buffers for Hybrid LSTM, storing quantization temporary informations separately for different inputs.

This allows to process LSTM gates independently, instead of processing strictly sequentially based on different inputs.

PiperOrigin-RevId: 318563361
Change-Id: I27b4d14ec9e93083a5ad48729260d7b4a1d43cde
This commit is contained in:
Robert David 2020-06-26 16:22:10 -07:00 committed by TensorFlower Gardener
parent 4d7d1a8c34
commit 61a6e22f5f
6 changed files with 369 additions and 202 deletions

View File

@ -138,15 +138,19 @@ enum TemporaryTensor {
kBwActivationStateQuantized = 4,
kFwCellStateQuantized = 5,
kBwCellStateQuantized = 6,
kScalingFactors = 7,
kProductScalingFactors = 8,
kRecoveredCellWeights = 9,
kAccumScratchBuffer = 10,
kZeroPoints = 11,
kFwRowSums = 12,
kBwRowSums = 13,
kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input.
kNumTemporaryTensors = 15
kInputScalingFactors = 7,
kAuxInputScalingFactors = 8,
kOutputStateScalingFactors = 9,
kProductScalingFactors = 10,
kRecoveredCellWeights = 11,
kAccumScratchBuffer = 12,
kInputZeroPoints = 13,
kAuxInputZeroPoints = 14,
kOutputStateZeroPoints = 15,
kFwRowSums = 16,
kBwRowSums = 17,
kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input.
kNumTemporaryTensors = 19,
};
struct OpData {
@ -699,18 +703,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// a vector once (which produces the scaling factors) and multiply it with
// different matrices (which requires multiplying the scaling factors with
// the scaling factor of the matrix).
node->temporaries->data[kScalingFactors] =
op_data->scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch};
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
scaling_factors_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size));
if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
input_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, input_sf, input_sf_size));
}
node->temporaries->data[kAuxInputScalingFactors] =
op_data->scratch_tensor_index + kAuxInputScalingFactors;
TfLiteTensor* aux_input_sf =
GetTemporary(context, node, kAuxInputScalingFactors);
aux_input_sf->type = kTfLiteFloat32;
aux_input_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1);
aux_input_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf,
aux_input_sf_size));
}
node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf =
GetTemporary(context, node, kOutputStateScalingFactors);
output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
output_state_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
output_state_sf_size));
}
node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors;
@ -768,16 +795,40 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Allocate temporary tensors for storing zero-points.
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
input_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, input_zp, input_zp_size));
}
node->temporaries->data[kAuxInputZeroPoints] =
op_data->scratch_tensor_index + kAuxInputZeroPoints;
TfLiteTensor* aux_input_zp =
GetTemporary(context, node, kAuxInputZeroPoints);
aux_input_zp->type = kTfLiteFloat32;
aux_input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1);
aux_input_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp,
aux_input_zp_size));
}
node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp =
GetTemporary(context, node, kOutputStateZeroPoints);
output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
output_state_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
output_state_zp_size));
}
// Allocate temporary tensors for caching row sums for hybrid zero-point
@ -1071,8 +1122,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, kFwCellStateQuantized);
TfLiteTensor* bw_cell_state_quantized =
GetTemporary(context, node, kBwCellStateQuantized);
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
TfLiteTensor* prod_scaling_factors =
GetTemporary(context, node, kProductScalingFactors);
TfLiteTensor* recovered_cell_weights =
@ -1082,7 +1131,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
: nullptr;
TfLiteTensor* accum_scratch =
GetTemporary(context, node, kAccumScratchBuffer);
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
const int fw_row_sums_size = fw_row_sums->dims->data[0];
@ -1104,12 +1152,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
&lstm_params,
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
recovered_cell_weights, input_quantized, aux_input_quantized,
fw_activation_state_quantized, fw_cell_state_quantized,
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
zero_points, fw_row_sums, fw_row_sums_size,
&op_data->compute_fw_row_sums,
fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
GetTemporary(context, node, kAuxInputScalingFactors),
GetTemporary(context, node, kOutputStateScalingFactors),
prod_scaling_factors, recovered_cell_weights, input_quantized,
aux_input_quantized, fw_activation_state_quantized,
fw_cell_state_quantized, fw_activation_state, fw_cell_state,
accum_scratch, fw_output,
GetTemporary(context, node, kInputZeroPoints),
GetTemporary(context, node, kAuxInputZeroPoints),
GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums,
fw_row_sums_size, &op_data->compute_fw_row_sums,
CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, fw_pass_status);
@ -1130,12 +1183,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
&lstm_params,
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
recovered_cell_weights, input_quantized, aux_input_quantized,
bw_activation_state_quantized, bw_cell_state_quantized,
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
zero_points, bw_row_sums, bw_row_sums_size,
&op_data->compute_bw_row_sums,
bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
GetTemporary(context, node, kAuxInputScalingFactors),
GetTemporary(context, node, kOutputStateScalingFactors),
prod_scaling_factors, recovered_cell_weights, input_quantized,
aux_input_quantized, bw_activation_state_quantized,
bw_cell_state_quantized, bw_activation_state, bw_cell_state,
accum_scratch, actual_bw_output,
GetTemporary(context, node, kInputZeroPoints),
GetTemporary(context, node, kAuxInputZeroPoints),
GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums,
bw_row_sums_size, &op_data->compute_bw_row_sums,
CpuBackendContext::GetFromContext(context));
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;

View File

@ -66,13 +66,15 @@ enum HybridTemporaryTensor {
kInputQuantized = 1,
kOutputStateQuantized = 2,
kCellStateQuantized = 3,
kScalingFactors = 4,
kProductScalingFactors = 5,
kRecoveredCellWeights = 6,
kAccumScratch = 7,
kZeroPoints = 8,
kRowSums = 9,
kNumHybridTemporaryTensors = 10,
kInputScalingFactors = 4,
kOutputStateScalingFactors = 5,
kProductScalingFactors = 6,
kRecoveredCellWeights = 7,
kAccumScratch = 8,
kInputZeroPoints = 9,
kOutputStateZeroPoints = 10,
kRowSums = 11,
kNumHybridTemporaryTensors = 12,
};
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
@ -1333,18 +1335,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// a vector once (which produces the scaling factors) and multiply it with
// different matrices (which requires multiplying the scaling factors with
// the scaling factor of the matrix).
node->temporaries->data[kScalingFactors] =
op_data->scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch};
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
scaling_factors_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size));
if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
input_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, input_sf, input_sf_size));
}
node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf =
GetTemporary(context, node, kOutputStateScalingFactors);
output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
output_state_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
output_state_sf_size));
}
node->temporaries->data[kProductScalingFactors] =
op_data->scratch_tensor_index + kProductScalingFactors;
@ -1394,18 +1407,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size));
}
node->temporaries->data[kZeroPoints] =
op_data->scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
int zero_points_dims[1] = {n_batch};
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
input_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, input_zp, input_zp_size));
}
node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp =
GetTemporary(context, node, kOutputStateZeroPoints);
output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
output_state_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
output_state_zp_size));
}
node->temporaries->data[kRowSums] =
@ -1621,7 +1644,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
projection_weights, projection_bias, params,
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
GetTemporary(context, node, kScratchBuffer),
GetTemporary(context, node, kScalingFactors),
GetTemporary(context, node, kInputScalingFactors),
/*aux_input_sf=*/nullptr,
GetTemporary(context, node, kOutputStateScalingFactors),
GetTemporary(context, node, kProductScalingFactors),
GetTemporary(context, node, kRecoveredCellWeights),
GetTemporary(context, node, kInputQuantized),
@ -1629,8 +1654,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, kOutputStateQuantized),
GetTemporary(context, node, kCellStateQuantized), output_state,
cell_state, GetTemporary(context, node, kAccumScratch), output,
GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size,
&op_data->compute_row_sums,
GetTemporary(context, node, kInputZeroPoints),
/*aux_input_zp=*/nullptr,
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
row_sums_size, &op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
} else {
const int num_intermediate_tensors = node->intermediates->size;

View File

@ -785,14 +785,15 @@ inline void LstmStepHybrid(
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* scratch0, float* scratch1,
float* scratch2, float* scratch3, float* scaling_factors,
float* scaling_factors_scratch, float* recovered_cell_weights,
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
float* output_ptr, int32_t* zero_points, int32_t* row_sums,
int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs,
CpuBackendContext* context) {
float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf,
float* output_state_sf, float* scaling_factors_scratch,
float* recovered_cell_weights, int8_t* quantized_input_ptr,
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_output_scratch, float* output_state_ptr,
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp,
int32_t* row_sums, int row_sums_size, bool* compute_row_sums,
bool asymmetric_quantize_inputs, CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepHybrid");
// Since we have already checked that weights are all there or none, we
// can check the existence of only one to the get the condition.
@ -897,38 +898,37 @@ inline void LstmStepHybrid(
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input,
quantized_input_ptr, scaling_factors,
zero_points, asymmetric_quantize_inputs);
quantized_input_ptr, input_sf, input_zp,
asymmetric_quantize_inputs);
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
input_to_input_weights_scale, scaling_factors, n_batch,
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, input_to_input_row_sums, compute_row_sums,
scaling_factors_scratch, context);
input_to_input_weights_scale, input_sf, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
input_to_input_row_sums, compute_row_sums, scaling_factors_scratch,
context);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
input_to_forget_weights_scale, scaling_factors, n_batch,
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums,
scaling_factors_scratch, context);
input_to_forget_weights_scale, input_sf, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
input_to_cell_weights_scale, scaling_factors, n_batch,
cell_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_cell_weights_scale, input_sf, n_batch, cell_gate_scratch,
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
input_to_output_weights_scale, scaling_factors, n_batch,
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, input_to_output_row_sums, compute_row_sums,
scaling_factors_scratch, context);
input_to_output_weights_scale, input_sf, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
input_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
context);
}
// For each batch and cell: compute aux_input_weight * aux_input.
@ -936,15 +936,15 @@ inline void LstmStepHybrid(
if (aux_input_ptr != nullptr &&
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input,
quantized_aux_input_ptr, scaling_factors,
zero_points, asymmetric_quantize_inputs);
quantized_aux_input_ptr, aux_input_sf,
aux_input_zp, asymmetric_quantize_inputs);
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, aux_input_to_input_weights_scale,
scaling_factors, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_sf, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr,
aux_input_to_input_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
@ -952,24 +952,23 @@ inline void LstmStepHybrid(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, aux_input_to_forget_weights_scale,
scaling_factors, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_sf, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr,
aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
scaling_factors, n_batch, cell_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
quantized_aux_input_ptr, aux_input_to_cell_weights_scale, aux_input_sf,
n_batch, cell_gate_scratch, /*per_channel_scale=*/nullptr, aux_input_zp,
accum_scratch_ptr, aux_input_to_cell_row_sums, compute_row_sums,
scaling_factors_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, aux_input_to_output_weights_scale,
scaling_factors, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_sf, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr,
aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
context);
}
@ -978,14 +977,14 @@ inline void LstmStepHybrid(
// Save quantization and matmul computation for all zero input.
tensor_utils::BatchQuantizeFloats(
output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
scaling_factors, zero_points, asymmetric_quantize_inputs);
output_state_sf, output_state_zp, asymmetric_quantize_inputs);
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, recurrent_to_input_weights_scale,
scaling_factors, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
output_state_sf, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
recurrent_to_input_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
@ -993,24 +992,24 @@ inline void LstmStepHybrid(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, recurrent_to_forget_weights_scale,
scaling_factors, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
output_state_sf, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
scaling_factors, n_batch, cell_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
output_state_sf, n_batch, cell_gate_scratch,
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, recurrent_to_output_weights_scale,
scaling_factors, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
output_state_sf, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
context);
}
@ -1102,7 +1101,7 @@ inline void LstmStepHybrid(
params->activation, projection_weights_ptr, projection_weights_scale,
projection_bias_ptr, params->proj_clip, output_state_ptr,
asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums,
context, scratch2, quantized_output_scratch, scaling_factors, zero_points,
context, scratch2, quantized_output_scratch, input_sf, input_zp,
accum_scratch_ptr);
// Copy output_state_ptr to the output. Note that the output batch rows may
@ -1892,14 +1891,16 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) {
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state,
TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
bool* compute_row_sums, CpuBackendContext* context) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
@ -1939,10 +1940,14 @@ TfLiteStatus EvalHybrid(
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
int32_t* zero_points_ptr = nullptr;
int32_t* input_zp_ptr = nullptr;
int32_t* aux_input_zp_ptr = nullptr;
int32_t* output_state_zp_ptr = nullptr;
int32_t* row_sums_ptr = nullptr;
if (params->asymmetric_quantize_inputs) {
zero_points_ptr = GetTensorData<int32_t>(zero_points);
input_zp_ptr = GetTensorData<int32_t>(input_zp);
aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp);
output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp);
row_sums_ptr = GetTensorData<int32_t>(row_sums);
}
@ -2005,7 +2010,9 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
output_gate_scratch, GetTensorData<float>(scaling_factors),
output_gate_scratch, GetTensorData<float>(input_sf),
GetTensorData<float>(aux_input_sf),
GetTensorData<float>(output_state_sf),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
@ -2014,8 +2021,9 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(cell_state_quantized),
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
params->asymmetric_quantize_inputs, context);
input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr,
row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs,
context);
}
} else {
for (int b = 0; b < n_batch; b++) {
@ -2092,7 +2100,9 @@ TfLiteStatus EvalHybrid(
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
output_gate_scratch_ptr, GetTensorData<float>(scaling_factors),
output_gate_scratch_ptr, GetTensorData<float>(input_sf),
GetTensorData<float>(aux_input_sf),
GetTensorData<float>(output_state_sf),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
@ -2100,8 +2110,9 @@ TfLiteStatus EvalHybrid(
GetTensorData<int8_t>(output_state_quantized),
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size,
compute_row_sums, params->asymmetric_quantize_inputs, context);
output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr,
row_sums_ptr, row_sums_size, compute_row_sums,
params->asymmetric_quantize_inputs, context);
}
}
}

View File

@ -148,14 +148,16 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context);
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state,
TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
bool* compute_row_sums, CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,

View File

@ -654,15 +654,27 @@ class HybridLstmParam : public BaseLstmParam {
scratch_buffer_tensor_.data.f = scratch_buffer_.data();
return &scratch_buffer_tensor_;
}
TfLiteTensor* GetScalingFactors() {
PackWeightToTensor(&scaling_factors_tensor_, scaling_factors_,
scaling_factors_size_);
scaling_factors_tensor_.data.f = scaling_factors_.data();
return &scaling_factors_tensor_;
TfLiteTensor* GetInputScalingFactors() {
PackWeightToTensor(&input_sf_tensor_, input_sf_,
quantization_extra_scratch_buffer_sizes_);
input_sf_tensor_.data.f = input_sf_.data();
return &input_sf_tensor_;
}
TfLiteTensor* GetAuxInputScalingFactors() {
PackWeightToTensor(&aux_input_sf_tensor_, aux_input_sf_,
quantization_extra_scratch_buffer_sizes_);
aux_input_sf_tensor_.data.f = aux_input_sf_.data();
return &aux_input_sf_tensor_;
}
TfLiteTensor* GetOutputStateScalingFactors() {
PackWeightToTensor(&output_state_sf_tensor_, output_state_sf_,
quantization_extra_scratch_buffer_sizes_);
output_state_sf_tensor_.data.f = output_state_sf_.data();
return &output_state_sf_tensor_;
}
TfLiteTensor* GetProdScalingFactors() {
PackWeightToTensor(&prod_scaling_factors_tensor_, prod_scaling_factors_,
prod_scaling_factors_size_);
quantization_extra_scratch_buffer_sizes_);
prod_scaling_factors_tensor_.data.f = prod_scaling_factors_.data();
return &prod_scaling_factors_tensor_;
}
@ -682,10 +694,23 @@ class HybridLstmParam : public BaseLstmParam {
cell_quantized_tensor_.data.int8 = cell_quantized_.data();
return &cell_quantized_tensor_;
}
TfLiteTensor* GetZeroPoints() {
PackWeightToTensor(&zero_points_tensor_, zero_points_, zero_points_size_);
zero_points_tensor_.data.i32 = zero_points_.data();
return &zero_points_tensor_;
TfLiteTensor* GetInputZeroPoints() {
PackWeightToTensor(&zero_points_tensor0_, input_zp_,
quantization_extra_scratch_buffer_sizes_);
zero_points_tensor0_.data.i32 = input_zp_.data();
return &zero_points_tensor0_;
}
TfLiteTensor* GetAuxInputZeroPoints() {
PackWeightToTensor(&zero_points_tensor1_, aux_input_zp_,
quantization_extra_scratch_buffer_sizes_);
zero_points_tensor1_.data.i32 = aux_input_zp_.data();
return &zero_points_tensor1_;
}
TfLiteTensor* GetOutputStateZeroPoints() {
PackWeightToTensor(&zero_points_tensor2_, output_state_zp_,
quantization_extra_scratch_buffer_sizes_);
zero_points_tensor2_.data.i32 = output_state_zp_.data();
return &zero_points_tensor2_;
}
TfLiteTensor* GetRowSums() {
PackWeightToTensor(&row_sums_tensor_, row_sums_, row_sums_size_);
@ -776,12 +801,16 @@ class HybridLstmParam : public BaseLstmParam {
~HybridLstmParam() {
TfLiteIntArrayFree(scratch_buffer_tensor_.dims);
TfLiteIntArrayFree(accum_scratch_tensor_.dims);
TfLiteIntArrayFree(scaling_factors_tensor_.dims);
TfLiteIntArrayFree(input_sf_tensor_.dims);
TfLiteIntArrayFree(aux_input_sf_tensor_.dims);
TfLiteIntArrayFree(output_state_sf_tensor_.dims);
TfLiteIntArrayFree(prod_scaling_factors_tensor_.dims);
TfLiteIntArrayFree(input_quantized_tensor_.dims);
TfLiteIntArrayFree(activation_quantized_tensor_.dims);
TfLiteIntArrayFree(cell_quantized_tensor_.dims);
TfLiteIntArrayFree(zero_points_tensor_.dims);
TfLiteIntArrayFree(zero_points_tensor0_.dims);
TfLiteIntArrayFree(zero_points_tensor1_.dims);
TfLiteIntArrayFree(zero_points_tensor2_.dims);
TfLiteIntArrayFree(row_sums_tensor_.dims);
}
@ -792,14 +821,24 @@ class HybridLstmParam : public BaseLstmParam {
std::vector<int32_t> scratch_buffer_size_ = {n_batch_, n_cell_ * 4};
TfLiteTensor scratch_buffer_tensor_;
std::vector<float> scaling_factors_;
std::vector<int32_t> scaling_factors_size_ = {n_batch_};
TfLiteTensor scaling_factors_tensor_;
std::vector<int32_t> quantization_extra_scratch_buffer_sizes_ = {n_batch_};
std::vector<float> input_sf_;
TfLiteTensor input_sf_tensor_;
std::vector<float> aux_input_sf_;
TfLiteTensor aux_input_sf_tensor_;
std::vector<float> output_state_sf_;
TfLiteTensor output_state_sf_tensor_;
std::vector<float> prod_scaling_factors_;
std::vector<int32_t> prod_scaling_factors_size_ = {n_batch_};
TfLiteTensor prod_scaling_factors_tensor_;
std::vector<int32_t> input_zp_;
TfLiteTensor zero_points_tensor0_;
std::vector<int32_t> aux_input_zp_;
TfLiteTensor zero_points_tensor1_;
std::vector<int32_t> output_state_zp_;
TfLiteTensor zero_points_tensor2_;
std::vector<int8_t> input_quantized_;
TfLiteTensor input_quantized_tensor_;
@ -813,10 +852,6 @@ class HybridLstmParam : public BaseLstmParam {
16, 4, 5, 6, 1, 1, 3, 4, -5, 6, 1, 14, 5, 6, 1, 1, 3, 4, -5, 6,
};
std::vector<int32_t> zero_points_;
std::vector<int32_t> zero_points_size_ = {n_batch_};
TfLiteTensor zero_points_tensor_;
std::vector<int32_t> row_sums_;
std::vector<int32_t> row_sums_size_ = {n_row_sums_, n_cell_};
TfLiteTensor row_sums_tensor_;
@ -896,13 +931,17 @@ void TestOneHybridAsymmLSTM() {
/*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, one_parameter.GetScratchBuffer(),
one_parameter.GetScalingFactors(), one_parameter.GetProdScalingFactors(),
one_parameter.GetInputScalingFactors(),
one_parameter.GetAuxInputScalingFactors(),
one_parameter.GetOutputStateScalingFactors(),
one_parameter.GetProdScalingFactors(),
/*recovered_cell_weights=*/nullptr, one_parameter.GetInputQuantized(),
/*aux_input_quantized=*/nullptr,
one_parameter.GetActivationStateQuantized(),
one_parameter.GetCellStateQuantized(), activation, cell,
one_parameter.GetAccumScratchBuffer(), output,
one_parameter.GetZeroPoints(), one_parameter.GetRowSums(),
one_parameter.GetInputZeroPoints(), one_parameter.GetAuxInputZeroPoints(),
one_parameter.GetOutputStateZeroPoints(), one_parameter.GetRowSums(),
one_parameter.GetNumRowSums(), &compute_row_sums, &context);
const std::vector<float> expected_cell = {
7.83134, 1.96158, 2.18285, 3.28739, 0.483214,

View File

@ -45,13 +45,15 @@ enum TemporaryTensor {
kInputQuantized = 1,
kOutputStateQuantized = 2,
kCellStateQuantized = 3,
kScalingFactors = 4,
kProductScalingFactors = 5,
kRecoveredCellWeights = 6,
kAccumScratch = 7,
kZeroPoints = 8,
kRowSums = 9,
kNumTemporaryTensors = 10
kInputScalingFactors = 4,
kOutputStateScalingFactors = 5,
kProductScalingFactors = 6,
kRecoveredCellWeights = 7,
kAccumScratch = 8,
kInputZeroPoints = 9,
kOutputStateZeroPoints = 10,
kRowSums = 11,
kNumTemporaryTensors = 12,
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@ -416,18 +418,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// a vector once (which produces the scaling factors) and multiply it with
// different matrices (which requires multiplying the scaling factors with
// the scaling factor of the matrix).
node->temporaries->data[kScalingFactors] =
scratch_tensor_index + kScalingFactors;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, kScalingFactors);
scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
node->temporaries->data[kInputScalingFactors] =
op_data->scratch_tensor_index + kInputScalingFactors;
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
input_sf->type = kTfLiteFloat32;
input_sf->allocation_type = kTfLiteArenaRw;
int scaling_dims[1] = {n_batch};
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
scaling_factors_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size));
if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
input_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, input_sf, input_sf_size));
}
node->temporaries->data[kOutputStateScalingFactors] =
op_data->scratch_tensor_index + kOutputStateScalingFactors;
TfLiteTensor* output_state_sf =
GetTemporary(context, node, kOutputStateScalingFactors);
output_state_sf->type = kTfLiteFloat32;
output_state_sf->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
output_state_sf_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
output_state_sf_size));
}
node->temporaries->data[kProductScalingFactors] =
scratch_tensor_index + kProductScalingFactors;
@ -477,15 +490,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, accum_scratch, accum_size));
}
node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints;
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
zero_points->type = kTfLiteFloat32;
zero_points->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
zero_points_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
zero_points_size));
node->temporaries->data[kInputZeroPoints] =
op_data->scratch_tensor_index + kInputZeroPoints;
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
input_zp->type = kTfLiteFloat32;
input_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
input_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, input_zp, input_zp_size));
}
node->temporaries->data[kOutputStateZeroPoints] =
op_data->scratch_tensor_index + kOutputStateZeroPoints;
TfLiteTensor* output_state_zp =
GetTemporary(context, node, kOutputStateZeroPoints);
output_state_zp->type = kTfLiteFloat32;
output_state_zp->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
output_state_zp_size->data[0] = n_batch;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
output_state_zp_size));
}
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
@ -640,7 +666,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
projection_weights, projection_bias, &lstm_params,
/*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer,
GetTemporary(context, node, kScalingFactors),
GetTemporary(context, node, kInputScalingFactors),
/*aux_input_sf=*/nullptr,
GetTemporary(context, node, kOutputStateScalingFactors),
GetTemporary(context, node, kProductScalingFactors),
GetTemporary(context, node, kRecoveredCellWeights),
GetTemporary(context, node, kInputQuantized),
@ -648,8 +676,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, kOutputStateQuantized),
GetTemporary(context, node, kCellStateQuantized), output_state,
cell_state, GetTemporary(context, node, kAccumScratch), output,
GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size,
&op_data->compute_row_sums,
GetTemporary(context, node, kInputZeroPoints),
/*aux_input_zp=*/nullptr,
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
row_sums_size, &op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
}
default: