Make the names of "activation", "activation_state", "output_state" variables consistent: Use "output_state", used by the float implementation (LstmStepFloat).

Remove a few variables in implementation and test that were representing the same value (the index of these tensors, mostly).

Also rename "input cell state" to just "cell state".

PiperOrigin-RevId: 316908413
Change-Id: Icb64ecd31c90f45ef21cf7d48849fb2ec0975d3a
This commit is contained in:
Robert David 2020-06-17 09:53:14 -07:00 committed by TensorFlower Gardener
parent c870b9f920
commit 406d9b5521
8 changed files with 226 additions and 245 deletions

View File

@ -68,19 +68,19 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
const float cell_clip = params->cell_clip;
const float proj_clip = params->proj_clip;
const TfLiteTensor* cell_tensor =
GetVariableInput(context, node, kInputCellStateTensor);
TF_LITE_ENSURE(context, cell_tensor != nullptr);
const TfLiteTensor* cell_state =
GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
auto* cell_params =
static_cast<TfLiteAffineQuantization*>(cell_tensor->quantization.params);
auto* cell_state_params =
static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
auto* proj_params = static_cast<TfLiteAffineQuantization*>(
output_tensor->quantization.params);
if (cell_clip > 0.0) {
integer_lstm_param->quantized_cell_clip = static_cast<int32_t>(
std::min(std::max(cell_clip / cell_params->scale->data[0], -32768.0f),
32767.0f));
integer_lstm_param->quantized_cell_clip = static_cast<int32_t>(std::min(
std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
32767.0f));
} else {
integer_lstm_param->quantized_cell_clip = 0;
}
@ -134,9 +134,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to get the condition.
@ -187,7 +187,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
float layer_norm_forget_scale = default_scale;
float layer_norm_cell_scale = default_scale;
float layer_norm_output_scale = default_scale;
float activation_scale = default_scale;
float output_state_scale = default_scale;
int cell_scale = 1;
// Effective scales.
@ -231,7 +231,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
if (use_projection) {
proj_weight_scale = projection_weights->params.scale;
}
activation_scale = activation_state->params.scale;
output_state_scale = output_state->params.scale;
input_to_forget_weight_scale = input_to_forget_weights->params.scale;
input_to_cell_weight_scale = input_to_cell_weights->params.scale;
@ -240,12 +240,8 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
// Get cell state.
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
// Check cell state (already used above)
TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
TF_LITE_ENSURE(context, cell_scale <= -9);
integer_lstm_param->cell_scale = cell_scale;
input_scale = input->params.scale;
@ -255,31 +251,32 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_16(
effective_input_to_input_scale =
input_to_input_weight_scale * input_scale / intermediate_scale[0];
effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
activation_scale /
output_state_scale /
intermediate_scale[0];
}
effective_input_to_forget_scale =
input_to_forget_weight_scale * input_scale / intermediate_scale[1];
effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
activation_scale /
output_state_scale /
intermediate_scale[1];
effective_input_to_cell_scale =
input_to_cell_weight_scale * input_scale / intermediate_scale[2];
effective_recurrent_to_cell_scale =
recurrent_to_cell_weight_scale * activation_scale / intermediate_scale[2];
effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
output_state_scale /
intermediate_scale[2];
effective_input_to_output_scale =
input_to_output_weight_scale * input_scale / intermediate_scale[3];
effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
activation_scale /
output_state_scale /
intermediate_scale[3];
effective_hidden_scale =
std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
effective_proj_scale =
proj_weight_scale * intermediate_scale[4] / activation_scale;
proj_weight_scale * intermediate_scale[4] / output_state_scale;
if (use_peephole) {
if (!use_cifg) {
@ -419,11 +416,10 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
// Since we have already checked that weights are all there or none, we can
@ -456,7 +452,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
int32_t* output_bias_ptr = nullptr;
int32_t* proj_bias_ptr = nullptr;
int16_t* cell_ptr = nullptr;
int8_t* activation_ptr = nullptr;
int8_t* output_state_ptr = nullptr;
// Scales.
const float default_scale = 1.0;
@ -477,7 +473,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
float layer_norm_forget_scale = default_scale;
float layer_norm_cell_scale = default_scale;
float layer_norm_output_scale = default_scale;
float activation_scale = default_scale;
float output_state_scale = default_scale;
// Effective scales.
float effective_input_to_input_scale = default_scale;
@ -495,7 +491,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
// Zero points
int input_zp = 0;
int activation_zp = 0;
int output_state_zp = 0;
// Populate all the values.
if (!use_cifg) {
@ -537,7 +533,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
proj_bias_ptr = projection_bias->data.i32;
}
}
activation_scale = activation_state->params.scale;
output_state_scale = output_state->params.scale;
input_to_forget_weight_ptr = input_to_forget_weights->data.int8;
input_to_forget_weight_scale = input_to_forget_weights->params.scale;
@ -554,11 +550,11 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
forget_bias_ptr = forget_gate_bias->data.i32;
cell_bias_ptr = cell_bias->data.i32;
output_bias_ptr = output_gate_bias->data.i32;
activation_ptr = activation_state->data.int8;
output_state_ptr = output_state->data.int8;
cell_ptr = cell_state->data.i16;
input_scale = input->params.scale;
input_zp = input->params.zero_point;
activation_zp = activation_state->params.zero_point;
output_state_zp = output_state->params.zero_point;
std::vector<float> intermediate_scale;
for (int i = 0; i < 12; ++i) {
@ -575,27 +571,28 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
effective_input_to_input_scale =
input_to_input_weight_scale * input_scale / intermediate_scale[1];
effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
activation_scale /
output_state_scale /
intermediate_scale[2];
}
effective_input_to_forget_scale =
input_to_forget_weight_scale * input_scale / intermediate_scale[4];
effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
activation_scale /
output_state_scale /
intermediate_scale[5];
effective_input_to_cell_scale =
input_to_cell_weight_scale * input_scale / intermediate_scale[7];
effective_recurrent_to_cell_scale =
recurrent_to_cell_weight_scale * activation_scale / intermediate_scale[8];
effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
output_state_scale /
intermediate_scale[8];
effective_input_to_output_scale =
input_to_output_weight_scale * input_scale / intermediate_scale[10];
effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
activation_scale /
output_state_scale /
intermediate_scale[11];
effective_proj_scale =
proj_weight_scale * std::pow(2, -15) / activation_scale;
proj_weight_scale * std::pow(2, -15) / output_state_scale;
if (use_peephole) {
if (!use_cifg) {
@ -698,18 +695,16 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
const float cell_clip = params->cell_clip;
const float proj_clip = params->proj_clip;
const TfLiteTensor* cell_tensor =
GetInput(context, node, kInputCellStateTensor);
const TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
auto* cell_params = reinterpret_cast<TfLiteAffineQuantization*>(
cell_tensor->quantization.params);
auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
cell_state->quantization.params);
auto* proj_params = reinterpret_cast<TfLiteAffineQuantization*>(
output_tensor->quantization.params);
TF_LITE_ENSURE_EQ(context, cell_params->scale->data[0], 1.0 / 32768);
TF_LITE_ENSURE_EQ(context, cell_state_params->scale->data[0], 1.0 / 32768);
if (cell_clip > 0.0 && cell_clip < 1.0) {
integer_lstm_param->quantized_cell_clip =
static_cast<int>(cell_clip / cell_params->scale->data[0]);
static_cast<int>(cell_clip / cell_state_params->scale->data[0]);
} else {
integer_lstm_param->quantized_cell_clip = 0;
}
@ -1026,12 +1021,12 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
OpData* op_data,
TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
const TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
const int32_t input_zero_point = -input->params.zero_point;
const int32_t activation_zero_point = -activation_state->params.zero_point;
const int32_t output_state_zero_point = -output_state->params.zero_point;
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
@ -1083,8 +1078,8 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, activation_zero_point, recurrent_to_forget_weights, nullptr,
&(integer_lstm_params->recurrent_to_forget_effective_bias)));
context, output_state_zero_point, recurrent_to_forget_weights,
nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
// Modulation gate.
const TfLiteTensor* cell_gate_bias =
@ -1097,7 +1092,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, activation_zero_point, recurrent_to_cell_weights, nullptr,
context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
&(integer_lstm_params->recurrent_to_cell_effective_bias)));
// Output gate.
@ -1112,8 +1107,8 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, activation_zero_point, recurrent_to_output_weights, nullptr,
&(integer_lstm_params->recurrent_to_output_effective_bias)));
context, output_state_zero_point, recurrent_to_output_weights,
nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
// Input gate. The calculation is only meaningful for non-cifg case.
const TfLiteTensor* input_gate_bias =
@ -1126,7 +1121,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
TF_LITE_ENSURE_OK(
context,
PrecomputeZeroPointTimesWeightWithBias(
context, activation_zero_point, recurrent_to_input_weights, nullptr,
context, output_state_zero_point, recurrent_to_input_weights, nullptr,
&(integer_lstm_params->recurrent_to_input_effective_bias)));
// Projection bias. The calculation is only meaningful for with projection.
@ -1198,20 +1193,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, use_layer_norm, is_integer));
// Get the pointer to output, activation_state and cell_state tensors.
// Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
// Check the shape of input state tensors.
// These tensor may be 1D or 2D. It's fine as long as the total size is
// correct.
TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
// Resize the output tensors.
@ -1275,7 +1269,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (is_hybrid_op) {
op_data->compute_row_sums = true;
// Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors.
// output_state and cell_state tensors.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
input_quantized->type = input_to_output_weights->type;
@ -1286,17 +1280,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
}
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* activation_state_quantized =
TfLiteTensor* output_state_quantized =
GetTemporary(context, node, /*index=*/2);
activation_state_quantized->type = input_to_output_weights->type;
activation_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
activation_state->dims)) {
TfLiteIntArray* activation_state_quantized_size =
TfLiteIntArrayCopy(activation_state->dims);
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, activation_state_quantized,
activation_state_quantized_size));
output_state_quantized->type = input_to_output_weights->type;
output_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
output_state->dims)) {
TfLiteIntArray* output_state_quantized_size =
TfLiteIntArrayCopy(output_state->dims);
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output_state_quantized,
output_state_quantized_size));
}
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* cell_state_quantized =
@ -1540,11 +1534,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
TfLiteTensor* output_state =
GetVariableInput(context, node, kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@ -1569,7 +1562,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
/*output_offset=*/0, scratch_buffer, output_state, cell_state,
output);
}
case kTfLiteUInt8:
@ -1580,7 +1573,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* input_quantized =
GetTemporary(context, node, /*index=*/1);
TfLiteTensor* activation_state_quantized =
TfLiteTensor* output_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@ -1614,8 +1607,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*time_major=*/true, /*output_offset=*/0, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state,
/*aux_input_quantized=*/nullptr, output_state_quantized,
cell_state_quantized, output_state, cell_state,
output_scratch_buffer, output, zero_points, row_sums, row_sums_size,
&op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
@ -1638,9 +1631,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, &op_data->integer_lstm_param, activation_state,
cell_state, output, scratch0, scratch1, scratch2, scratch3,
scratch4, scratch5, CpuBackendContext::GetFromContext(context));
params, &op_data->integer_lstm_param, output_state, cell_state,
output, scratch0, scratch1, scratch2, scratch3, scratch4,
scratch5, CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
@ -1660,7 +1653,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, activation_state, cell_state, output,
params, output_state, cell_state, output,
&op_data->integer_lstm_param, scratch0, scratch1, scratch2,
scratch3, scratch4, scratch5, scratch6, scratch7);
return kTfLiteOk;

View File

@ -900,7 +900,7 @@ inline void LstmStepHybrid(
// Fully quantized lstm kernel for 16 bit gate matmul output.
//
// Input activation of size n_batch * n_input:
// Input tensor of size n_batch * n_input:
// input_ptr
//
// LSTM weights:
@ -972,7 +972,7 @@ inline void LstmStepHybrid(
// cell_scale: the power of two scale for cell state.
//
// Zero points:
// activation_zp: zero point of activation
// output_state_zp: zero point of output state
// hidden_zp: zero point for hidden state.
//
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
@ -1048,8 +1048,8 @@ inline void LstmStepInteger(
const int32_t* input_to_input_effective_bias,
const int32_t* recurrent_to_input_effective_bias,
const int32_t* projection_effective_bias, int32 n_batch, int32 n_cell,
int32 n_input, int32 n_output, int8_t* activation_ptr,
int32_t activation_zp, int16_t* cell_ptr, int8_t* output_ptr,
int32 n_input, int32 n_output, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
int16_t* scratch_0_ptr, int16_t* scratch_1_ptr, int16_t* scratch_2_ptr,
int16_t* scratch_3_ptr, int8_t* scratch_4_ptr, int32_t* scratch_5_ptr,
CpuBackendContext* context) {
@ -1088,7 +1088,7 @@ inline void LstmStepInteger(
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_1_ptr, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, recurrent_to_forget_effective_bias,
output_state_ptr, recurrent_to_forget_effective_bias,
recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a,
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_1_ptr, context);
@ -1115,7 +1115,7 @@ inline void LstmStepInteger(
n_input, n_cell, 0, scratch_5_ptr, scratch_2_ptr, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, recurrent_to_cell_effective_bias,
output_state_ptr, recurrent_to_cell_effective_bias,
recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a,
effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_2_ptr, context);
@ -1139,7 +1139,7 @@ inline void LstmStepInteger(
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_0_ptr, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, recurrent_to_input_effective_bias,
output_state_ptr, recurrent_to_input_effective_bias,
recurrent_to_input_weight_ptr, effective_recurrent_to_input_scale_a,
effective_recurrent_to_input_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_0_ptr, context);
@ -1180,7 +1180,7 @@ inline void LstmStepInteger(
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_3_ptr, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, recurrent_to_output_effective_bias,
output_state_ptr, recurrent_to_output_effective_bias,
recurrent_to_output_weight_ptr, effective_recurrent_to_output_scale_a,
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_3_ptr, context);
@ -1213,7 +1213,7 @@ inline void LstmStepInteger(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
scratch_4_ptr, projection_effective_bias, proj_weight_ptr,
effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell,
n_output, activation_zp, scratch_5_ptr, output_ptr, context);
n_output, output_state_zp, scratch_5_ptr, output_ptr, context);
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch,
n_output);
@ -1221,12 +1221,12 @@ inline void LstmStepInteger(
} else {
std::copy_n(scratch_4_ptr, n_batch * n_output, output_ptr);
}
std::copy_n(output_ptr, n_batch * n_output, activation_ptr);
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
}
// Fully quantized lstm kernel for 8 bit gate matmul output.
//
// Input activation of size n_batch * n_input:
// Input tensor of size n_batch * n_input:
// input_ptr
//
// LSTM weights:
@ -1298,7 +1298,7 @@ inline void LstmStepInteger(
// cell_scale: the power of two scale for cell state.
//
// Zero points:
// activation_zp: zero point of activation
// output_state_zp: zero point of output state.
// hidden_zp: zero point for hidden state.
//
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
@ -1367,8 +1367,8 @@ void LstmStepInteger(
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
const int32_t* intermediate_zp, int32 quantized_cell_clip,
int32 quantized_proj_clip, int32 n_batch, int32 n_cell, int32 n_input,
int32 n_output, int32 output_batch_leading_dim, int8_t* activation_ptr,
int32_t activation_zp, int16_t* cell_ptr, int8_t* output_ptr,
int32 n_output, int32 output_batch_leading_dim, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
int16_t* scratch7) {
@ -1381,7 +1381,7 @@ void LstmStepInteger(
n_batch, n_input, n_cell, scratch0, intermediate_zp[4]);
tensor_utils::MatrixBatchVectorMultiply(
activation_ptr, activation_zp, recurrent_to_forget_weight_ptr,
output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
effective_recurrent_to_forget_scale_a,
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell,
scratch1, intermediate_zp[5]);
@ -1408,7 +1408,7 @@ void LstmStepInteger(
n_input, n_cell, scratch0, intermediate_zp[7]);
tensor_utils::MatrixBatchVectorMultiply(
activation_ptr, activation_zp, recurrent_to_cell_weight_ptr,
output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
n_batch, n_output, n_cell, scratch1, intermediate_zp[8]);
@ -1434,7 +1434,7 @@ void LstmStepInteger(
n_batch, n_input, n_cell, scratch0, intermediate_zp[10]);
tensor_utils::MatrixBatchVectorMultiply(
activation_ptr, activation_zp, recurrent_to_output_weight_ptr,
output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
effective_recurrent_to_output_scale_a,
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell,
scratch1, intermediate_zp[11]);
@ -1478,7 +1478,7 @@ void LstmStepInteger(
// Projection.
tensor_utils::MatrixBatchVectorMultiply(
scratch3, proj_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
proj_bias_ptr, n_batch, n_cell, n_output, activation_zp, output_ptr);
proj_bias_ptr, n_batch, n_cell, n_output, output_state_zp, output_ptr);
// Projection clipping.
if (quantized_proj_clip > 0) {
@ -1486,8 +1486,8 @@ void LstmStepInteger(
n_output);
}
// Copy output to activation.
std::copy_n(output_ptr, n_batch * n_output, activation_ptr);
// Copy output to output state.
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
}
} // namespace
@ -1518,9 +1518,8 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
int max_time, n_batch;
if (input->dims->size == 3) {
@ -1604,10 +1603,9 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
output_ptr);
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
output_gate_scratch, output_ptr);
}
} else {
for (int b = 0; b < n_batch; b++) {
@ -1628,9 +1626,9 @@ TfLiteStatus EvalFloat(
float* output_ptr = GetTensorData<float>(output) +
time_offset * output_step + output_offset;
// Offset the {activation,cell}_state pointers to the right batch.
float* activation_state_ptr = GetTensorData<float>(activation_state) +
b * output_batch_leading_dim;
// Offset the {output,cell}_state pointers to the right batch.
float* output_state_ptr =
GetTensorData<float>(output_state) + b * output_batch_leading_dim;
float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
// Offset the scratch pointers to the right batch.
float* input_gate_scratch_ptr =
@ -1666,7 +1664,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
output_gate_scratch_ptr, output_ptr);
}
@ -1939,10 +1937,10 @@ TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1,
TfLiteTensor* scratch2, TfLiteTensor* scratch3, TfLiteTensor* scratch4,
TfLiteTensor* scratch5, CpuBackendContext* context) {
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
CpuBackendContext* context) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
@ -1959,7 +1957,7 @@ TfLiteStatus EvalInteger8x8_16(
const int n_output = recurrent_to_output_weights->dims->data[1];
// Activation zero point
int activation_zp = activation_state->params.zero_point;
int output_state_zp = output_state->params.zero_point;
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
@ -2042,8 +2040,8 @@ TfLiteStatus EvalInteger8x8_16(
integer_lstm_param->input_to_input_effective_bias.get(),
integer_lstm_param->recurrent_to_input_effective_bias.get(),
integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
n_input, n_output, GetTensorData<int8_t>(activation_state),
activation_zp, GetTensorData<int16_t>(cell_state), output_ptr,
n_input, n_output, GetTensorData<int8_t>(output_state), output_state_zp,
GetTensorData<int16_t>(cell_state), output_ptr,
GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
@ -2072,7 +2070,7 @@ TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* activation_state,
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
@ -2131,11 +2129,11 @@ TfLiteStatus EvalInteger8x8_8(
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
int8_t* activation_ptr = GetTensorData<int8_t>(activation_state);
int8_t* output_state_ptr = GetTensorData<int8_t>(output_state);
int8_t* output_ptr = nullptr;
const int32 input_zp = input->params.zero_point;
const int32 activation_zp = activation_state->params.zero_point;
const int32 output_state_zp = output_state->params.zero_point;
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
@ -2222,7 +2220,7 @@ TfLiteStatus EvalInteger8x8_8(
integer_lstm_param->intermediate_zp,
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
n_output, output_batch_leading_dim, activation_ptr, activation_zp,
n_output, output_batch_leading_dim, output_state_ptr, output_state_zp,
cell_ptr, output_ptr, GetTensorData<int8_t>(scratch0),
GetTensorData<int8_t>(scratch1), GetTensorData<int16_t>(scratch2),
GetTensorData<int16_t>(scratch3), GetTensorData<int16_t>(scratch4),

View File

@ -120,9 +120,8 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output);
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output);
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
@ -179,10 +178,10 @@ TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1,
TfLiteTensor* scratch2, TfLiteTensor* scratch3, TfLiteTensor* scratch4,
TfLiteTensor* scratch5, CpuBackendContext* context);
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
CpuBackendContext* context);
TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
@ -203,7 +202,7 @@ TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* activation_state,
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,

View File

@ -57,8 +57,8 @@ constexpr int kProjectionBiasTensor = 17; // Optional
// These state tensors are defined as variable tensors, and will be modified by
// this op.
constexpr int kInputActivationStateTensor = 18;
constexpr int kInputCellStateTensor = 19;
constexpr int kOutputStateTensor = 18;
constexpr int kCellStateTensor = 19;
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
// matrix.

View File

@ -104,10 +104,10 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
// Adding the 2 input state tensors.
input_activation_state_ =
// Adding the 2 state tensors.
output_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
input_cell_state_ =
cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
// Layer norm weights.
@ -266,13 +266,11 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
int input_activation_state_;
int input_cell_state_;
int output_;
int output_state_;
int cell_state_;
int output_;
int n_batch_;
int n_input_;
int n_cell_;
@ -553,7 +551,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
@ -1697,7 +1695,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
@ -1768,7 +1766,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
@ -1841,7 +1839,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
@ -1955,7 +1953,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
@ -2026,7 +2024,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
@ -2098,7 +2096,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
@ -2216,13 +2214,13 @@ class LSTMIntegerOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
// Adding the 2 input state tensors.
input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18],
ranges[18].first, ranges[18].second},
true);
input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
ranges[19].first, ranges[19].second},
true);
// Adding the 2 state tensors.
output_state_ = AddInput({TensorType_INT16, input_shapes[18],
ranges[18].first, ranges[18].second},
true);
cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
ranges[19].first, ranges[19].second},
true);
// Layer norm weights.
if (use_layer_norm) {
@ -2386,8 +2384,6 @@ class LSTMIntegerOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
int input_activation_state_;
int input_cell_state_;
int intermediates_[5];
@ -2483,7 +2479,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
@ -2517,14 +2513,14 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
{-0.5, 0.5}, // projection_weight tensor
{-1, 1}, // projection_bias tensor
{-1.0, 32767.0 / 32768}, // activation_state tensor
{-1.0, 32767.0 / 32768}, // output_state tensor
{-1, 1}, // cell_state tensor
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
// Output scale is the same as input activation scale and only activation
// Output scale is the same as output_state scale and only output_state
// scale is used in the op, so this is only provided for clarity.
{-1.0, 32767.0 / 32768}, // output tensor.
};
@ -2685,7 +2681,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
@ -2719,14 +2715,14 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
{-0.5, 0.5}, // projection_weight tensor
{-1, 1}, // projection_bias tensor
{-1.0, 32767.0 / 32768}, // activation_state tensor
{-1.0, 32767.0 / 32768}, // output_state tensor
{-1, 1}, // cell_state tensor
{-0.5, 0.5}, // input_layer_norm_coefficient tensor
{-0.5, 0.5}, // forget_layer_norm_coefficient tensor
{-1.0, 1.0}, // cell_layer_norm_coefficient tensor
{-1.0, 1.0}, // output_layer_norm_coefficient tensor
// Output scale is the same as input activation scale and only activation
// Output scale is the same as output_state scale and only output_state
// scale is used in the op, so this is only provided for clarity.
{-1.0, 32767.0 / 32768}, // output tensor.
};
@ -2892,13 +2888,13 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
projection_bias_ = AddNullInput();
}
// Adding the 2 input state tensors.
input_activation_state_ = AddInput({TensorType_INT16, input_shapes[18],
ranges[18].first, ranges[18].second},
true);
input_cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
ranges[19].first, ranges[19].second},
true);
// Adding the 2 state tensors.
output_state_ = AddInput({TensorType_INT16, input_shapes[18],
ranges[18].first, ranges[18].second},
true);
cell_state_ = AddInput({TensorType_INT16, input_shapes[19],
ranges[19].first, ranges[19].second},
true);
// Layer norm weights.
if (use_layer_norm) {
@ -3062,8 +3058,6 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
int projection_weights_;
int projection_bias_;
int input_activation_state_;
int input_cell_state_;
int intermediates_[12];
@ -3160,7 +3154,7 @@ TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
{n_output, n_cell}, // projection_weight tensor
{n_output}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
@ -3194,14 +3188,14 @@ TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
{-0.5, 0.5}, // projection_weight tensor
{-1, 1}, // projection_bias tensor
{-1.0, 32767.0 / 32768}, // activation_state tensor
{-1.0, 32767.0 / 32768}, // output_state tensor
{-1.0, 32767.0 / 32768}, // cell_state tensor
{-1.00001, 1.0}, // input_layer_norm_coefficient tensor
{-1.00001, 1.0}, // forget_layer_norm_coefficient tensor
{-1.00001, 1.0}, // cell_layer_norm_coefficient tensor
{-1.00001, 1.0}, // output_layer_norm_coefficient tensor
// Output scale is the same as input activation scale and only activation
// Output scale is the same as output_state scale and only output_state
// scale is used in the op, so this is only provided for clarity.
{-1.0, 32767.0 / 32768}, // output tensor.
};

View File

@ -317,20 +317,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, is_layer_norm_lstm));
// Get the pointer to output, activation_state and cell_state buffer tensors.
// Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
GetVariableInput(context, node, lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
// Check the shape of input state tensors.
// These tensor may be 1D or 2D. It's fine as long as the total size is
// correct.
TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
// Resize the output tensors.
@ -370,7 +370,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (IsHybridOp(input, input_to_output_weights)) {
op_data->compute_row_sums = true;
// Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors.
// output_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
@ -384,17 +384,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateQuantized] =
scratch_tensor_index + kOutputStateQuantized;
TfLiteTensor* activation_state_quantized =
TfLiteTensor* output_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
activation_state_quantized->type = input_to_output_weights->type;
activation_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
activation_state->dims)) {
TfLiteIntArray* activation_state_quantized_size =
TfLiteIntArrayCopy(activation_state->dims);
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, activation_state_quantized,
activation_state_quantized_size));
output_state_quantized->type = input_to_output_weights->type;
output_state_quantized->allocation_type = kTfLiteArenaRw;
if (!TfLiteIntArrayEqual(output_state_quantized->dims,
output_state->dims)) {
TfLiteIntArray* output_state_quantized_size =
TfLiteIntArrayCopy(output_state->dims);
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output_state_quantized,
output_state_quantized_size));
}
node->temporaries->data[kCellStateQuantized] =
scratch_tensor_index + kCellStateQuantized;
@ -559,11 +559,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* activation_state =
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* output_state =
GetVariableInput(context, node, lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
GetVariableInput(context, node, lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
const TfLiteTensor* input_layer_norm_coefficients =
@ -613,14 +613,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
/*output_offset=*/0, scratch_buffer, output_state, cell_state,
output);
}
case kTfLiteUInt8:
case kTfLiteInt8: {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
TfLiteTensor* activation_state_quantized =
TfLiteTensor* output_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@ -652,10 +652,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
cell_state_quantized, activation_state, cell_state, accum_scratch,
output, zero_points, row_sums, row_sums_size,
&op_data->compute_row_sums,
/*aux_input_quantized=*/nullptr, output_state_quantized,
cell_state_quantized, output_state, cell_state, accum_scratch, output,
zero_points, row_sums, row_sums_size, &op_data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
}
default:

View File

@ -100,13 +100,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
// Adding the 2 input state tensors.
input_activation_state_ =
// Adding the 2 state tensors.
output_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
/*is_variable=*/true);
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
/*is_variable=*/true);
cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
/*is_variable=*/true);
// Layer norm weights.
if (is_layer_norm) {
@ -256,8 +255,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
int input_activation_state_;
int input_cell_state_;
int output_state_;
int cell_state_;
int input_layer_norm_coefficients_;
int forget_layer_norm_coefficients_;
@ -537,7 +536,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
});
@ -599,7 +598,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
});
@ -665,7 +664,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
},
TensorType_UINT8, GetParam());
@ -728,7 +727,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
},
TensorType_INT8, GetParam());
@ -840,7 +839,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
});
@ -901,7 +900,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
},
TensorType_UINT8, GetParam());
@ -964,7 +963,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
},
TensorType_INT8, GetParam());
@ -1626,7 +1625,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
});
@ -1695,7 +1694,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
},
TensorType_UINT8, GetParam());
@ -1766,7 +1765,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
},
TensorType_INT8, GetParam());
@ -2437,7 +2436,7 @@ TEST_F(NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest,
{n_output, n_cell}, // projection_weight tensor
{n_output}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
});
@ -2643,7 +2642,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
@ -2714,7 +2713,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_output}, // output_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor

View File

@ -302,9 +302,8 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, Logger* logger,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output, Logger* logger,
const std::vector<int>& intermediate_tensor_indexes,
ErrorReporter* error_reporter) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
@ -390,10 +389,10 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
output_ptr_time, logger, intermediate_tensor_indexes, error_reporter);
GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
output_gate_scratch, output_ptr_time, logger,
intermediate_tensor_indexes, error_reporter);
}
} else {
for (int b = 0; b < n_batch; b++) {
@ -414,9 +413,9 @@ TfLiteStatus EvalFloat(
float* output_ptr = GetTensorData<float>(output) +
time_offset * output_step + output_offset;
// Offset the {activation,cell}_state pointers to the right batch.
float* activation_state_ptr = GetTensorData<float>(activation_state) +
b * output_batch_leading_dim;
// Offset the {output,cell}_state pointers to the right batch.
float* output_state_ptr =
GetTensorData<float>(output_state) + b * output_batch_leading_dim;
float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
// Offset the scratch pointers to the right batch.
float* input_gate_scratch_ptr =
@ -452,7 +451,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
output_gate_scratch_ptr, output_ptr, logger,
intermediate_tensor_indexes, error_reporter);
@ -541,11 +540,11 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* activation_state = GetVariableInput(
context, node, ops::builtin::lstm::full::kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* output_state = GetVariableInput(
context, node, ops::builtin::lstm::full::kOutputStateTensor);
TF_LITE_ENSURE(context, output_state != nullptr);
TfLiteTensor* cell_state = GetVariableInput(
context, node, ops::builtin::lstm::full::kInputCellStateTensor);
context, node, ops::builtin::lstm::full::kCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
TfLiteTensor* output =
@ -574,8 +573,8 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output, logger, intermediate_tensor_indexes, error_reporter);
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
logger, intermediate_tensor_indexes, error_reporter);
}
case kTfLiteUInt8:
case kTfLiteInt8: