diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index 12ed8e088f5..845b7ffe266 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -45,7 +45,8 @@ class BaseConvolutionOpModel : public SingleOpModel { const TensorData& filter, const TensorData& output, int stride_width = 2, int stride_height = 2, enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE, - int dilation_width_factor = 1, int dilation_height_factor = 1) { + int dilation_width_factor = 1, int dilation_height_factor = 1, + int num_threads = -1) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -97,7 +98,8 @@ class BaseConvolutionOpModel : public SingleOpModel { resolver_ = absl::make_unique(BuiltinOperator_CONV_2D, registration); - BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}, + num_threads); } protected: @@ -168,6 +170,37 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) { })); } +TEST_P(ConvolutionOpTest, SimpleTestFloat32SingleThreaded) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, 2, 2, Padding_VALID, + ActivationFunctionType_NONE, 1, 1, /*num_threads=*/1); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + // This test's output is equivalent to the SimpleTestFloat32 // because we break each input into two channels, each with half of the value, // while keeping the filters for each channel equivalent. diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index 379dc9b1970..b4d6eade65e 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -115,6 +115,7 @@ void SingleOpModel::SetCustomOp( } void SingleOpModel::BuildInterpreter(std::vector> input_shapes, + int num_threads, bool allow_fp32_relax_to_fp16) { auto opcodes = builder_.CreateVector(opcodes_); auto operators = builder_.CreateVector(operators_); @@ -141,7 +142,8 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, } resolver_ = std::unique_ptr(resolver); } - CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_) == kTfLiteOk); + CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_, num_threads) == + kTfLiteOk); CHECK(interpreter_ != nullptr); @@ -174,6 +176,23 @@ void SingleOpModel::Invoke() { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); } TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); } +void SingleOpModel::BuildInterpreter( + std::vector> input_shapes) { + BuildInterpreter(input_shapes, /*num_threads=*/-1, + /*allow_fp32_relax_to_fp16=*/false); +} + +void SingleOpModel::BuildInterpreter(std::vector> input_shapes, + int num_threads) { + BuildInterpreter(input_shapes, num_threads, + /*allow_fp32_relax_to_fp16=*/false); +} + +void SingleOpModel::BuildInterpreter(std::vector> input_shapes, + bool allow_fp32_relax_to_fp16) { + BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16); +} + // static void SingleOpModel::SetForceUseNnapi(bool use_nnapi) { force_use_nnapi = use_nnapi; diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 9bf21a819d5..5334e39082e 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -250,7 +250,15 @@ class SingleOpModel { // Build the interpreter for this model. Also, resize and allocate all // tensors given the shapes of the inputs. void BuildInterpreter(std::vector> input_shapes, - bool allow_fp32_relax_to_fp16 = false); + int num_threads, bool allow_fp32_relax_to_fp16); + + void BuildInterpreter(std::vector> input_shapes, + int num_threads); + + void BuildInterpreter(std::vector> input_shapes, + bool allow_fp32_relax_to_fp16); + + void BuildInterpreter(std::vector> input_shapes); // Executes inference, asserting success. void Invoke();