Merge pull request #26361 from ANSHUMAN87:conv-test-improvement

PiperOrigin-RevId: 256386122
This commit is contained in:
TensorFlower Gardener 2019-07-03 13:31:42 -07:00
commit f565864b27
3 changed files with 64 additions and 4 deletions

View File

@ -45,7 +45,8 @@ class BaseConvolutionOpModel : public SingleOpModel {
const TensorData& filter, const TensorData& output, int stride_width = 2, const TensorData& filter, const TensorData& output, int stride_width = 2,
int stride_height = 2, enum Padding padding = Padding_VALID, int stride_height = 2, enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE, enum ActivationFunctionType activation = ActivationFunctionType_NONE,
int dilation_width_factor = 1, int dilation_height_factor = 1) { int dilation_width_factor = 1, int dilation_height_factor = 1,
int num_threads = -1) {
input_ = AddInput(input); input_ = AddInput(input);
filter_ = AddInput(filter); filter_ = AddInput(filter);
@ -97,7 +98,8 @@ class BaseConvolutionOpModel : public SingleOpModel {
resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D, resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D,
registration); registration);
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)},
num_threads);
} }
protected: protected:
@ -168,6 +170,37 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) {
})); }));
} }
TEST_P(ConvolutionOpTest, SimpleTestFloat32SingleThreaded) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}}, 2, 2, Padding_VALID,
ActivationFunctionType_NONE, 1, 1, /*num_threads=*/1);
m.SetInput({
// First batch
1, 1, 1, 1, // row = 1
2, 2, 2, 2, // row = 2
// Second batch
1, 2, 3, 4, // row = 1
1, 2, 3, 4, // row = 2
});
m.SetFilter({
1, 2, 3, 4, // first 2x2 filter
-1, 1, -1, 1, // second 2x2 filter
-1, -1, 1, 1, // third 2x2 filter
});
m.SetBias({1, 2, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({
18, 2, 5, // first batch, left
18, 2, 5, // first batch, right
17, 4, 3, // second batch, left
37, 4, 3, // second batch, right
}));
}
// This test's output is equivalent to the SimpleTestFloat32 // This test's output is equivalent to the SimpleTestFloat32
// because we break each input into two channels, each with half of the value, // because we break each input into two channels, each with half of the value,
// while keeping the filters for each channel equivalent. // while keeping the filters for each channel equivalent.

View File

@ -115,6 +115,7 @@ void SingleOpModel::SetCustomOp(
} }
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes, void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads,
bool allow_fp32_relax_to_fp16) { bool allow_fp32_relax_to_fp16) {
auto opcodes = builder_.CreateVector(opcodes_); auto opcodes = builder_.CreateVector(opcodes_);
auto operators = builder_.CreateVector(operators_); auto operators = builder_.CreateVector(operators_);
@ -141,7 +142,8 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
} }
resolver_ = std::unique_ptr<OpResolver>(resolver); resolver_ = std::unique_ptr<OpResolver>(resolver);
} }
CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_) == kTfLiteOk); CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_, num_threads) ==
kTfLiteOk);
CHECK(interpreter_ != nullptr); CHECK(interpreter_ != nullptr);
@ -174,6 +176,23 @@ void SingleOpModel::Invoke() { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); }
TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); } TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); }
void SingleOpModel::BuildInterpreter(
std::vector<std::vector<int>> input_shapes) {
BuildInterpreter(input_shapes, /*num_threads=*/-1,
/*allow_fp32_relax_to_fp16=*/false);
}
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads) {
BuildInterpreter(input_shapes, num_threads,
/*allow_fp32_relax_to_fp16=*/false);
}
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
bool allow_fp32_relax_to_fp16) {
BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16);
}
// static // static
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) { void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
force_use_nnapi = use_nnapi; force_use_nnapi = use_nnapi;

View File

@ -250,7 +250,15 @@ class SingleOpModel {
// Build the interpreter for this model. Also, resize and allocate all // Build the interpreter for this model. Also, resize and allocate all
// tensors given the shapes of the inputs. // tensors given the shapes of the inputs.
void BuildInterpreter(std::vector<std::vector<int>> input_shapes, void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
bool allow_fp32_relax_to_fp16 = false); int num_threads, bool allow_fp32_relax_to_fp16);
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads);
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
bool allow_fp32_relax_to_fp16);
void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
// Executes inference, asserting success. // Executes inference, asserting success.
void Invoke(); void Invoke();