Changes to use delegate after tensor scales are known.
PiperOrigin-RevId: 261393134
This commit is contained in:
parent
aa19d7f8e4
commit
fdcbcba902
@ -135,7 +135,10 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
cell_clip, proj_clip)
|
cell_clip, proj_clip)
|
||||||
.Union());
|
.Union());
|
||||||
|
|
||||||
BuildInterpreter(input_shapes);
|
// Do not apply delegate yet since tensor values are not known (and more
|
||||||
|
// specifically scales in quantized tensors are not known).
|
||||||
|
BuildInterpreter(input_shapes, /*allow_fp32_relax_to_fp16=*/false,
|
||||||
|
/*apply_delegate=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInputToInputWeights(const std::vector<float>& f) {
|
void SetInputToInputWeights(const std::vector<float>& f) {
|
||||||
@ -183,22 +186,18 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetInputLayerNormCoefficients(const std::vector<float>& f) {
|
void SetInputLayerNormCoefficients(const std::vector<float>& f) {
|
||||||
ASSERT_TRUE(is_layer_norm_);
|
|
||||||
PopulateTensor(input_layer_norm_coefficients_, f);
|
PopulateTensor(input_layer_norm_coefficients_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
|
void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
|
||||||
ASSERT_TRUE(is_layer_norm_);
|
|
||||||
PopulateTensor(forget_layer_norm_coefficients_, f);
|
PopulateTensor(forget_layer_norm_coefficients_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetCellLayerNormCoefficients(const std::vector<float>& f) {
|
void SetCellLayerNormCoefficients(const std::vector<float>& f) {
|
||||||
ASSERT_TRUE(is_layer_norm_);
|
|
||||||
PopulateTensor(cell_layer_norm_coefficients_, f);
|
PopulateTensor(cell_layer_norm_coefficients_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
|
void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
|
||||||
ASSERT_TRUE(is_layer_norm_);
|
|
||||||
PopulateTensor(output_layer_norm_coefficients_, f);
|
PopulateTensor(output_layer_norm_coefficients_, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,7 +226,7 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SetInput(int offset, const float* begin, const float* end) {
|
void SetInput(int offset, const float* begin, const float* end) {
|
||||||
PopulateTensor(input_, offset, const_cast<float*>(begin),
|
SingleOpModel::PopulateTensor(input_, offset, const_cast<float*>(begin),
|
||||||
const_cast<float*>(end));
|
const_cast<float*>(end));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,7 +287,16 @@ class LSTMOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void PopulateTensor(int index, const std::vector<T>& data) {
|
||||||
|
// Nothing to do if tensor is an optional input or if data vector is empty.
|
||||||
|
if ((index == kOptionalTensor) || data.empty()) return;
|
||||||
|
SingleOpModel::PopulateTensor(index, data);
|
||||||
|
}
|
||||||
|
|
||||||
void SetWeights(int index, const std::vector<float>& data) {
|
void SetWeights(int index, const std::vector<float>& data) {
|
||||||
|
if (data.empty()) return;
|
||||||
|
if (index == kOptionalTensor) return;
|
||||||
switch (weight_type_) {
|
switch (weight_type_) {
|
||||||
case TensorType_FLOAT32:
|
case TensorType_FLOAT32:
|
||||||
PopulateTensor(index, data);
|
PopulateTensor(index, data);
|
||||||
@ -328,6 +336,10 @@ class BaseLstmTest : public ::testing::Test {
|
|||||||
std::vector<float> cell_to_forget_weights_;
|
std::vector<float> cell_to_forget_weights_;
|
||||||
std::vector<float> cell_to_output_weights_;
|
std::vector<float> cell_to_output_weights_;
|
||||||
std::vector<float> projection_weights_;
|
std::vector<float> projection_weights_;
|
||||||
|
std::vector<float> input_layer_norm_coefficients_;
|
||||||
|
std::vector<float> forget_layer_norm_coefficients_;
|
||||||
|
std::vector<float> cell_layer_norm_coefficients_;
|
||||||
|
std::vector<float> output_layer_norm_coefficients_;
|
||||||
|
|
||||||
// LSTM input is stored as num_batch x num_inputs vector.
|
// LSTM input is stored as num_batch x num_inputs vector.
|
||||||
std::vector<std::vector<float>> lstm_input_;
|
std::vector<std::vector<float>> lstm_input_;
|
||||||
@ -338,6 +350,16 @@ class BaseLstmTest : public ::testing::Test {
|
|||||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||||
const std::vector<std::vector<float>>& output,
|
const std::vector<std::vector<float>>& output,
|
||||||
LSTMOpModel* lstm, float tolerance = 1e-5) {
|
LSTMOpModel* lstm, float tolerance = 1e-5) {
|
||||||
|
// Weights are set twice:
|
||||||
|
// - The delegate, if used, needs to know the scales and zero-points of
|
||||||
|
// quantized tensors, which are computed dynamically when weights are set,
|
||||||
|
// so weights have to be set before applying the delegate.
|
||||||
|
// - Applying a delegate will invalidate the tensor data so weights have to
|
||||||
|
// be set a second time.
|
||||||
|
SetAllWeightsAndBiases(lstm);
|
||||||
|
lstm->ApplyDelegate();
|
||||||
|
SetAllWeightsAndBiases(lstm);
|
||||||
|
|
||||||
const int num_batches = input.size();
|
const int num_batches = input.size();
|
||||||
EXPECT_GT(num_batches, 0);
|
EXPECT_GT(num_batches, 0);
|
||||||
const int num_inputs = lstm->num_inputs();
|
const int num_inputs = lstm->num_inputs();
|
||||||
@ -365,6 +387,37 @@ class BaseLstmTest : public ::testing::Test {
|
|||||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sets all weights and biases that have been defined by test. The test can
|
||||||
|
// define only a subset of all those vectors, and only the ones that have been
|
||||||
|
// defined will be set.
|
||||||
|
void SetAllWeightsAndBiases(LSTMOpModel* lstm) {
|
||||||
|
lstm->SetInputToInputWeights(input_to_input_weights_);
|
||||||
|
lstm->SetInputToCellWeights(input_to_cell_weights_);
|
||||||
|
lstm->SetInputToForgetWeights(input_to_forget_weights_);
|
||||||
|
lstm->SetInputToOutputWeights(input_to_output_weights_);
|
||||||
|
|
||||||
|
lstm->SetInputGateBias(input_gate_bias_);
|
||||||
|
lstm->SetCellBias(cell_gate_bias_);
|
||||||
|
lstm->SetForgetGateBias(forget_gate_bias_);
|
||||||
|
lstm->SetOutputGateBias(output_gate_bias_);
|
||||||
|
|
||||||
|
lstm->SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||||
|
lstm->SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||||
|
lstm->SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||||
|
lstm->SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||||
|
|
||||||
|
lstm->SetCellToInputWeights(cell_to_input_weights_);
|
||||||
|
lstm->SetCellToForgetWeights(cell_to_forget_weights_);
|
||||||
|
lstm->SetCellToOutputWeights(cell_to_output_weights_);
|
||||||
|
|
||||||
|
lstm->SetProjectionWeights(projection_weights_);
|
||||||
|
|
||||||
|
lstm->SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||||
|
lstm->SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
|
||||||
|
lstm->SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||||
|
lstm->SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
||||||
@ -456,21 +509,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
/*weight_type=*/TensorType_FLOAT32,
|
/*weight_type=*/TensorType_FLOAT32,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,21 +564,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_FLOAT32,
|
/*weight_type=*/TensorType_FLOAT32,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -585,21 +608,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
@ -645,21 +653,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||||
/*tolerance=*/0.0157651);
|
/*tolerance=*/0.0157651);
|
||||||
}
|
}
|
||||||
@ -751,21 +744,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
/*weight_type=*/TensorType_FLOAT32,
|
/*weight_type=*/TensorType_FLOAT32,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -810,21 +788,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -869,21 +832,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
|||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1525,27 +1473,6 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
|||||||
/*weight_type=*/TensorType_FLOAT32,
|
/*weight_type=*/TensorType_FLOAT32,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
||||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1588,27 +1515,6 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
|||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
||||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1652,94 +1558,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
|||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/false);
|
/*is_layer_norm=*/false);
|
||||||
|
|
||||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
||||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseLayerNormLstmTest : public ::testing::Test {
|
|
||||||
protected:
|
|
||||||
// Weights of the Layer Norm LSTM model. Some are optional.
|
|
||||||
std::vector<float> input_to_input_weights_;
|
|
||||||
std::vector<float> input_to_cell_weights_;
|
|
||||||
std::vector<float> input_to_forget_weights_;
|
|
||||||
std::vector<float> input_to_output_weights_;
|
|
||||||
std::vector<float> input_gate_bias_;
|
|
||||||
std::vector<float> cell_gate_bias_;
|
|
||||||
std::vector<float> forget_gate_bias_;
|
|
||||||
std::vector<float> output_gate_bias_;
|
|
||||||
std::vector<float> recurrent_to_input_weights_;
|
|
||||||
std::vector<float> recurrent_to_cell_weights_;
|
|
||||||
std::vector<float> recurrent_to_forget_weights_;
|
|
||||||
std::vector<float> recurrent_to_output_weights_;
|
|
||||||
std::vector<float> cell_to_input_weights_;
|
|
||||||
std::vector<float> cell_to_forget_weights_;
|
|
||||||
std::vector<float> cell_to_output_weights_;
|
|
||||||
std::vector<float> projection_weights_;
|
|
||||||
std::vector<float> input_layer_norm_coefficients_;
|
|
||||||
std::vector<float> forget_layer_norm_coefficients_;
|
|
||||||
std::vector<float> cell_layer_norm_coefficients_;
|
|
||||||
std::vector<float> output_layer_norm_coefficients_;
|
|
||||||
|
|
||||||
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
|
|
||||||
std::vector<std::vector<float>> layer_norm_lstm_input_;
|
|
||||||
|
|
||||||
// Compares output up to tolerance to the result of the layer_norm_lstm given
|
|
||||||
// the input.
|
|
||||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
|
||||||
const std::vector<std::vector<float>>& output,
|
|
||||||
LSTMOpModel* layer_norm_lstm, float tolerance = 1e-5) {
|
|
||||||
const int num_batches = input.size();
|
|
||||||
EXPECT_GT(num_batches, 0);
|
|
||||||
const int num_inputs = layer_norm_lstm->num_inputs();
|
|
||||||
EXPECT_GT(num_inputs, 0);
|
|
||||||
const int input_sequence_size = input[0].size() / num_inputs;
|
|
||||||
EXPECT_GT(input_sequence_size, 0);
|
|
||||||
for (int i = 0; i < input_sequence_size; ++i) {
|
|
||||||
for (int b = 0; b < num_batches; ++b) {
|
|
||||||
const float* batch_start = input[b].data() + i * num_inputs;
|
|
||||||
const float* batch_end = batch_start + num_inputs;
|
|
||||||
|
|
||||||
layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
|
|
||||||
batch_start, batch_end);
|
|
||||||
}
|
|
||||||
|
|
||||||
layer_norm_lstm->Invoke();
|
|
||||||
|
|
||||||
const int num_outputs = layer_norm_lstm->num_outputs();
|
|
||||||
std::vector<float> expected;
|
|
||||||
for (int b = 0; b < num_batches; ++b) {
|
|
||||||
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
|
||||||
const float* golden_end_batch = golden_start_batch + num_outputs;
|
|
||||||
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
|
||||||
}
|
|
||||||
EXPECT_THAT(layer_norm_lstm->GetOutput(),
|
|
||||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||||
: public BaseLayerNormLstmTest {
|
: public BaseLstmTest {
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
|
input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
|
||||||
0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
|
0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
|
||||||
@ -1791,7 +1614,7 @@ class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
|||||||
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
||||||
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
||||||
|
|
||||||
layer_norm_lstm_input_ = {
|
lstm_input_ = {
|
||||||
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
||||||
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
||||||
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
||||||
@ -1855,37 +1678,8 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_FLOAT32,
|
/*weight_type=*/TensorType_FLOAT32,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
|
||||||
forget_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
|
||||||
output_layer_norm_coefficients_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
// Verify the final output.
|
// Verify the final output.
|
||||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
lstm_golden_output_ = {{
|
||||||
{
|
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0244077, 0.128027, -0.00170918, // seq 0
|
0.0244077, 0.128027, -0.00170918, // seq 0
|
||||||
0.0137642, 0.140751, 0.0395835, // seq 1
|
0.0137642, 0.140751, 0.0395835, // seq 1
|
||||||
@ -1898,8 +1692,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
0.00752706, 0.161903, 0.0561371, // seq 2
|
0.00752706, 0.161903, 0.0561371, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
&layer_norm_lstm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
@ -1952,36 +1745,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm_golden_output_ = {{
|
||||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
|
||||||
forget_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
|
||||||
output_layer_norm_coefficients_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
|
||||||
{
|
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||||
@ -1994,8 +1758,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
&layer_norm_lstm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
@ -2048,36 +1811,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
lstm_golden_output_ = {{
|
||||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
|
||||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
|
||||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
|
||||||
forget_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
|
||||||
output_layer_norm_coefficients_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
|
||||||
{
|
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||||
@ -2090,12 +1824,10 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
&layer_norm_lstm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class CifgPeepholeProjectionNoClippingLayerNormLstmTest
|
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
|
||||||
: public BaseLayerNormLstmTest {
|
|
||||||
void SetUp() override {
|
void SetUp() override {
|
||||||
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
|
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
|
||||||
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
|
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
|
||||||
@ -2127,7 +1859,7 @@ class CifgPeepholeProjectionNoClippingLayerNormLstmTest
|
|||||||
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
||||||
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
||||||
|
|
||||||
layer_norm_lstm_input_ = {
|
lstm_input_ = {
|
||||||
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
||||||
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
||||||
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
||||||
@ -2191,31 +1923,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_FLOAT32,
|
/*weight_type=*/TensorType_FLOAT32,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
|
||||||
forget_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
|
||||||
output_layer_norm_coefficients_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
// Verify the final output.
|
// Verify the final output.
|
||||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
lstm_golden_output_ = {
|
||||||
{
|
{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.02129706, 0.140816242, 0.0112733059, // seq 0
|
0.02129706, 0.140816242, 0.0112733059, // seq 0
|
||||||
@ -2229,8 +1938,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
-0.0103429332, 0.173016444, 0.0720508844, // seq 2
|
-0.0103429332, 0.173016444, 0.0720508844, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
&layer_norm_lstm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
@ -2283,31 +1991,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_UINT8,
|
/*weight_type=*/TensorType_UINT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
|
||||||
forget_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
|
||||||
output_layer_norm_coefficients_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
// Verify the final output.
|
// Verify the final output.
|
||||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
lstm_golden_output_ = {
|
||||||
{
|
{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
||||||
@ -2321,8 +2006,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
&layer_norm_lstm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||||
@ -2375,31 +2059,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
/*weight_type=*/TensorType_INT8,
|
/*weight_type=*/TensorType_INT8,
|
||||||
/*is_layer_norm=*/true);
|
/*is_layer_norm=*/true);
|
||||||
|
|
||||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
|
||||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
|
||||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
|
||||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
|
||||||
forget_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
|
||||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
|
||||||
output_layer_norm_coefficients_);
|
|
||||||
|
|
||||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
|
||||||
|
|
||||||
// Verify the final output.
|
// Verify the final output.
|
||||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
lstm_golden_output_ = {
|
||||||
{
|
{
|
||||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||||
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
||||||
@ -2413,8 +2074,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
|||||||
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
||||||
}};
|
}};
|
||||||
|
|
||||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||||
&layer_norm_lstm);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GTEST_HAS_DEATH_TEST
|
#ifdef GTEST_HAS_DEATH_TEST
|
||||||
|
@ -116,7 +116,8 @@ void SingleOpModel::SetCustomOp(
|
|||||||
|
|
||||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
int num_threads,
|
int num_threads,
|
||||||
bool allow_fp32_relax_to_fp16) {
|
bool allow_fp32_relax_to_fp16,
|
||||||
|
bool apply_delegate) {
|
||||||
auto opcodes = builder_.CreateVector(opcodes_);
|
auto opcodes = builder_.CreateVector(opcodes_);
|
||||||
auto operators = builder_.CreateVector(operators_);
|
auto operators = builder_.CreateVector(operators_);
|
||||||
auto tensors = builder_.CreateVector(tensors_);
|
auto tensors = builder_.CreateVector(tensors_);
|
||||||
@ -161,6 +162,13 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
|||||||
<< "Cannot allocate tensors";
|
<< "Cannot allocate tensors";
|
||||||
interpreter_->ResetVariableTensors();
|
interpreter_->ResetVariableTensors();
|
||||||
|
|
||||||
|
// In some rare cases a test may need to postpone modifying the graph with
|
||||||
|
// a delegate, e.g. if tensors are not fully specified. In such cases the
|
||||||
|
// test has to explicitly call ApplyDelegate() when necessary.
|
||||||
|
if (apply_delegate) ApplyDelegate();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SingleOpModel::ApplyDelegate() {
|
||||||
if (force_use_nnapi) {
|
if (force_use_nnapi) {
|
||||||
// TODO(b/124505407): Check the result and fail accordingly.
|
// TODO(b/124505407): Check the result and fail accordingly.
|
||||||
interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate());
|
interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate());
|
||||||
@ -179,18 +187,22 @@ TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); }
|
|||||||
void SingleOpModel::BuildInterpreter(
|
void SingleOpModel::BuildInterpreter(
|
||||||
std::vector<std::vector<int>> input_shapes) {
|
std::vector<std::vector<int>> input_shapes) {
|
||||||
BuildInterpreter(input_shapes, /*num_threads=*/-1,
|
BuildInterpreter(input_shapes, /*num_threads=*/-1,
|
||||||
/*allow_fp32_relax_to_fp16=*/false);
|
/*allow_fp32_relax_to_fp16=*/false,
|
||||||
|
/*apply_delegate=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
|
bool allow_fp32_relax_to_fp16,
|
||||||
|
bool apply_delegate) {
|
||||||
|
BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16,
|
||||||
|
apply_delegate);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
int num_threads) {
|
int num_threads) {
|
||||||
BuildInterpreter(input_shapes, num_threads,
|
BuildInterpreter(input_shapes, num_threads,
|
||||||
/*allow_fp32_relax_to_fp16=*/false);
|
/*allow_fp32_relax_to_fp16=*/false,
|
||||||
}
|
/*apply_delegate=*/true);
|
||||||
|
|
||||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
|
||||||
bool allow_fp32_relax_to_fp16) {
|
|
||||||
BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
|
@ -151,6 +151,8 @@ class SingleOpModel {
|
|||||||
apply_delegate_fn_ = apply_delegate_fn;
|
apply_delegate_fn_ = apply_delegate_fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ApplyDelegate();
|
||||||
|
|
||||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||||
SingleOpModel(const SingleOpModel&) = delete;
|
SingleOpModel(const SingleOpModel&) = delete;
|
||||||
SingleOpModel& operator=(const SingleOpModel&) = delete;
|
SingleOpModel& operator=(const SingleOpModel&) = delete;
|
||||||
@ -255,13 +257,14 @@ class SingleOpModel {
|
|||||||
// Build the interpreter for this model. Also, resize and allocate all
|
// Build the interpreter for this model. Also, resize and allocate all
|
||||||
// tensors given the shapes of the inputs.
|
// tensors given the shapes of the inputs.
|
||||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
int num_threads, bool allow_fp32_relax_to_fp16);
|
int num_threads, bool allow_fp32_relax_to_fp16,
|
||||||
|
bool apply_delegate = true);
|
||||||
|
|
||||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
int num_threads);
|
int num_threads);
|
||||||
|
|
||||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
bool allow_fp32_relax_to_fp16);
|
bool allow_fp32_relax_to_fp16, bool apply_delegate);
|
||||||
|
|
||||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user