Add additional scratch buffers for Hybrid LSTM, storing quantization temporary informations separately for different inputs.
This allows to process LSTM gates independently, instead of processing strictly sequentially based on different inputs. PiperOrigin-RevId: 318563361 Change-Id: I27b4d14ec9e93083a5ad48729260d7b4a1d43cde
This commit is contained in:
parent
4d7d1a8c34
commit
61a6e22f5f
@ -138,15 +138,19 @@ enum TemporaryTensor {
|
||||
kBwActivationStateQuantized = 4,
|
||||
kFwCellStateQuantized = 5,
|
||||
kBwCellStateQuantized = 6,
|
||||
kScalingFactors = 7,
|
||||
kProductScalingFactors = 8,
|
||||
kRecoveredCellWeights = 9,
|
||||
kAccumScratchBuffer = 10,
|
||||
kZeroPoints = 11,
|
||||
kFwRowSums = 12,
|
||||
kBwRowSums = 13,
|
||||
kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input.
|
||||
kNumTemporaryTensors = 15
|
||||
kInputScalingFactors = 7,
|
||||
kAuxInputScalingFactors = 8,
|
||||
kOutputStateScalingFactors = 9,
|
||||
kProductScalingFactors = 10,
|
||||
kRecoveredCellWeights = 11,
|
||||
kAccumScratchBuffer = 12,
|
||||
kInputZeroPoints = 13,
|
||||
kAuxInputZeroPoints = 14,
|
||||
kOutputStateZeroPoints = 15,
|
||||
kFwRowSums = 16,
|
||||
kBwRowSums = 17,
|
||||
kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input.
|
||||
kNumTemporaryTensors = 19,
|
||||
};
|
||||
|
||||
struct OpData {
|
||||
@ -699,18 +703,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// a vector once (which produces the scaling factors) and multiply it with
|
||||
// different matrices (which requires multiplying the scaling factors with
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kScalingFactors] =
|
||||
op_data->scratch_tensor_index + kScalingFactors;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
node->temporaries->data[kInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kInputScalingFactors;
|
||||
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
|
||||
input_sf->type = kTfLiteFloat32;
|
||||
input_sf->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
|
||||
scaling_factors_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||
scaling_factors_size));
|
||||
if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
|
||||
input_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, input_sf, input_sf_size));
|
||||
}
|
||||
node->temporaries->data[kAuxInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kAuxInputScalingFactors;
|
||||
TfLiteTensor* aux_input_sf =
|
||||
GetTemporary(context, node, kAuxInputScalingFactors);
|
||||
aux_input_sf->type = kTfLiteFloat32;
|
||||
aux_input_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1);
|
||||
aux_input_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf,
|
||||
aux_input_sf_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateScalingFactors] =
|
||||
op_data->scratch_tensor_index + kOutputStateScalingFactors;
|
||||
TfLiteTensor* output_state_sf =
|
||||
GetTemporary(context, node, kOutputStateScalingFactors);
|
||||
output_state_sf->type = kTfLiteFloat32;
|
||||
output_state_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
|
||||
output_state_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
|
||||
output_state_sf_size));
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
op_data->scratch_tensor_index + kProductScalingFactors;
|
||||
@ -768,16 +795,40 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Allocate temporary tensors for storing zero-points.
|
||||
node->temporaries->data[kZeroPoints] =
|
||||
op_data->scratch_tensor_index + kZeroPoints;
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||
zero_points->type = kTfLiteFloat32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||
zero_points_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||
zero_points_size));
|
||||
node->temporaries->data[kInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kInputZeroPoints;
|
||||
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
|
||||
input_zp->type = kTfLiteFloat32;
|
||||
input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
|
||||
input_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, input_zp, input_zp_size));
|
||||
}
|
||||
node->temporaries->data[kAuxInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kAuxInputZeroPoints;
|
||||
TfLiteTensor* aux_input_zp =
|
||||
GetTemporary(context, node, kAuxInputZeroPoints);
|
||||
aux_input_zp->type = kTfLiteFloat32;
|
||||
aux_input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1);
|
||||
aux_input_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp,
|
||||
aux_input_zp_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateZeroPoints] =
|
||||
op_data->scratch_tensor_index + kOutputStateZeroPoints;
|
||||
TfLiteTensor* output_state_zp =
|
||||
GetTemporary(context, node, kOutputStateZeroPoints);
|
||||
output_state_zp->type = kTfLiteFloat32;
|
||||
output_state_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
|
||||
output_state_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
|
||||
output_state_zp_size));
|
||||
}
|
||||
|
||||
// Allocate temporary tensors for caching row sums for hybrid zero-point
|
||||
@ -1071,8 +1122,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, kFwCellStateQuantized);
|
||||
TfLiteTensor* bw_cell_state_quantized =
|
||||
GetTemporary(context, node, kBwCellStateQuantized);
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
TfLiteTensor* prod_scaling_factors =
|
||||
GetTemporary(context, node, kProductScalingFactors);
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
@ -1082,7 +1131,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
: nullptr;
|
||||
TfLiteTensor* accum_scratch =
|
||||
GetTemporary(context, node, kAccumScratchBuffer);
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||
TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums);
|
||||
TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums);
|
||||
const int fw_row_sums_size = fw_row_sums->dims->data[0];
|
||||
@ -1104,12 +1152,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||
fw_activation_state_quantized, fw_cell_state_quantized,
|
||||
fw_activation_state, fw_cell_state, accum_scratch, fw_output,
|
||||
zero_points, fw_row_sums, fw_row_sums_size,
|
||||
&op_data->compute_fw_row_sums,
|
||||
fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
||||
GetTemporary(context, node, kAuxInputScalingFactors),
|
||||
GetTemporary(context, node, kOutputStateScalingFactors),
|
||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||
aux_input_quantized, fw_activation_state_quantized,
|
||||
fw_cell_state_quantized, fw_activation_state, fw_cell_state,
|
||||
accum_scratch, fw_output,
|
||||
GetTemporary(context, node, kInputZeroPoints),
|
||||
GetTemporary(context, node, kAuxInputZeroPoints),
|
||||
GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums,
|
||||
fw_row_sums_size, &op_data->compute_fw_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
@ -1130,12 +1183,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
||||
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||
bw_activation_state_quantized, bw_cell_state_quantized,
|
||||
bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output,
|
||||
zero_points, bw_row_sums, bw_row_sums_size,
|
||||
&op_data->compute_bw_row_sums,
|
||||
bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
|
||||
GetTemporary(context, node, kAuxInputScalingFactors),
|
||||
GetTemporary(context, node, kOutputStateScalingFactors),
|
||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||
aux_input_quantized, bw_activation_state_quantized,
|
||||
bw_cell_state_quantized, bw_activation_state, bw_cell_state,
|
||||
accum_scratch, actual_bw_output,
|
||||
GetTemporary(context, node, kInputZeroPoints),
|
||||
GetTemporary(context, node, kAuxInputZeroPoints),
|
||||
GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums,
|
||||
bw_row_sums_size, &op_data->compute_bw_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
TF_LITE_ENSURE_OK(context, bw_pass_status);
|
||||
return kTfLiteOk;
|
||||
|
||||
@ -66,13 +66,15 @@ enum HybridTemporaryTensor {
|
||||
kInputQuantized = 1,
|
||||
kOutputStateQuantized = 2,
|
||||
kCellStateQuantized = 3,
|
||||
kScalingFactors = 4,
|
||||
kProductScalingFactors = 5,
|
||||
kRecoveredCellWeights = 6,
|
||||
kAccumScratch = 7,
|
||||
kZeroPoints = 8,
|
||||
kRowSums = 9,
|
||||
kNumHybridTemporaryTensors = 10,
|
||||
kInputScalingFactors = 4,
|
||||
kOutputStateScalingFactors = 5,
|
||||
kProductScalingFactors = 6,
|
||||
kRecoveredCellWeights = 7,
|
||||
kAccumScratch = 8,
|
||||
kInputZeroPoints = 9,
|
||||
kOutputStateZeroPoints = 10,
|
||||
kRowSums = 11,
|
||||
kNumHybridTemporaryTensors = 12,
|
||||
};
|
||||
|
||||
TfLiteStatus PopulateQuantizedLstmParams8x8_16(
|
||||
@ -1333,18 +1335,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// a vector once (which produces the scaling factors) and multiply it with
|
||||
// different matrices (which requires multiplying the scaling factors with
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kScalingFactors] =
|
||||
op_data->scratch_tensor_index + kScalingFactors;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
node->temporaries->data[kInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kInputScalingFactors;
|
||||
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
|
||||
input_sf->type = kTfLiteFloat32;
|
||||
input_sf->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
|
||||
scaling_factors_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||
scaling_factors_size));
|
||||
if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
|
||||
input_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, input_sf, input_sf_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateScalingFactors] =
|
||||
op_data->scratch_tensor_index + kOutputStateScalingFactors;
|
||||
TfLiteTensor* output_state_sf =
|
||||
GetTemporary(context, node, kOutputStateScalingFactors);
|
||||
output_state_sf->type = kTfLiteFloat32;
|
||||
output_state_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
|
||||
output_state_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
|
||||
output_state_sf_size));
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
op_data->scratch_tensor_index + kProductScalingFactors;
|
||||
@ -1394,18 +1407,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||
}
|
||||
|
||||
node->temporaries->data[kZeroPoints] =
|
||||
op_data->scratch_tensor_index + kZeroPoints;
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||
zero_points->type = kTfLiteFloat32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
int zero_points_dims[1] = {n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
|
||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||
zero_points_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||
zero_points_size));
|
||||
node->temporaries->data[kInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kInputZeroPoints;
|
||||
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
|
||||
input_zp->type = kTfLiteFloat32;
|
||||
input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
|
||||
input_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, input_zp, input_zp_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateZeroPoints] =
|
||||
op_data->scratch_tensor_index + kOutputStateZeroPoints;
|
||||
TfLiteTensor* output_state_zp =
|
||||
GetTemporary(context, node, kOutputStateZeroPoints);
|
||||
output_state_zp->type = kTfLiteFloat32;
|
||||
output_state_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
|
||||
output_state_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
|
||||
output_state_zp_size));
|
||||
}
|
||||
|
||||
node->temporaries->data[kRowSums] =
|
||||
@ -1621,7 +1644,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
projection_weights, projection_bias, params,
|
||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||
GetTemporary(context, node, kScratchBuffer),
|
||||
GetTemporary(context, node, kScalingFactors),
|
||||
GetTemporary(context, node, kInputScalingFactors),
|
||||
/*aux_input_sf=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateScalingFactors),
|
||||
GetTemporary(context, node, kProductScalingFactors),
|
||||
GetTemporary(context, node, kRecoveredCellWeights),
|
||||
GetTemporary(context, node, kInputQuantized),
|
||||
@ -1629,8 +1654,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, kOutputStateQuantized),
|
||||
GetTemporary(context, node, kCellStateQuantized), output_state,
|
||||
cell_state, GetTemporary(context, node, kAccumScratch), output,
|
||||
GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size,
|
||||
&op_data->compute_row_sums,
|
||||
GetTemporary(context, node, kInputZeroPoints),
|
||||
/*aux_input_zp=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
|
||||
row_sums_size, &op_data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
} else {
|
||||
const int num_intermediate_tensors = node->intermediates->size;
|
||||
|
||||
@ -785,14 +785,15 @@ inline void LstmStepHybrid(
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
int output_batch_leading_dim, float* scratch0, float* scratch1,
|
||||
float* scratch2, float* scratch3, float* scaling_factors,
|
||||
float* scaling_factors_scratch, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
|
||||
int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
|
||||
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
|
||||
float* output_ptr, int32_t* zero_points, int32_t* row_sums,
|
||||
int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
||||
CpuBackendContext* context) {
|
||||
float* scratch2, float* scratch3, float* input_sf, float* aux_input_sf,
|
||||
float* output_state_sf, float* scaling_factors_scratch,
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr,
|
||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_output_scratch, float* output_state_ptr,
|
||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
||||
int32_t* input_zp, int32_t* aux_input_zp, int32_t* output_state_zp,
|
||||
int32_t* row_sums, int row_sums_size, bool* compute_row_sums,
|
||||
bool asymmetric_quantize_inputs, CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existence of only one to the get the condition.
|
||||
@ -897,38 +898,37 @@ inline void LstmStepHybrid(
|
||||
|
||||
if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) {
|
||||
tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input,
|
||||
quantized_input_ptr, scaling_factors,
|
||||
zero_points, asymmetric_quantize_inputs);
|
||||
quantized_input_ptr, input_sf, input_zp,
|
||||
asymmetric_quantize_inputs);
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
input_to_input_weights_scale, scaling_factors, n_batch,
|
||||
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, input_to_input_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
input_to_input_weights_scale, input_sf, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
|
||||
input_to_input_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
input_to_forget_weights_scale, scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
input_to_forget_weights_scale, input_sf, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
|
||||
input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
input_to_cell_weights_scale, scaling_factors, n_batch,
|
||||
cell_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
input_to_cell_weights_scale, input_sf, n_batch, cell_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
|
||||
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
input_to_output_weights_scale, scaling_factors, n_batch,
|
||||
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, input_to_output_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
input_to_output_weights_scale, input_sf, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, input_zp, accum_scratch_ptr,
|
||||
input_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||
@ -936,15 +936,15 @@ inline void LstmStepHybrid(
|
||||
if (aux_input_ptr != nullptr &&
|
||||
!tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) {
|
||||
tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input,
|
||||
quantized_aux_input_ptr, scaling_factors,
|
||||
zero_points, asymmetric_quantize_inputs);
|
||||
quantized_aux_input_ptr, aux_input_sf,
|
||||
aux_input_zp, asymmetric_quantize_inputs);
|
||||
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, aux_input_to_input_weights_scale,
|
||||
scaling_factors, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_sf, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr,
|
||||
aux_input_to_input_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
@ -952,24 +952,23 @@ inline void LstmStepHybrid(
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, aux_input_to_forget_weights_scale,
|
||||
scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_sf, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr,
|
||||
aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
|
||||
scaling_factors, n_batch, cell_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
quantized_aux_input_ptr, aux_input_to_cell_weights_scale, aux_input_sf,
|
||||
n_batch, cell_gate_scratch, /*per_channel_scale=*/nullptr, aux_input_zp,
|
||||
accum_scratch_ptr, aux_input_to_cell_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, aux_input_to_output_weights_scale,
|
||||
scaling_factors, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_sf, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch_ptr,
|
||||
aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
}
|
||||
@ -978,14 +977,14 @@ inline void LstmStepHybrid(
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
tensor_utils::BatchQuantizeFloats(
|
||||
output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
|
||||
scaling_factors, zero_points, asymmetric_quantize_inputs);
|
||||
output_state_sf, output_state_zp, asymmetric_quantize_inputs);
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, recurrent_to_input_weights_scale,
|
||||
scaling_factors, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
output_state_sf, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
|
||||
recurrent_to_input_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
@ -993,24 +992,24 @@ inline void LstmStepHybrid(
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, recurrent_to_forget_weights_scale,
|
||||
scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
output_state_sf, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
|
||||
recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
|
||||
scaling_factors, n_batch, cell_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
output_state_sf, n_batch, cell_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
|
||||
recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, recurrent_to_output_weights_scale,
|
||||
scaling_factors, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
output_state_sf, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, output_state_zp, accum_scratch_ptr,
|
||||
recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
}
|
||||
@ -1102,7 +1101,7 @@ inline void LstmStepHybrid(
|
||||
params->activation, projection_weights_ptr, projection_weights_scale,
|
||||
projection_bias_ptr, params->proj_clip, output_state_ptr,
|
||||
asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums,
|
||||
context, scratch2, quantized_output_scratch, scaling_factors, zero_points,
|
||||
context, scratch2, quantized_output_scratch, input_sf, input_zp,
|
||||
accum_scratch_ptr);
|
||||
|
||||
// Copy output_state_ptr to the output. Note that the output batch rows may
|
||||
@ -1892,14 +1891,16 @@ TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
|
||||
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
|
||||
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) {
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
|
||||
TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
|
||||
TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
|
||||
bool* compute_row_sums, CpuBackendContext* context) {
|
||||
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
|
||||
const int n_input = input->dims->data[input->dims->size - 1];
|
||||
int max_time, n_batch;
|
||||
@ -1939,10 +1940,14 @@ TfLiteStatus EvalHybrid(
|
||||
const int output_batch_leading_dim =
|
||||
output->dims->data[output->dims->size - 1];
|
||||
|
||||
int32_t* zero_points_ptr = nullptr;
|
||||
int32_t* input_zp_ptr = nullptr;
|
||||
int32_t* aux_input_zp_ptr = nullptr;
|
||||
int32_t* output_state_zp_ptr = nullptr;
|
||||
int32_t* row_sums_ptr = nullptr;
|
||||
if (params->asymmetric_quantize_inputs) {
|
||||
zero_points_ptr = GetTensorData<int32_t>(zero_points);
|
||||
input_zp_ptr = GetTensorData<int32_t>(input_zp);
|
||||
aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp);
|
||||
output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp);
|
||||
row_sums_ptr = GetTensorData<int32_t>(row_sums);
|
||||
}
|
||||
|
||||
@ -2005,7 +2010,9 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
|
||||
output_gate_scratch, GetTensorData<float>(scaling_factors),
|
||||
output_gate_scratch, GetTensorData<float>(input_sf),
|
||||
GetTensorData<float>(aux_input_sf),
|
||||
GetTensorData<float>(output_state_sf),
|
||||
GetTensorData<float>(prod_scaling_factors),
|
||||
GetTensorData<float>(recovered_cell_weights),
|
||||
GetTensorData<int8_t>(input_quantized),
|
||||
@ -2014,8 +2021,9 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<int8_t>(cell_state_quantized),
|
||||
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
|
||||
GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
|
||||
zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums,
|
||||
params->asymmetric_quantize_inputs, context);
|
||||
input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr,
|
||||
row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs,
|
||||
context);
|
||||
}
|
||||
} else {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
@ -2092,7 +2100,9 @@ TfLiteStatus EvalHybrid(
|
||||
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
||||
output_batch_leading_dim, input_gate_scratch_ptr,
|
||||
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
|
||||
output_gate_scratch_ptr, GetTensorData<float>(scaling_factors),
|
||||
output_gate_scratch_ptr, GetTensorData<float>(input_sf),
|
||||
GetTensorData<float>(aux_input_sf),
|
||||
GetTensorData<float>(output_state_sf),
|
||||
GetTensorData<float>(prod_scaling_factors),
|
||||
GetTensorData<float>(recovered_cell_weights),
|
||||
GetTensorData<int8_t>(input_quantized),
|
||||
@ -2100,8 +2110,9 @@ TfLiteStatus EvalHybrid(
|
||||
GetTensorData<int8_t>(output_state_quantized),
|
||||
GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
|
||||
cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
|
||||
output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size,
|
||||
compute_row_sums, params->asymmetric_quantize_inputs, context);
|
||||
output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr,
|
||||
row_sums_ptr, row_sums_size, compute_row_sums,
|
||||
params->asymmetric_quantize_inputs, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,14 +148,16 @@ TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
|
||||
int output_offset, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
|
||||
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
|
||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer,
|
||||
TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums,
|
||||
int row_sums_size, bool* compute_row_sums, CpuBackendContext* context);
|
||||
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
|
||||
TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
|
||||
TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
|
||||
TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
|
||||
bool* compute_row_sums, CpuBackendContext* context);
|
||||
|
||||
TfLiteStatus EvalInteger8x8_16(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
|
||||
@ -654,15 +654,27 @@ class HybridLstmParam : public BaseLstmParam {
|
||||
scratch_buffer_tensor_.data.f = scratch_buffer_.data();
|
||||
return &scratch_buffer_tensor_;
|
||||
}
|
||||
TfLiteTensor* GetScalingFactors() {
|
||||
PackWeightToTensor(&scaling_factors_tensor_, scaling_factors_,
|
||||
scaling_factors_size_);
|
||||
scaling_factors_tensor_.data.f = scaling_factors_.data();
|
||||
return &scaling_factors_tensor_;
|
||||
TfLiteTensor* GetInputScalingFactors() {
|
||||
PackWeightToTensor(&input_sf_tensor_, input_sf_,
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
input_sf_tensor_.data.f = input_sf_.data();
|
||||
return &input_sf_tensor_;
|
||||
}
|
||||
TfLiteTensor* GetAuxInputScalingFactors() {
|
||||
PackWeightToTensor(&aux_input_sf_tensor_, aux_input_sf_,
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
aux_input_sf_tensor_.data.f = aux_input_sf_.data();
|
||||
return &aux_input_sf_tensor_;
|
||||
}
|
||||
TfLiteTensor* GetOutputStateScalingFactors() {
|
||||
PackWeightToTensor(&output_state_sf_tensor_, output_state_sf_,
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
output_state_sf_tensor_.data.f = output_state_sf_.data();
|
||||
return &output_state_sf_tensor_;
|
||||
}
|
||||
TfLiteTensor* GetProdScalingFactors() {
|
||||
PackWeightToTensor(&prod_scaling_factors_tensor_, prod_scaling_factors_,
|
||||
prod_scaling_factors_size_);
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
prod_scaling_factors_tensor_.data.f = prod_scaling_factors_.data();
|
||||
return &prod_scaling_factors_tensor_;
|
||||
}
|
||||
@ -682,10 +694,23 @@ class HybridLstmParam : public BaseLstmParam {
|
||||
cell_quantized_tensor_.data.int8 = cell_quantized_.data();
|
||||
return &cell_quantized_tensor_;
|
||||
}
|
||||
TfLiteTensor* GetZeroPoints() {
|
||||
PackWeightToTensor(&zero_points_tensor_, zero_points_, zero_points_size_);
|
||||
zero_points_tensor_.data.i32 = zero_points_.data();
|
||||
return &zero_points_tensor_;
|
||||
TfLiteTensor* GetInputZeroPoints() {
|
||||
PackWeightToTensor(&zero_points_tensor0_, input_zp_,
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
zero_points_tensor0_.data.i32 = input_zp_.data();
|
||||
return &zero_points_tensor0_;
|
||||
}
|
||||
TfLiteTensor* GetAuxInputZeroPoints() {
|
||||
PackWeightToTensor(&zero_points_tensor1_, aux_input_zp_,
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
zero_points_tensor1_.data.i32 = aux_input_zp_.data();
|
||||
return &zero_points_tensor1_;
|
||||
}
|
||||
TfLiteTensor* GetOutputStateZeroPoints() {
|
||||
PackWeightToTensor(&zero_points_tensor2_, output_state_zp_,
|
||||
quantization_extra_scratch_buffer_sizes_);
|
||||
zero_points_tensor2_.data.i32 = output_state_zp_.data();
|
||||
return &zero_points_tensor2_;
|
||||
}
|
||||
TfLiteTensor* GetRowSums() {
|
||||
PackWeightToTensor(&row_sums_tensor_, row_sums_, row_sums_size_);
|
||||
@ -776,12 +801,16 @@ class HybridLstmParam : public BaseLstmParam {
|
||||
~HybridLstmParam() {
|
||||
TfLiteIntArrayFree(scratch_buffer_tensor_.dims);
|
||||
TfLiteIntArrayFree(accum_scratch_tensor_.dims);
|
||||
TfLiteIntArrayFree(scaling_factors_tensor_.dims);
|
||||
TfLiteIntArrayFree(input_sf_tensor_.dims);
|
||||
TfLiteIntArrayFree(aux_input_sf_tensor_.dims);
|
||||
TfLiteIntArrayFree(output_state_sf_tensor_.dims);
|
||||
TfLiteIntArrayFree(prod_scaling_factors_tensor_.dims);
|
||||
TfLiteIntArrayFree(input_quantized_tensor_.dims);
|
||||
TfLiteIntArrayFree(activation_quantized_tensor_.dims);
|
||||
TfLiteIntArrayFree(cell_quantized_tensor_.dims);
|
||||
TfLiteIntArrayFree(zero_points_tensor_.dims);
|
||||
TfLiteIntArrayFree(zero_points_tensor0_.dims);
|
||||
TfLiteIntArrayFree(zero_points_tensor1_.dims);
|
||||
TfLiteIntArrayFree(zero_points_tensor2_.dims);
|
||||
TfLiteIntArrayFree(row_sums_tensor_.dims);
|
||||
}
|
||||
|
||||
@ -792,14 +821,24 @@ class HybridLstmParam : public BaseLstmParam {
|
||||
std::vector<int32_t> scratch_buffer_size_ = {n_batch_, n_cell_ * 4};
|
||||
TfLiteTensor scratch_buffer_tensor_;
|
||||
|
||||
std::vector<float> scaling_factors_;
|
||||
std::vector<int32_t> scaling_factors_size_ = {n_batch_};
|
||||
TfLiteTensor scaling_factors_tensor_;
|
||||
std::vector<int32_t> quantization_extra_scratch_buffer_sizes_ = {n_batch_};
|
||||
std::vector<float> input_sf_;
|
||||
TfLiteTensor input_sf_tensor_;
|
||||
std::vector<float> aux_input_sf_;
|
||||
TfLiteTensor aux_input_sf_tensor_;
|
||||
std::vector<float> output_state_sf_;
|
||||
TfLiteTensor output_state_sf_tensor_;
|
||||
|
||||
std::vector<float> prod_scaling_factors_;
|
||||
std::vector<int32_t> prod_scaling_factors_size_ = {n_batch_};
|
||||
TfLiteTensor prod_scaling_factors_tensor_;
|
||||
|
||||
std::vector<int32_t> input_zp_;
|
||||
TfLiteTensor zero_points_tensor0_;
|
||||
std::vector<int32_t> aux_input_zp_;
|
||||
TfLiteTensor zero_points_tensor1_;
|
||||
std::vector<int32_t> output_state_zp_;
|
||||
TfLiteTensor zero_points_tensor2_;
|
||||
|
||||
std::vector<int8_t> input_quantized_;
|
||||
TfLiteTensor input_quantized_tensor_;
|
||||
|
||||
@ -813,10 +852,6 @@ class HybridLstmParam : public BaseLstmParam {
|
||||
16, 4, 5, 6, 1, 1, 3, 4, -5, 6, 1, 14, 5, 6, 1, 1, 3, 4, -5, 6,
|
||||
};
|
||||
|
||||
std::vector<int32_t> zero_points_;
|
||||
std::vector<int32_t> zero_points_size_ = {n_batch_};
|
||||
TfLiteTensor zero_points_tensor_;
|
||||
|
||||
std::vector<int32_t> row_sums_;
|
||||
std::vector<int32_t> row_sums_size_ = {n_row_sums_, n_cell_};
|
||||
TfLiteTensor row_sums_tensor_;
|
||||
@ -896,13 +931,17 @@ void TestOneHybridAsymmLSTM() {
|
||||
/*forward_sequence=*/true,
|
||||
/*time_major=*/true,
|
||||
/*output_offset=*/0, one_parameter.GetScratchBuffer(),
|
||||
one_parameter.GetScalingFactors(), one_parameter.GetProdScalingFactors(),
|
||||
one_parameter.GetInputScalingFactors(),
|
||||
one_parameter.GetAuxInputScalingFactors(),
|
||||
one_parameter.GetOutputStateScalingFactors(),
|
||||
one_parameter.GetProdScalingFactors(),
|
||||
/*recovered_cell_weights=*/nullptr, one_parameter.GetInputQuantized(),
|
||||
/*aux_input_quantized=*/nullptr,
|
||||
one_parameter.GetActivationStateQuantized(),
|
||||
one_parameter.GetCellStateQuantized(), activation, cell,
|
||||
one_parameter.GetAccumScratchBuffer(), output,
|
||||
one_parameter.GetZeroPoints(), one_parameter.GetRowSums(),
|
||||
one_parameter.GetInputZeroPoints(), one_parameter.GetAuxInputZeroPoints(),
|
||||
one_parameter.GetOutputStateZeroPoints(), one_parameter.GetRowSums(),
|
||||
one_parameter.GetNumRowSums(), &compute_row_sums, &context);
|
||||
const std::vector<float> expected_cell = {
|
||||
7.83134, 1.96158, 2.18285, 3.28739, 0.483214,
|
||||
|
||||
@ -45,13 +45,15 @@ enum TemporaryTensor {
|
||||
kInputQuantized = 1,
|
||||
kOutputStateQuantized = 2,
|
||||
kCellStateQuantized = 3,
|
||||
kScalingFactors = 4,
|
||||
kProductScalingFactors = 5,
|
||||
kRecoveredCellWeights = 6,
|
||||
kAccumScratch = 7,
|
||||
kZeroPoints = 8,
|
||||
kRowSums = 9,
|
||||
kNumTemporaryTensors = 10
|
||||
kInputScalingFactors = 4,
|
||||
kOutputStateScalingFactors = 5,
|
||||
kProductScalingFactors = 6,
|
||||
kRecoveredCellWeights = 7,
|
||||
kAccumScratch = 8,
|
||||
kInputZeroPoints = 9,
|
||||
kOutputStateZeroPoints = 10,
|
||||
kRowSums = 11,
|
||||
kNumTemporaryTensors = 12,
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
@ -416,18 +418,29 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// a vector once (which produces the scaling factors) and multiply it with
|
||||
// different matrices (which requires multiplying the scaling factors with
|
||||
// the scaling factor of the matrix).
|
||||
node->temporaries->data[kScalingFactors] =
|
||||
scratch_tensor_index + kScalingFactors;
|
||||
TfLiteTensor* scaling_factors =
|
||||
GetTemporary(context, node, kScalingFactors);
|
||||
scaling_factors->type = kTfLiteFloat32;
|
||||
scaling_factors->allocation_type = kTfLiteArenaRw;
|
||||
node->temporaries->data[kInputScalingFactors] =
|
||||
op_data->scratch_tensor_index + kInputScalingFactors;
|
||||
TfLiteTensor* input_sf = GetTemporary(context, node, kInputScalingFactors);
|
||||
input_sf->type = kTfLiteFloat32;
|
||||
input_sf->allocation_type = kTfLiteArenaRw;
|
||||
int scaling_dims[1] = {n_batch};
|
||||
if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
|
||||
scaling_factors_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
|
||||
scaling_factors_size));
|
||||
if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
|
||||
input_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, input_sf, input_sf_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateScalingFactors] =
|
||||
op_data->scratch_tensor_index + kOutputStateScalingFactors;
|
||||
TfLiteTensor* output_state_sf =
|
||||
GetTemporary(context, node, kOutputStateScalingFactors);
|
||||
output_state_sf->type = kTfLiteFloat32;
|
||||
output_state_sf->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
|
||||
output_state_sf_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
|
||||
output_state_sf_size));
|
||||
}
|
||||
node->temporaries->data[kProductScalingFactors] =
|
||||
scratch_tensor_index + kProductScalingFactors;
|
||||
@ -477,15 +490,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, accum_scratch, accum_size));
|
||||
}
|
||||
node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints;
|
||||
TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints);
|
||||
zero_points->type = kTfLiteFloat32;
|
||||
zero_points->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
|
||||
zero_points_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
|
||||
zero_points_size));
|
||||
node->temporaries->data[kInputZeroPoints] =
|
||||
op_data->scratch_tensor_index + kInputZeroPoints;
|
||||
TfLiteTensor* input_zp = GetTemporary(context, node, kInputZeroPoints);
|
||||
input_zp->type = kTfLiteFloat32;
|
||||
input_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
|
||||
input_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, input_zp, input_zp_size));
|
||||
}
|
||||
node->temporaries->data[kOutputStateZeroPoints] =
|
||||
op_data->scratch_tensor_index + kOutputStateZeroPoints;
|
||||
TfLiteTensor* output_state_zp =
|
||||
GetTemporary(context, node, kOutputStateZeroPoints);
|
||||
output_state_zp->type = kTfLiteFloat32;
|
||||
output_state_zp->allocation_type = kTfLiteArenaRw;
|
||||
if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
|
||||
TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
|
||||
output_state_zp_size->data[0] = n_batch;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
|
||||
output_state_zp_size));
|
||||
}
|
||||
node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
|
||||
TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums);
|
||||
@ -640,7 +666,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
projection_weights, projection_bias, &lstm_params,
|
||||
/*forward_sequence=*/true, time_major,
|
||||
/*output_offset=*/0, scratch_buffer,
|
||||
GetTemporary(context, node, kScalingFactors),
|
||||
GetTemporary(context, node, kInputScalingFactors),
|
||||
/*aux_input_sf=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateScalingFactors),
|
||||
GetTemporary(context, node, kProductScalingFactors),
|
||||
GetTemporary(context, node, kRecoveredCellWeights),
|
||||
GetTemporary(context, node, kInputQuantized),
|
||||
@ -648,8 +676,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, kOutputStateQuantized),
|
||||
GetTemporary(context, node, kCellStateQuantized), output_state,
|
||||
cell_state, GetTemporary(context, node, kAccumScratch), output,
|
||||
GetTemporary(context, node, kZeroPoints), row_sums, row_sums_size,
|
||||
&op_data->compute_row_sums,
|
||||
GetTemporary(context, node, kInputZeroPoints),
|
||||
/*aux_input_zp=*/nullptr,
|
||||
GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
|
||||
row_sums_size, &op_data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
default:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user