Separate "is layer norm" and "has layer norm tensors" parameters of LSTMOpModel.
PiperOrigin-RevId: 321890450 Change-Id: Ie5a07786688bd1e3e2914362e78735fee29df093
This commit is contained in:
parent
0721b70578
commit
b13a153d07
@ -40,8 +40,8 @@ class LSTMOpModel : public SingleOpModel {
|
||||
bool use_peephole, bool use_projection_weights,
|
||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType weight_type, bool is_layer_norm,
|
||||
bool asymmetric_quantize_inputs)
|
||||
const TensorType weight_type, bool model_has_legacy_20_inputs,
|
||||
bool is_layer_norm, bool asymmetric_quantize_inputs)
|
||||
: n_batch_(n_batch),
|
||||
n_input_(n_input),
|
||||
n_cell_(n_cell),
|
||||
@ -111,23 +111,19 @@ class LSTMOpModel : public SingleOpModel {
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (is_layer_norm) {
|
||||
const int kInputLayerNormCoeffsIndex = 20;
|
||||
const int kForgetLayerNormCoeffsIndex = 21;
|
||||
const int kCellLayerNormCoeffsIndex = 22;
|
||||
const int kOutputLayerNormCoeffsIndex = 23;
|
||||
if (!model_has_legacy_20_inputs) {
|
||||
if (use_cifg) {
|
||||
input_layer_norm_coefficients_ = AddNullInput();
|
||||
} else {
|
||||
input_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes);
|
||||
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
|
||||
}
|
||||
forget_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes);
|
||||
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
|
||||
cell_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes);
|
||||
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
|
||||
output_layer_norm_coefficients_ =
|
||||
AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes);
|
||||
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
|
||||
}
|
||||
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
@ -277,15 +273,6 @@ class LSTMOpModel : public SingleOpModel {
|
||||
int n_output_;
|
||||
|
||||
private:
|
||||
int AddLayerNormCoeffsTensor(
|
||||
int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
|
||||
if (input_shapes[tensor_index][0] != 0) {
|
||||
return AddInput(TensorType_FLOAT32);
|
||||
} else {
|
||||
return AddNullInput();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PopulateTensor(int index, const std::vector<T>& data) {
|
||||
// Nothing to do if tensor is an optional input or if data vector is empty.
|
||||
@ -504,16 +491,17 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest
|
||||
class NoCifgNoPeepholeNoProjectionNoClippingNoLayerNormLstmTest
|
||||
: public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
|
||||
|
||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
||||
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingNoLayerNormLstmTest,
|
||||
LstmBlackBoxTest) {
|
||||
const int n_batch = 1;
|
||||
const int n_input = 2;
|
||||
@ -559,7 +547,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
||||
{0}, // cell_layer_norm_coefficient tensor
|
||||
{0}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
@ -607,7 +597,8 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||
@ -658,7 +649,8 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||
@ -749,7 +741,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
@ -797,7 +790,8 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||
@ -846,7 +840,8 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||
@ -1487,7 +1482,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
@ -1534,7 +1530,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||
@ -1583,7 +1580,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015);
|
||||
@ -1703,8 +1701,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
|
||||
/*asymmetric_quantize_inputs=*/false);
|
||||
/*weight_type=*/TensorType_FLOAT32, /*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
// Verify the final output.
|
||||
lstm_golden_output_ = {{
|
||||
@ -1774,8 +1772,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/true,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
/*weight_type=*/TensorType_UINT8, /*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
lstm_golden_output_ = {{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
@ -1847,8 +1845,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/true,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
/*weight_type=*/TensorType_INT8, /*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
// Goldens are calculated from weight_type=TensorType_FLOAT32.
|
||||
lstm_golden_output_ = {{
|
||||
@ -1961,8 +1959,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
|
||||
/*asymmetric_quantize_inputs=*/false);
|
||||
/*weight_type=*/TensorType_FLOAT32, /*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false);
|
||||
|
||||
// Verify the final output.
|
||||
lstm_golden_output_ = {
|
||||
@ -2032,8 +2030,8 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/true,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
/*weight_type=*/TensorType_UINT8, /*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
// Verify the final output.
|
||||
lstm_golden_output_ = {
|
||||
@ -2104,8 +2102,8 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/true,
|
||||
/*asymmetric_quantize_inputs=*/GetParam());
|
||||
/*weight_type=*/TensorType_INT8, /*model_has_legacy_20_inputs=*/false,
|
||||
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
|
||||
|
||||
// Goldens are results using FLOAT32 inference.
|
||||
lstm_golden_output_ = {{
|
||||
@ -3278,41 +3276,6 @@ TEST(LSTMOpModel, InvalidTypeTest) {
|
||||
const int n_cell = 4;
|
||||
const int n_output = 4;
|
||||
|
||||
EXPECT_DEATH(LSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
/*use_cifg=*/false, /*use_peephole=*/false,
|
||||
/*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false,
|
||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||
{
|
||||
{n_batch, n_input}, // input tensor
|
||||
|
||||
{n_cell, n_input}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{n_cell, n_output}, // recurrent_to_input_weight_tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight_tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight_tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight_tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{0}, // cell_to_forget_weight tensor
|
||||
{0}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_gate_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_INT32, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false),
|
||||
"");
|
||||
|
||||
EXPECT_DEATH(
|
||||
LSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
@ -3345,9 +3308,45 @@ TEST(LSTMOpModel, InvalidTypeTest) {
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_COMPLEX64, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false),
|
||||
/*weight_type=*/TensorType_INT32, /*model_has_legacy_20_inputs=*/true,
|
||||
/*is_layer_norm=*/false, /*asymmetric_quantize_inputs=*/false),
|
||||
"");
|
||||
|
||||
EXPECT_DEATH(LSTMOpModel lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
/*use_cifg=*/false, /*use_peephole=*/false,
|
||||
/*use_projection_weights=*/false,
|
||||
/*use_projection_bias=*/false,
|
||||
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
|
||||
{
|
||||
{n_batch, n_input}, // input tensor
|
||||
|
||||
{n_cell, n_input}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{n_cell, n_output}, // recurrent_to_input_weight_tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight_tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight_tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight_tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{0}, // cell_to_forget_weight tensor
|
||||
{0}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_gate_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{0, 0}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
},
|
||||
/*weight_type=*/TensorType_COMPLEX64,
|
||||
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
|
||||
/*asymmetric_quantize_inputs=*/false),
|
||||
"");
|
||||
}
|
||||
#endif
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user