Consistently call LSTM's gate biases in variable names <gatename>_gate_bias.

PiperOrigin-RevId: 317343158
Change-Id: I385ddbad6c1283b84574b2ec0b523ce9f88a4cd3
This commit is contained in:
Robert David 2020-06-19 11:16:53 -07:00 committed by TensorFlower Gardener
parent 7a88f7fb5c
commit 1129f21360
11 changed files with 224 additions and 206 deletions

View File

@ -318,11 +318,11 @@ TfLiteStatus CheckLstmTensorDimensionsAndTypes(
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* cell_bias =
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, cell_gate_bias_tensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, output_gate_bias_tensor);
@ -886,7 +886,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
const TfLiteTensor* fw_forget_gate_bias =
GetInput(context, node, kFwForgetGateBiasTensor);
const TfLiteTensor* fw_cell_bias =
const TfLiteTensor* fw_cell_gate_bias =
GetInput(context, node, kFwCellGateBiasTensor);
const TfLiteTensor* fw_output_gate_bias =
GetInput(context, node, kFwOutputGateBiasTensor);
@ -934,7 +934,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
const TfLiteTensor* bw_forget_gate_bias =
GetInput(context, node, kBwForgetGateBiasTensor);
const TfLiteTensor* bw_cell_bias =
const TfLiteTensor* bw_cell_gate_bias =
GetInput(context, node, kBwCellGateBiasTensor);
const TfLiteTensor* bw_output_gate_bias =
GetInput(context, node, kBwOutputGateBiasTensor);
@ -1029,7 +1029,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
&lstm_params,
/*forward_sequence=*/true, time_major, /*output_offset=*/0,
@ -1049,7 +1049,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
&lstm_params,
/*forward_sequence=*/false, time_major, bw_output_offset,
@ -1099,7 +1099,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
&lstm_params,
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
@ -1125,7 +1125,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*output_layer_norm_coefficients=*/nullptr, real_aux_input,
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
&lstm_params,
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,

View File

@ -89,7 +89,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_input_gate_bias_ = AddInput(TensorType_FLOAT32);
}
fw_forget_gate_bias_ = AddInput(TensorType_FLOAT32);
fw_cell_bias_ = AddInput(TensorType_FLOAT32);
fw_cell_gate_bias_ = AddInput(TensorType_FLOAT32);
fw_output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
@ -144,7 +144,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_input_gate_bias_ = AddInput(TensorType_FLOAT32);
}
bw_forget_gate_bias_ = AddInput(TensorType_FLOAT32);
bw_cell_bias_ = AddInput(TensorType_FLOAT32);
bw_cell_gate_bias_ = AddInput(TensorType_FLOAT32);
bw_output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
@ -288,8 +288,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
}
void SetCellBias(const std::vector<float>& f) {
PopulateTensor(fw_cell_bias_, f);
PopulateTensor(bw_cell_bias_, f);
PopulateTensor(fw_cell_gate_bias_, f);
PopulateTensor(bw_cell_gate_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
@ -364,7 +364,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int fw_input_gate_bias_;
int fw_forget_gate_bias_;
int fw_cell_bias_;
int fw_cell_gate_bias_;
int fw_output_gate_bias_;
int fw_projection_weights_;
@ -386,7 +386,7 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int bw_input_gate_bias_;
int bw_forget_gate_bias_;
int bw_cell_bias_;
int bw_cell_gate_bias_;
int bw_output_gate_bias_;
int bw_projection_weights_;
@ -467,7 +467,7 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -490,7 +490,7 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -633,7 +633,7 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -656,7 +656,7 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -796,7 +796,7 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -819,7 +819,7 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -956,7 +956,7 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -978,7 +978,7 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -1107,7 +1107,7 @@ TEST(LSTMOpTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -1129,7 +1129,7 @@ TEST(LSTMOpTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -1258,7 +1258,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1280,7 +1280,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1961,7 +1961,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1983,7 +1983,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2667,7 +2667,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -2690,7 +2690,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -2841,7 +2841,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -2864,7 +2864,7 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor

View File

@ -407,7 +407,8 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
@ -446,10 +447,10 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
int16_t* layer_norm_forget_weight_ptr = nullptr;
int16_t* layer_norm_cell_weight_ptr = nullptr;
int16_t* layer_norm_output_weight_ptr = nullptr;
int32_t* input_bias_ptr = nullptr;
int32_t* forget_bias_ptr = nullptr;
int32_t* cell_bias_ptr = nullptr;
int32_t* output_bias_ptr = nullptr;
int32_t* input_gate_bias_ptr = nullptr;
int32_t* forget_gate_bias_ptr = nullptr;
int32_t* cell_gate_bias_ptr = nullptr;
int32_t* output_gate_bias_ptr = nullptr;
int32_t* proj_bias_ptr = nullptr;
int16_t* cell_ptr = nullptr;
int8_t* output_state_ptr = nullptr;
@ -497,7 +498,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
if (!use_cifg) {
input_to_input_weight_ptr = input_to_input_weights->data.int8;
recurrent_to_input_weight_ptr = recurrent_to_input_weights->data.int8;
input_bias_ptr = input_gate_bias->data.i32;
input_gate_bias_ptr = input_gate_bias->data.i32;
input_to_input_weight_scale = input_to_input_weights->params.scale;
recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
}
@ -547,9 +548,9 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8(
recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
recurrent_to_output_weight_ptr = recurrent_to_output_weights->data.int8;
recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
forget_bias_ptr = forget_gate_bias->data.i32;
cell_bias_ptr = cell_bias->data.i32;
output_bias_ptr = output_gate_bias->data.i32;
forget_gate_bias_ptr = forget_gate_bias->data.i32;
cell_gate_bias_ptr = cell_gate_bias->data.i32;
output_gate_bias_ptr = output_gate_bias->data.i32;
output_state_ptr = output_state->data.int8;
cell_ptr = cell_state->data.i16;
input_scale = input->params.scale;
@ -875,13 +876,14 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
if (is_integer) {
TF_LITE_ENSURE_TYPES_EQ(context, cell_bias->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
} else {
TF_LITE_ENSURE_TYPES_EQ(context, cell_bias->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
}
const TfLiteTensor* output_gate_bias =
@ -1526,7 +1528,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
@ -1560,8 +1563,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, params,
/*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, output_state, cell_state,
output);
@ -1603,8 +1607,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, params,
/*forward_sequence=*/true,
/*time_major=*/true, /*output_offset=*/0, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
input_quantized,
@ -1631,10 +1636,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, &op_data->integer_lstm_param, output_state, cell_state,
output, scratch0, scratch1, scratch2, scratch3, scratch4,
scratch5, CpuBackendContext::GetFromContext(context));
cell_gate_bias, output_gate_bias, projection_weights,
projection_bias, params, &op_data->integer_lstm_param,
output_state, cell_state, output, scratch0, scratch1, scratch2,
scratch3, scratch4, scratch5,
CpuBackendContext::GetFromContext(context));
} else {
TfLiteTensor* scratch0 = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* scratch1 = GetTemporary(context, node, /*index=*/1);
@ -1653,8 +1659,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_layer_norm_coefficients,
forget_layer_norm_coefficients, cell_layer_norm_coefficients,
output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
cell_bias, output_gate_bias, projection_weights, projection_bias,
params, output_state, cell_state, output,
cell_gate_bias, output_gate_bias, projection_weights,
projection_bias, params, output_state, cell_state, output,
&op_data->integer_lstm_param, scratch0, scratch1, scratch2,
scratch3, scratch4, scratch5, scratch6, scratch7);
return kTfLiteOk;

View File

@ -942,10 +942,10 @@ inline void LstmStepHybrid(
// effective_proj_scale_b - optional
//
// Gate biases of size 'n_cell':
// input_bias_ptr - optional
// forget_bias_ptr
// input_gate_bias_ptr - optional
// forget_gate_bias_ptr
// cell_gate_bias_ptr
// output_bias_ptr
// output_gate_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
// layer_norm_input_weight_ptr - optional
@ -1031,8 +1031,8 @@ inline void LstmStepInteger(
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale,
int32_t input_variance_guard, int32_t forget_variance_guard,
int32_t cell_variance_guard, int32_t output_variance_guard,
@ -1098,7 +1098,7 @@ inline void LstmStepInteger(
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_1_ptr, layer_norm_forget_weight_ptr, forget_bias_ptr,
scratch_1_ptr, layer_norm_forget_weight_ptr, forget_gate_bias_ptr,
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
forget_variance_guard, n_batch, n_cell, scratch_1_ptr);
}
@ -1149,7 +1149,7 @@ inline void LstmStepInteger(
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_0_ptr, layer_norm_input_weight_ptr, input_bias_ptr,
scratch_0_ptr, layer_norm_input_weight_ptr, input_gate_bias_ptr,
layer_norm_input_scale_a, layer_norm_input_scale_b,
input_variance_guard, n_batch, n_cell, scratch_0_ptr);
}
@ -1190,7 +1190,7 @@ inline void LstmStepInteger(
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_3_ptr, layer_norm_output_weight_ptr, output_bias_ptr,
scratch_3_ptr, layer_norm_output_weight_ptr, output_gate_bias_ptr,
layer_norm_output_scale_a, layer_norm_output_scale_b,
output_variance_guard, n_batch, n_cell, scratch_3_ptr);
}
@ -1268,10 +1268,10 @@ inline void LstmStepInteger(
// effective_proj_scale_b - optional
//
// Gate biases of size 'n_cell':
// input_bias_ptr - optional
// forget_bias_ptr
// input_gate_bias_ptr - optional
// forget_gate_bias_ptr
// cell_gate_bias_ptr
// output_bias_ptr
// output_gate_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
// layer_norm_input_weight_ptr - optional
@ -1358,8 +1358,8 @@ void LstmStepInteger(
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
const int32_t* proj_bias_ptr, const TfLiteLSTMParams* params,
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
const int32_t* intermediate_zp, int16_t quantized_cell_clip,
@ -1391,7 +1391,8 @@ void LstmStepInteger(
// Forget gate layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch2, layer_norm_forget_weight_ptr, layer_norm_forget_scale_a,
layer_norm_forget_scale_b, forget_bias_ptr, n_batch, n_cell, scratch2);
layer_norm_forget_scale_b, forget_gate_bias_ptr, n_batch, n_cell,
scratch2);
// Forget gate sigmoid.
tensor_utils::ApplySigmoidFloat(scratch2, n_batch, n_cell, scratch2);
@ -1444,7 +1445,8 @@ void LstmStepInteger(
// Output gate with layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch4, layer_norm_output_weight_ptr, layer_norm_output_scale_a,
layer_norm_output_scale_b, output_bias_ptr, n_batch, n_cell, scratch4);
layer_norm_output_scale_b, output_gate_bias_ptr, n_batch, n_cell,
scratch4);
// Output gate sigmoid.
tensor_utils::ApplySigmoidFloat(scratch4, n_batch, n_cell, scratch4);
@ -1512,7 +1514,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
@ -1595,7 +1597,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
@ -1656,7 +1658,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
@ -1693,7 +1695,7 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
@ -1802,7 +1804,7 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorScale(projection_weights),
@ -1888,7 +1890,7 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<int8_t>(projection_weights),
GetTensorScale(projection_weights),
@ -1930,7 +1932,7 @@ TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
@ -2020,7 +2022,7 @@ TfLiteStatus EvalInteger8x8_16(
integer_lstm_param->layer_norm_output_scale_b,
GetTensorData<int32_t>(input_gate_bias),
GetTensorData<int32_t>(forget_gate_bias),
GetTensorData<int32_t>(cell_bias),
GetTensorData<int32_t>(cell_gate_bias),
GetTensorData<int32_t>(output_gate_bias),
integer_lstm_param->quantized_cell_clip,
integer_lstm_param->quantized_proj_clip, integer_lstm_param->cell_scale,
@ -2065,7 +2067,7 @@ TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,
@ -2120,10 +2122,12 @@ TfLiteStatus EvalInteger8x8_8(
GetTensorData<int16_t>(cell_layer_norm_coefficients);
const int16_t* layer_norm_output_weight_ptr =
GetTensorData<int16_t>(output_layer_norm_coefficients);
const int32_t* input_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
const int32_t* forget_bias_ptr = GetTensorData<int32_t>(forget_gate_bias);
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_bias);
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
const int32_t* input_gate_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
const int32_t* forget_gate_bias_ptr =
GetTensorData<int32_t>(forget_gate_bias);
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_gate_bias);
const int32_t* output_gate_bias_ptr =
GetTensorData<int32_t>(output_gate_bias);
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
int8_t* output_state_ptr = GetTensorData<int8_t>(output_state);
@ -2209,8 +2213,8 @@ TfLiteStatus EvalInteger8x8_8(
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
input_bias_ptr, forget_bias_ptr, cell_gate_bias_ptr, output_bias_ptr,
proj_bias_ptr,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_gate_bias_ptr,
output_gate_bias_ptr, proj_bias_ptr,
params, integer_lstm_param->intermediate_scale_a,
integer_lstm_param->intermediate_scale_b,

View File

@ -117,7 +117,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
@ -145,7 +145,7 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
@ -174,7 +174,7 @@ TfLiteStatus EvalInteger8x8_16(
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const lstm_eval::IntegerLstmParameter* integer_lstm_param,
@ -200,7 +200,7 @@ TfLiteStatus EvalInteger8x8_8(
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output,

View File

@ -113,10 +113,10 @@ class BaseLstmParam {
TfLiteIntArrayFree(layer_norm_forget_tensor_.dims);
TfLiteIntArrayFree(layer_norm_cell_tensor_.dims);
TfLiteIntArrayFree(layer_norm_output_tensor_.dims);
TfLiteIntArrayFree(input_bias_tensor_.dims);
TfLiteIntArrayFree(forget_bias_tensor_.dims);
TfLiteIntArrayFree(cell_bias_tensor_.dims);
TfLiteIntArrayFree(output_bias_tensor_.dims);
TfLiteIntArrayFree(input_gate_bias_tensor_.dims);
TfLiteIntArrayFree(forget_gate_bias_tensor_.dims);
TfLiteIntArrayFree(cell_gate_bias_tensor_.dims);
TfLiteIntArrayFree(output_gate_bias_tensor_.dims);
TfLiteIntArrayFree(projection_tensor_.dims);
TfLiteIntArrayFree(projection_bias_tensor_.dims);
TfLiteIntArrayFree(activation_tensor_.dims);
@ -275,17 +275,17 @@ class BaseLstmParam {
std::vector<int32_t> layer_norm_output_size_ = {n_cell_};
TfLiteTensor layer_norm_output_tensor_;
std::vector<int32_t> input_bias_size_ = {n_cell_};
TfLiteTensor input_bias_tensor_;
std::vector<int32_t> input_gate_bias_size_ = {n_cell_};
TfLiteTensor input_gate_bias_tensor_;
std::vector<int32_t> forget_bias_size_ = {n_cell_};
TfLiteTensor forget_bias_tensor_;
std::vector<int32_t> forget_gate_bias_size_ = {n_cell_};
TfLiteTensor forget_gate_bias_tensor_;
std::vector<int32_t> cell_bias_size_ = {n_cell_};
TfLiteTensor cell_bias_tensor_;
std::vector<int32_t> cell_gate_bias_size_ = {n_cell_};
TfLiteTensor cell_gate_bias_tensor_;
std::vector<int32_t> output_bias_size_ = {n_cell_};
TfLiteTensor output_bias_tensor_;
std::vector<int32_t> output_gate_bias_size_ = {n_cell_};
TfLiteTensor output_gate_bias_tensor_;
// projection_weights.
std::vector<int8_t> projection_ = {
@ -350,24 +350,28 @@ class QuantizedLstmParam : public BaseLstmParam {
return &layer_norm_output_tensor_;
}
TfLiteTensor* GetInputBias() {
PackWeightToTensor(&input_bias_tensor_, input_bias_, input_bias_size_);
input_bias_tensor_.data.i32 = input_bias_.data();
return &input_bias_tensor_;
PackWeightToTensor(&input_gate_bias_tensor_, input_gate_bias_,
input_gate_bias_size_);
input_gate_bias_tensor_.data.i32 = input_gate_bias_.data();
return &input_gate_bias_tensor_;
}
TfLiteTensor* GetForgetBias() {
PackWeightToTensor(&forget_bias_tensor_, forget_bias_, forget_bias_size_);
forget_bias_tensor_.data.i32 = forget_bias_.data();
return &forget_bias_tensor_;
PackWeightToTensor(&forget_gate_bias_tensor_, forget_gate_bias_,
forget_gate_bias_size_);
forget_gate_bias_tensor_.data.i32 = forget_gate_bias_.data();
return &forget_gate_bias_tensor_;
}
TfLiteTensor* GetCellBias() {
PackWeightToTensor(&cell_bias_tensor_, cell_bias_, cell_bias_size_);
cell_bias_tensor_.data.i32 = cell_bias_.data();
return &cell_bias_tensor_;
PackWeightToTensor(&cell_gate_bias_tensor_, cell_gate_bias_,
cell_gate_bias_size_);
cell_gate_bias_tensor_.data.i32 = cell_gate_bias_.data();
return &cell_gate_bias_tensor_;
}
TfLiteTensor* GetOutputBias() {
PackWeightToTensor(&output_bias_tensor_, output_bias_, output_bias_size_);
output_bias_tensor_.data.i32 = output_bias_.data();
return &output_bias_tensor_;
PackWeightToTensor(&output_gate_bias_tensor_, output_gate_bias_,
output_gate_bias_size_);
output_gate_bias_tensor_.data.i32 = output_gate_bias_.data();
return &output_gate_bias_tensor_;
}
TfLiteTensor* GetProjectionBias() {
PackWeightToTensor(&projection_bias_tensor_, projection_bias_,
@ -539,22 +543,22 @@ class QuantizedLstmParam : public BaseLstmParam {
};
// input_gate_bias.
std::vector<int32_t> input_bias_ = {
std::vector<int32_t> input_gate_bias_ = {
16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
};
// forget_gate_bias.
std::vector<int32_t> forget_bias_ = {
std::vector<int32_t> forget_gate_bias_ = {
16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
};
// cell_bias.
std::vector<int32_t> cell_bias_ = {
// cell_gate_bias.
std::vector<int32_t> cell_gate_bias_ = {
16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
};
// output_gate_bias.
std::vector<int32_t> output_bias_ = {
std::vector<int32_t> output_gate_bias_ = {
16, 4, 5, 6, 1, 1, 3, 4, -5, 6, //
};
@ -711,27 +715,28 @@ class HybridLstmParam : public BaseLstmParam {
return &accum_scratch_tensor_;
}
TfLiteTensor* GetInputBias() {
PackWeightToTensor(&input_bias_tensor_, input_float_bias_,
input_bias_size_);
input_bias_tensor_.data.f = input_float_bias_.data();
return &input_bias_tensor_;
PackWeightToTensor(&input_gate_bias_tensor_, input_float_bias_,
input_gate_bias_size_);
input_gate_bias_tensor_.data.f = input_float_bias_.data();
return &input_gate_bias_tensor_;
}
TfLiteTensor* GetForgetBias() {
PackWeightToTensor(&forget_bias_tensor_, forget_float_bias_,
forget_bias_size_);
forget_bias_tensor_.data.f = forget_float_bias_.data();
return &forget_bias_tensor_;
PackWeightToTensor(&forget_gate_bias_tensor_, forget_float_bias_,
forget_gate_bias_size_);
forget_gate_bias_tensor_.data.f = forget_float_bias_.data();
return &forget_gate_bias_tensor_;
}
TfLiteTensor* GetCellBias() {
PackWeightToTensor(&cell_bias_tensor_, cell_float_bias_, cell_bias_size_);
cell_bias_tensor_.data.f = cell_float_bias_.data();
return &cell_bias_tensor_;
PackWeightToTensor(&cell_gate_bias_tensor_, cell_float_bias_,
cell_gate_bias_size_);
cell_gate_bias_tensor_.data.f = cell_float_bias_.data();
return &cell_gate_bias_tensor_;
}
TfLiteTensor* GetOutputBias() {
PackWeightToTensor(&output_bias_tensor_, output_float_bias_,
output_bias_size_);
output_bias_tensor_.data.f = output_float_bias_.data();
return &output_bias_tensor_;
PackWeightToTensor(&output_gate_bias_tensor_, output_float_bias_,
output_gate_bias_size_);
output_gate_bias_tensor_.data.f = output_float_bias_.data();
return &output_gate_bias_tensor_;
}
TfLiteTensor* GetProjectionBias() {
PackWeightToTensor(&projection_bias_tensor_, projection_float_bias_,

View File

@ -89,7 +89,7 @@ class LSTMOpModel : public SingleOpModel {
input_gate_bias_ = AddInput(TensorType_FLOAT32);
}
forget_gate_bias_ = AddInput(TensorType_FLOAT32);
cell_bias_ = AddInput(TensorType_FLOAT32);
cell_gate_bias_ = AddInput(TensorType_FLOAT32);
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
@ -211,7 +211,7 @@ class LSTMOpModel : public SingleOpModel {
}
void SetCellBias(const std::vector<float>& f) {
PopulateTensor(cell_bias_, f);
PopulateTensor(cell_gate_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
@ -261,7 +261,7 @@ class LSTMOpModel : public SingleOpModel {
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int cell_gate_bias_;
int output_gate_bias_;
int projection_weights_;
@ -498,7 +498,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -545,7 +545,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -601,7 +601,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -652,7 +652,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -743,7 +743,7 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -791,7 +791,7 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -840,7 +840,7 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -1481,7 +1481,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1528,7 +1528,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1577,7 +1577,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1689,7 +1689,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1760,7 +1760,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1833,7 +1833,7 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1947,7 +1947,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2018,7 +2018,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2090,7 +2090,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2195,8 +2195,8 @@ class LSTMIntegerOpModel : public SingleOpModel {
}
forget_gate_bias_ = AddInput({TensorType_INT32, input_shapes[13],
ranges[13].first, ranges[13].second});
cell_bias_ = AddInput({TensorType_INT32, input_shapes[14], ranges[14].first,
ranges[14].second});
cell_gate_bias_ = AddInput({TensorType_INT32, input_shapes[14],
ranges[14].first, ranges[14].second});
output_gate_bias_ = AddInput({TensorType_INT32, input_shapes[15],
ranges[15].first, ranges[15].second});
@ -2330,7 +2330,7 @@ class LSTMIntegerOpModel : public SingleOpModel {
}
void SetCellBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(cell_bias_, f);
QuantizeAndPopulate<int32_t>(cell_gate_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
@ -2379,7 +2379,7 @@ class LSTMIntegerOpModel : public SingleOpModel {
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int cell_gate_bias_;
int output_gate_bias_;
int projection_weights_;
@ -2473,7 +2473,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2507,7 +2507,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionNoPeephole) {
{-100, 100}, // input_gate_bias tensor
{-100, 100}, // forget_gate_bias tensor
{-100, 100}, // cell_bias tensor
{-100, 100}, // cell_gate_bias tensor
{-100, 100}, // output_gate_bias tensor
{-0.5, 0.5}, // projection_weight tensor
@ -2675,7 +2675,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2709,7 +2709,7 @@ TEST(LSTMIntegerOpModel, NoCifgYesLayerNormNoYesProjectionYesPeephole) {
{-100, 100}, // input_gate_bias tensor
{-100, 80}, // forget_gate_bias tensor
{-100, 100}, // cell_bias tensor
{-100, 100}, // cell_gate_bias tensor
{-100, 100}, // output_gate_bias tensor
{-0.5, 0.5}, // projection_weight tensor
@ -2869,8 +2869,8 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
}
forget_gate_bias_ = AddInput({TensorType_INT32, input_shapes[13],
ranges[13].first, ranges[13].second});
cell_bias_ = AddInput({TensorType_INT32, input_shapes[14], ranges[14].first,
ranges[14].second});
cell_gate_bias_ = AddInput({TensorType_INT32, input_shapes[14],
ranges[14].first, ranges[14].second});
output_gate_bias_ = AddInput({TensorType_INT32, input_shapes[15],
ranges[15].first, ranges[15].second});
@ -3004,7 +3004,7 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
}
void SetCellBias(const std::vector<float>& f) {
QuantizeAndPopulate<int32_t>(cell_bias_, f);
QuantizeAndPopulate<int32_t>(cell_gate_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
@ -3053,7 +3053,7 @@ class LSTMIntegerOpModel8x8_8 : public SingleOpModel {
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int cell_gate_bias_;
int output_gate_bias_;
int projection_weights_;
@ -3148,7 +3148,7 @@ TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -3182,7 +3182,7 @@ TEST(LSTMIntegerOpModel8x8_8, CifgYesLayerNormNoYesProjectionNoPeephole) {
{-100, 100}, // input_gate_bias tensor
{-100, 100}, // forget_gate_bias tensor
{-100, 100}, // cell_bias tensor
{-100, 100}, // cell_gate_bias tensor
{-100, 100}, // output_gate_bias tensor
{-0.5, 0.5}, // projection_weight tensor
@ -3303,7 +3303,7 @@ TEST(LSTMOpModel, InvalidTypeTest) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -3338,7 +3338,7 @@ TEST(LSTMOpModel, InvalidTypeTest) {
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor

View File

@ -78,7 +78,7 @@ class LSTMOpModel : public SingleOpModel {
input_gate_bias_ = AddInput(TensorType_FLOAT32);
}
forget_gate_bias_ = AddInput(TensorType_FLOAT32);
cell_bias_ = AddInput(TensorType_FLOAT32);
cell_gate_bias_ = AddInput(TensorType_FLOAT32);
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
@ -161,7 +161,7 @@ class LSTMOpModel : public SingleOpModel {
}
void SetCellBias(std::initializer_list<float> f) {
PopulateTensor(cell_bias_, f);
PopulateTensor(cell_gate_bias_, f);
}
void SetOutputGateBias(std::initializer_list<float> f) {
@ -209,7 +209,7 @@ class LSTMOpModel : public SingleOpModel {
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int cell_gate_bias_;
int output_gate_bias_;
int projection_weights_;
@ -256,7 +256,7 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor

View File

@ -179,10 +179,10 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* cell_bias =
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, lstm::full::kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
@ -546,7 +546,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* cell_bias =
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
@ -611,8 +611,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, &lstm_params,
/*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer, output_state, cell_state,
output);
}
@ -648,8 +649,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, &lstm_params,
/*forward_sequence=*/true, time_major,
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, output_state_quantized,

View File

@ -85,7 +85,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
input_gate_bias_ = AddInput(TensorType_FLOAT32);
}
forget_gate_bias_ = AddInput(TensorType_FLOAT32);
cell_bias_ = AddInput(TensorType_FLOAT32);
cell_gate_bias_ = AddInput(TensorType_FLOAT32);
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
@ -187,7 +187,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
}
void SetCellBias(const std::vector<float>& f) {
PopulateTensor(cell_bias_, f);
PopulateTensor(cell_gate_bias_, f);
}
void SetOutputGateBias(const std::vector<float>& f) {
@ -249,7 +249,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
int cell_gate_bias_;
int output_gate_bias_;
int projection_weights_;
@ -530,7 +530,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -592,7 +592,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -658,7 +658,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -721,7 +721,7 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -833,7 +833,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -894,7 +894,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -957,7 +957,7 @@ TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -1619,7 +1619,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1688,7 +1688,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -1759,7 +1759,7 @@ TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2430,7 +2430,7 @@ TEST_F(NoCifgPeepholeProjectionAndBiasClippingUnidirectionalLstmTest,
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
@ -2636,7 +2636,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLayerNormUnidirectionalLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
@ -2707,7 +2707,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest,
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor

View File

@ -299,7 +299,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
@ -384,7 +384,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
@ -446,7 +446,7 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(output_layer_norm_coefficients),
GetTensorData<float>(input_gate_bias),
GetTensorData<float>(forget_gate_bias),
GetTensorData<float>(cell_bias),
GetTensorData<float>(cell_gate_bias),
GetTensorData<float>(output_gate_bias),
GetTensorData<float>(projection_weights),
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
@ -527,7 +527,7 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* cell_bias =
const TfLiteTensor* cell_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
@ -570,8 +570,9 @@ TfLiteStatus lstm_eval(TfLiteContext* context, TfLiteNode* node, Logger* logger,
/*aux_input_to_forget_weights=*/nullptr,
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
projection_bias, params, /*forward_sequence=*/true,
forget_gate_bias, cell_gate_bias, output_gate_bias,
projection_weights, projection_bias, params,
/*forward_sequence=*/true,
/*time_major=*/true,
/*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
logger, intermediate_tensor_indexes, error_reporter);