Separate "is layer norm" and "has layer norm tensors" parameters of LSTMOpModel.

PiperOrigin-RevId: 321890450
Change-Id: Ie5a07786688bd1e3e2914362e78735fee29df093
This commit is contained in:
Robert David 2020-07-17 18:06:19 -07:00 committed by TensorFlower Gardener
parent 0721b70578
commit b13a153d07

View File

@ -40,8 +40,8 @@ class LSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const TensorType weight_type, bool is_layer_norm,
bool asymmetric_quantize_inputs)
const TensorType weight_type, bool model_has_legacy_20_inputs,
bool is_layer_norm, bool asymmetric_quantize_inputs)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@ -111,23 +111,19 @@ class LSTMOpModel : public SingleOpModel {
AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
// Layer norm weights.
if (is_layer_norm) {
const int kInputLayerNormCoeffsIndex = 20;
const int kForgetLayerNormCoeffsIndex = 21;
const int kCellLayerNormCoeffsIndex = 22;
const int kOutputLayerNormCoeffsIndex = 23;
if (!model_has_legacy_20_inputs) {
if (use_cifg) {
input_layer_norm_coefficients_ = AddNullInput();
} else {
input_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes);
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
}
forget_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes);
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
cell_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes);
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
output_layer_norm_coefficients_ =
AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes);
is_layer_norm ? AddInput(TensorType_FLOAT32) : AddNullInput();
}
output_ = AddOutput(TensorType_FLOAT32);
@ -277,15 +273,6 @@ class LSTMOpModel : public SingleOpModel {
int n_output_;
private:
int AddLayerNormCoeffsTensor(
int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
if (input_shapes[tensor_index][0] != 0) {
return AddInput(TensorType_FLOAT32);
} else {
return AddNullInput();
}
}
template <typename T>
void PopulateTensor(int index, const std::vector<T>& data) {
// Nothing to do if tensor is an optional input or if data vector is empty.
@ -504,16 +491,17 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_FLOAT32,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest
class NoCifgNoPeepholeNoProjectionNoClippingNoLayerNormLstmTest
: public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingNoLayerNormLstmTest,
LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
@ -559,7 +547,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
{0}, // cell_layer_norm_coefficient tensor
{0}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
/*weight_type=*/TensorType_FLOAT32,
/*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
@ -607,7 +597,8 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_UINT8,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
@ -658,7 +649,8 @@ TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_INT8,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
@ -749,7 +741,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_FLOAT32,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
@ -797,7 +790,8 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_UINT8,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
@ -846,7 +840,8 @@ TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test,
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_INT8,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
@ -1487,7 +1482,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_FLOAT32,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false);
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
@ -1534,7 +1530,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_UINT8,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
@ -1583,7 +1580,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test,
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/false,
/*weight_type=*/TensorType_INT8,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/GetParam());
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015);
@ -1703,8 +1701,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
/*asymmetric_quantize_inputs=*/false);
/*weight_type=*/TensorType_FLOAT32, /*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false);
// Verify the final output.
lstm_golden_output_ = {{
@ -1774,8 +1772,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/true,
/*asymmetric_quantize_inputs=*/GetParam());
/*weight_type=*/TensorType_UINT8, /*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
lstm_golden_output_ = {{
// Batch0: 3 (input_sequence_size) * 3 (n_output)
@ -1847,8 +1845,8 @@ TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/true,
/*asymmetric_quantize_inputs=*/GetParam());
/*weight_type=*/TensorType_INT8, /*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
// Goldens are calculated from weight_type=TensorType_FLOAT32.
lstm_golden_output_ = {{
@ -1961,8 +1959,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_FLOAT32, /*is_layer_norm=*/true,
/*asymmetric_quantize_inputs=*/false);
/*weight_type=*/TensorType_FLOAT32, /*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/false);
// Verify the final output.
lstm_golden_output_ = {
@ -2032,8 +2030,8 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_UINT8, /*is_layer_norm=*/true,
/*asymmetric_quantize_inputs=*/GetParam());
/*weight_type=*/TensorType_UINT8, /*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
// Verify the final output.
lstm_golden_output_ = {
@ -2104,8 +2102,8 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test,
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
},
/*weight_type=*/TensorType_INT8, /*is_layer_norm=*/true,
/*asymmetric_quantize_inputs=*/GetParam());
/*weight_type=*/TensorType_INT8, /*model_has_legacy_20_inputs=*/false,
/*is_layer_norm=*/true, /*asymmetric_quantize_inputs=*/GetParam());
// Goldens are results using FLOAT32 inference.
lstm_golden_output_ = {{
@ -3278,41 +3276,6 @@ TEST(LSTMOpModel, InvalidTypeTest) {
const int n_cell = 4;
const int n_output = 4;
EXPECT_DEATH(LSTMOpModel lstm(
n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/false,
/*use_projection_weights=*/false,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight_tensor
{n_cell, n_output}, // recurrent_to_forget_weight_tensor
{n_cell, n_output}, // recurrent_to_cell_weight_tensor
{n_cell, n_output}, // recurrent_to_output_weight_tensor
{0}, // cell_to_input_weight tensor
{0}, // cell_to_forget_weight tensor
{0}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_INT32, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false),
"");
EXPECT_DEATH(
LSTMOpModel lstm(
n_batch, n_input, n_cell, n_output,
@ -3345,9 +3308,45 @@ TEST(LSTMOpModel, InvalidTypeTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_COMPLEX64, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false),
/*weight_type=*/TensorType_INT32, /*model_has_legacy_20_inputs=*/true,
/*is_layer_norm=*/false, /*asymmetric_quantize_inputs=*/false),
"");
EXPECT_DEATH(LSTMOpModel lstm(
n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/false,
/*use_projection_weights=*/false,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight_tensor
{n_cell, n_output}, // recurrent_to_forget_weight_tensor
{n_cell, n_output}, // recurrent_to_cell_weight_tensor
{n_cell, n_output}, // recurrent_to_output_weight_tensor
{0}, // cell_to_input_weight tensor
{0}, // cell_to_forget_weight tensor
{0}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_gate_bias tensor
{n_cell}, // output_gate_bias tensor
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
},
/*weight_type=*/TensorType_COMPLEX64,
/*model_has_legacy_20_inputs=*/true, /*is_layer_norm=*/false,
/*asymmetric_quantize_inputs=*/false),
"");
}
#endif