diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index 8c8783fa89f..69284578625 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -51,7 +51,7 @@ class SingleOpModelWithNNAPI : public SingleOpModel { } protected: - void SetData(int index, TensorType type, std::initializer_list data) { + void SetData(int index, TensorType type, const std::vector& data) { switch (type) { case TensorType_FLOAT32: PopulateTensor(index, data); @@ -1887,8 +1887,8 @@ static std::initializer_list rnn_bias = { class RNNOpModel : public SingleOpModelWithNNAPI { public: RNNOpModel(int batches, int units, int size, - const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType weights = TensorType_FLOAT32, + const TensorType recurrent_weights = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); weights_ = AddInput(weights); @@ -2246,11 +2246,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType& weight_type = TensorType_FLOAT32) + const TensorType weight_type) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), - n_output_(n_output) { + n_output_(n_output), + weight_type_(weight_type) { input_ = AddInput(TensorType_FLOAT32); if (use_cifg) { @@ -2324,47 +2325,47 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { } void SetInputToInputWeights(const std::vector& f) { - PopulateTensor(input_to_input_weights_, f); + SetData(input_to_input_weights_, weight_type_, f); } void SetInputToForgetWeights(const std::vector& f) { - PopulateTensor(input_to_forget_weights_, f); + SetData(input_to_forget_weights_, weight_type_, f); } void SetInputToCellWeights(const std::vector& f) { - PopulateTensor(input_to_cell_weights_, f); + SetData(input_to_cell_weights_, weight_type_, f); } void SetInputToOutputWeights(const std::vector& f) { - PopulateTensor(input_to_output_weights_, f); + SetData(input_to_output_weights_, weight_type_, f); } void SetRecurrentToInputWeights(const std::vector& f) { - PopulateTensor(recurrent_to_input_weights_, f); + SetData(recurrent_to_input_weights_, weight_type_, f); } void SetRecurrentToForgetWeights(const std::vector& f) { - PopulateTensor(recurrent_to_forget_weights_, f); + SetData(recurrent_to_forget_weights_, weight_type_, f); } void SetRecurrentToCellWeights(const std::vector& f) { - PopulateTensor(recurrent_to_cell_weights_, f); + SetData(recurrent_to_cell_weights_, weight_type_, f); } void SetRecurrentToOutputWeights(const std::vector& f) { - PopulateTensor(recurrent_to_output_weights_, f); + SetData(recurrent_to_output_weights_, weight_type_, f); } void SetCellToInputWeights(const std::vector& f) { - PopulateTensor(cell_to_input_weights_, f); + SetData(cell_to_input_weights_, weight_type_, f); } void SetCellToForgetWeights(const std::vector& f) { - PopulateTensor(cell_to_forget_weights_, f); + SetData(cell_to_forget_weights_, weight_type_, f); } void SetCellToOutputWeights(const std::vector& f) { - PopulateTensor(cell_to_output_weights_, f); + SetData(cell_to_output_weights_, weight_type_, f); } void SetInputGateBias(const std::vector& f) { @@ -2384,7 +2385,7 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { } void SetProjectionWeights(const std::vector& f) { - PopulateTensor(projection_weights_, f); + SetData(projection_weights_, weight_type_, f); } void SetProjectionBias(const std::vector& f) { @@ -2437,6 +2438,9 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { int n_input_; int n_cell_; int n_output_; + + private: + const TensorType weight_type_; }; class BaseLstmTest : public ::testing::Test { @@ -2582,7 +2586,8 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor - }); + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -2685,7 +2690,8 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor - }); + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -3339,7 +3345,8 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { {n_output, n_cell}, // projection_weight tensor {0}, // projection_bias tensor - }); + }, + /*weight_type=*/TensorType_FLOAT32); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_);