Changes to use delegate after tensor scales are known.
PiperOrigin-RevId: 261393134
This commit is contained in:
parent
aa19d7f8e4
commit
fdcbcba902
@ -135,7 +135,10 @@ class LSTMOpModel : public SingleOpModel {
|
||||
cell_clip, proj_clip)
|
||||
.Union());
|
||||
|
||||
BuildInterpreter(input_shapes);
|
||||
// Do not apply delegate yet since tensor values are not known (and more
|
||||
// specifically scales in quantized tensors are not known).
|
||||
BuildInterpreter(input_shapes, /*allow_fp32_relax_to_fp16=*/false,
|
||||
/*apply_delegate=*/false);
|
||||
}
|
||||
|
||||
void SetInputToInputWeights(const std::vector<float>& f) {
|
||||
@ -183,22 +186,18 @@ class LSTMOpModel : public SingleOpModel {
|
||||
}
|
||||
|
||||
void SetInputLayerNormCoefficients(const std::vector<float>& f) {
|
||||
ASSERT_TRUE(is_layer_norm_);
|
||||
PopulateTensor(input_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
|
||||
ASSERT_TRUE(is_layer_norm_);
|
||||
PopulateTensor(forget_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetCellLayerNormCoefficients(const std::vector<float>& f) {
|
||||
ASSERT_TRUE(is_layer_norm_);
|
||||
PopulateTensor(cell_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
|
||||
ASSERT_TRUE(is_layer_norm_);
|
||||
PopulateTensor(output_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
@ -227,8 +226,8 @@ class LSTMOpModel : public SingleOpModel {
|
||||
}
|
||||
|
||||
void SetInput(int offset, const float* begin, const float* end) {
|
||||
PopulateTensor(input_, offset, const_cast<float*>(begin),
|
||||
const_cast<float*>(end));
|
||||
SingleOpModel::PopulateTensor(input_, offset, const_cast<float*>(begin),
|
||||
const_cast<float*>(end));
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
@ -288,7 +287,16 @@ class LSTMOpModel : public SingleOpModel {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PopulateTensor(int index, const std::vector<T>& data) {
|
||||
// Nothing to do if tensor is an optional input or if data vector is empty.
|
||||
if ((index == kOptionalTensor) || data.empty()) return;
|
||||
SingleOpModel::PopulateTensor(index, data);
|
||||
}
|
||||
|
||||
void SetWeights(int index, const std::vector<float>& data) {
|
||||
if (data.empty()) return;
|
||||
if (index == kOptionalTensor) return;
|
||||
switch (weight_type_) {
|
||||
case TensorType_FLOAT32:
|
||||
PopulateTensor(index, data);
|
||||
@ -328,6 +336,10 @@ class BaseLstmTest : public ::testing::Test {
|
||||
std::vector<float> cell_to_forget_weights_;
|
||||
std::vector<float> cell_to_output_weights_;
|
||||
std::vector<float> projection_weights_;
|
||||
std::vector<float> input_layer_norm_coefficients_;
|
||||
std::vector<float> forget_layer_norm_coefficients_;
|
||||
std::vector<float> cell_layer_norm_coefficients_;
|
||||
std::vector<float> output_layer_norm_coefficients_;
|
||||
|
||||
// LSTM input is stored as num_batch x num_inputs vector.
|
||||
std::vector<std::vector<float>> lstm_input_;
|
||||
@ -338,6 +350,16 @@ class BaseLstmTest : public ::testing::Test {
|
||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||
const std::vector<std::vector<float>>& output,
|
||||
LSTMOpModel* lstm, float tolerance = 1e-5) {
|
||||
// Weights are set twice:
|
||||
// - The delegate, if used, needs to know the scales and zero-points of
|
||||
// quantized tensors, which are computed dynamically when weights are set,
|
||||
// so weights have to be set before applying the delegate.
|
||||
// - Applying a delegate will invalidate the tensor data so weights have to
|
||||
// be set a second time.
|
||||
SetAllWeightsAndBiases(lstm);
|
||||
lstm->ApplyDelegate();
|
||||
SetAllWeightsAndBiases(lstm);
|
||||
|
||||
const int num_batches = input.size();
|
||||
EXPECT_GT(num_batches, 0);
|
||||
const int num_inputs = lstm->num_inputs();
|
||||
@ -365,6 +387,37 @@ class BaseLstmTest : public ::testing::Test {
|
||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
||||
}
|
||||
}
|
||||
|
||||
// Sets all weights and biases that have been defined by test. The test can
|
||||
// define only a subset of all those vectors, and only the ones that have been
|
||||
// defined will be set.
|
||||
void SetAllWeightsAndBiases(LSTMOpModel* lstm) {
|
||||
lstm->SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm->SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm->SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm->SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm->SetInputGateBias(input_gate_bias_);
|
||||
lstm->SetCellBias(cell_gate_bias_);
|
||||
lstm->SetForgetGateBias(forget_gate_bias_);
|
||||
lstm->SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm->SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm->SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm->SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm->SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm->SetCellToInputWeights(cell_to_input_weights_);
|
||||
lstm->SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm->SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
lstm->SetProjectionWeights(projection_weights_);
|
||||
|
||||
lstm->SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||
lstm->SetForgetLayerNormCoefficients(forget_layer_norm_coefficients_);
|
||||
lstm->SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
lstm->SetOutputLayerNormCoefficients(output_layer_norm_coefficients_);
|
||||
}
|
||||
};
|
||||
|
||||
class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
|
||||
@ -456,21 +509,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
@ -526,21 +564,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
@ -585,21 +608,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||
/*tolerance=*/0.0157651);
|
||||
}
|
||||
@ -645,21 +653,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
|
||||
/*tolerance=*/0.0157651);
|
||||
}
|
||||
@ -751,21 +744,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
@ -810,21 +788,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||
}
|
||||
|
||||
@ -869,21 +832,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
|
||||
}
|
||||
|
||||
@ -1525,27 +1473,6 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
|
||||
}
|
||||
|
||||
@ -1588,27 +1515,6 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) {
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||
}
|
||||
|
||||
@ -1652,94 +1558,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*is_layer_norm=*/false);
|
||||
|
||||
lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
lstm.SetInputGateBias(input_gate_bias_);
|
||||
lstm.SetCellBias(cell_gate_bias_);
|
||||
lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||
}
|
||||
|
||||
class BaseLayerNormLstmTest : public ::testing::Test {
|
||||
protected:
|
||||
// Weights of the Layer Norm LSTM model. Some are optional.
|
||||
std::vector<float> input_to_input_weights_;
|
||||
std::vector<float> input_to_cell_weights_;
|
||||
std::vector<float> input_to_forget_weights_;
|
||||
std::vector<float> input_to_output_weights_;
|
||||
std::vector<float> input_gate_bias_;
|
||||
std::vector<float> cell_gate_bias_;
|
||||
std::vector<float> forget_gate_bias_;
|
||||
std::vector<float> output_gate_bias_;
|
||||
std::vector<float> recurrent_to_input_weights_;
|
||||
std::vector<float> recurrent_to_cell_weights_;
|
||||
std::vector<float> recurrent_to_forget_weights_;
|
||||
std::vector<float> recurrent_to_output_weights_;
|
||||
std::vector<float> cell_to_input_weights_;
|
||||
std::vector<float> cell_to_forget_weights_;
|
||||
std::vector<float> cell_to_output_weights_;
|
||||
std::vector<float> projection_weights_;
|
||||
std::vector<float> input_layer_norm_coefficients_;
|
||||
std::vector<float> forget_layer_norm_coefficients_;
|
||||
std::vector<float> cell_layer_norm_coefficients_;
|
||||
std::vector<float> output_layer_norm_coefficients_;
|
||||
|
||||
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
|
||||
std::vector<std::vector<float>> layer_norm_lstm_input_;
|
||||
|
||||
// Compares output up to tolerance to the result of the layer_norm_lstm given
|
||||
// the input.
|
||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||
const std::vector<std::vector<float>>& output,
|
||||
LSTMOpModel* layer_norm_lstm, float tolerance = 1e-5) {
|
||||
const int num_batches = input.size();
|
||||
EXPECT_GT(num_batches, 0);
|
||||
const int num_inputs = layer_norm_lstm->num_inputs();
|
||||
EXPECT_GT(num_inputs, 0);
|
||||
const int input_sequence_size = input[0].size() / num_inputs;
|
||||
EXPECT_GT(input_sequence_size, 0);
|
||||
for (int i = 0; i < input_sequence_size; ++i) {
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* batch_start = input[b].data() + i * num_inputs;
|
||||
const float* batch_end = batch_start + num_inputs;
|
||||
|
||||
layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
|
||||
batch_start, batch_end);
|
||||
}
|
||||
|
||||
layer_norm_lstm->Invoke();
|
||||
|
||||
const int num_outputs = layer_norm_lstm->num_outputs();
|
||||
std::vector<float> expected;
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
||||
const float* golden_end_batch = golden_start_batch + num_outputs;
|
||||
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
||||
}
|
||||
EXPECT_THAT(layer_norm_lstm->GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||
: public BaseLayerNormLstmTest {
|
||||
: public BaseLstmTest {
|
||||
void SetUp() override {
|
||||
input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
|
||||
0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
|
||||
@ -1791,7 +1614,7 @@ class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
||||
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
||||
|
||||
layer_norm_lstm_input_ = {
|
||||
lstm_input_ = {
|
||||
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
||||
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
||||
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
||||
@ -1855,51 +1678,21 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244077, 0.128027, -0.00170918, // seq 0
|
||||
0.0137642, 0.140751, 0.0395835, // seq 1
|
||||
-0.00459231, 0.155278, 0.0837377, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00692428, 0.0848741, 0.063445, // seq 0
|
||||
-0.00403912, 0.139963, 0.072681, // seq 1
|
||||
0.00752706, 0.161903, 0.0561371, // seq 2
|
||||
}};
|
||||
lstm_golden_output_ = {{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244077, 0.128027, -0.00170918, // seq 0
|
||||
0.0137642, 0.140751, 0.0395835, // seq 1
|
||||
-0.00459231, 0.155278, 0.0837377, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00692428, 0.0848741, 0.063445, // seq 0
|
||||
-0.00403912, 0.139963, 0.072681, // seq 1
|
||||
0.00752706, 0.161903, 0.0561371, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||
}
|
||||
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
@ -1952,50 +1745,20 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
lstm_golden_output_ = {{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||
-0.0048839, 0.155096, 0.0840309, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
||||
-0.00448382, 0.139278, 0.0737372, // seq 1
|
||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||
}};
|
||||
|
||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||
-0.0048839, 0.155096, 0.0840309, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
||||
-0.00448382, 0.139278, 0.0737372, // seq 1
|
||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||
}
|
||||
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
@ -2048,54 +1811,23 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
lstm_golden_output_ = {{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||
-0.0048839, 0.155096, 0.0840309, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
||||
-0.00448382, 0.139278, 0.0737372, // seq 1
|
||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||
}};
|
||||
|
||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||
-0.0048839, 0.155096, 0.0840309, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
||||
-0.00448382, 0.139278, 0.0737372, // seq 1
|
||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||
}
|
||||
|
||||
class CifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||
: public BaseLayerNormLstmTest {
|
||||
class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
|
||||
void SetUp() override {
|
||||
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
|
||||
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
|
||||
@ -2127,7 +1859,7 @@ class CifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
||||
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
||||
|
||||
layer_norm_lstm_input_ = {
|
||||
lstm_input_ = {
|
||||
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
||||
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
||||
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
||||
@ -2191,31 +1923,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_FLOAT32,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
lstm_golden_output_ = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.02129706, 0.140816242, 0.0112733059, // seq 0
|
||||
@ -2229,8 +1938,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
-0.0103429332, 0.173016444, 0.0720508844, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||
}
|
||||
|
||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
@ -2283,31 +1991,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_UINT8,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
lstm_golden_output_ = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
||||
@ -2321,8 +2006,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||
}
|
||||
|
||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
@ -2375,31 +2059,8 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
/*weight_type=*/TensorType_INT8,
|
||||
/*is_layer_norm=*/true);
|
||||
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
lstm_golden_output_ = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
||||
@ -2413,8 +2074,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm);
|
||||
}
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
|
@ -116,7 +116,8 @@ void SingleOpModel::SetCustomOp(
|
||||
|
||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
int num_threads,
|
||||
bool allow_fp32_relax_to_fp16) {
|
||||
bool allow_fp32_relax_to_fp16,
|
||||
bool apply_delegate) {
|
||||
auto opcodes = builder_.CreateVector(opcodes_);
|
||||
auto operators = builder_.CreateVector(operators_);
|
||||
auto tensors = builder_.CreateVector(tensors_);
|
||||
@ -161,6 +162,13 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
<< "Cannot allocate tensors";
|
||||
interpreter_->ResetVariableTensors();
|
||||
|
||||
// In some rare cases a test may need to postpone modifying the graph with
|
||||
// a delegate, e.g. if tensors are not fully specified. In such cases the
|
||||
// test has to explicitly call ApplyDelegate() when necessary.
|
||||
if (apply_delegate) ApplyDelegate();
|
||||
}
|
||||
|
||||
void SingleOpModel::ApplyDelegate() {
|
||||
if (force_use_nnapi) {
|
||||
// TODO(b/124505407): Check the result and fail accordingly.
|
||||
interpreter_->ModifyGraphWithDelegate(TestNnApiDelegate());
|
||||
@ -179,18 +187,22 @@ TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); }
|
||||
void SingleOpModel::BuildInterpreter(
|
||||
std::vector<std::vector<int>> input_shapes) {
|
||||
BuildInterpreter(input_shapes, /*num_threads=*/-1,
|
||||
/*allow_fp32_relax_to_fp16=*/false);
|
||||
/*allow_fp32_relax_to_fp16=*/false,
|
||||
/*apply_delegate=*/true);
|
||||
}
|
||||
|
||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
bool allow_fp32_relax_to_fp16,
|
||||
bool apply_delegate) {
|
||||
BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16,
|
||||
apply_delegate);
|
||||
}
|
||||
|
||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
int num_threads) {
|
||||
BuildInterpreter(input_shapes, num_threads,
|
||||
/*allow_fp32_relax_to_fp16=*/false);
|
||||
}
|
||||
|
||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
bool allow_fp32_relax_to_fp16) {
|
||||
BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16);
|
||||
/*allow_fp32_relax_to_fp16=*/false,
|
||||
/*apply_delegate=*/true);
|
||||
}
|
||||
|
||||
// static
|
||||
|
@ -151,6 +151,8 @@ class SingleOpModel {
|
||||
apply_delegate_fn_ = apply_delegate_fn;
|
||||
}
|
||||
|
||||
void ApplyDelegate();
|
||||
|
||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||
SingleOpModel(const SingleOpModel&) = delete;
|
||||
SingleOpModel& operator=(const SingleOpModel&) = delete;
|
||||
@ -255,13 +257,14 @@ class SingleOpModel {
|
||||
// Build the interpreter for this model. Also, resize and allocate all
|
||||
// tensors given the shapes of the inputs.
|
||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
int num_threads, bool allow_fp32_relax_to_fp16);
|
||||
int num_threads, bool allow_fp32_relax_to_fp16,
|
||||
bool apply_delegate = true);
|
||||
|
||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
int num_threads);
|
||||
|
||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||
bool allow_fp32_relax_to_fp16);
|
||||
bool allow_fp32_relax_to_fp16, bool apply_delegate);
|
||||
|
||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user