Merge pull request #26361 from ANSHUMAN87:conv-test-improvement
PiperOrigin-RevId: 256386122
This commit is contained in:
commit
f565864b27
@ -45,7 +45,8 @@ class BaseConvolutionOpModel : public SingleOpModel {
|
|||||||
const TensorData& filter, const TensorData& output, int stride_width = 2,
|
const TensorData& filter, const TensorData& output, int stride_width = 2,
|
||||||
int stride_height = 2, enum Padding padding = Padding_VALID,
|
int stride_height = 2, enum Padding padding = Padding_VALID,
|
||||||
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
|
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
|
||||||
int dilation_width_factor = 1, int dilation_height_factor = 1) {
|
int dilation_width_factor = 1, int dilation_height_factor = 1,
|
||||||
|
int num_threads = -1) {
|
||||||
input_ = AddInput(input);
|
input_ = AddInput(input);
|
||||||
filter_ = AddInput(filter);
|
filter_ = AddInput(filter);
|
||||||
|
|
||||||
@ -97,7 +98,8 @@ class BaseConvolutionOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D,
|
resolver_ = absl::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D,
|
||||||
registration);
|
registration);
|
||||||
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
|
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)},
|
||||||
|
num_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -168,6 +170,37 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(ConvolutionOpTest, SimpleTestFloat32SingleThreaded) {
|
||||||
|
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
|
||||||
|
{TensorType_FLOAT32, {3, 2, 2, 1}},
|
||||||
|
{TensorType_FLOAT32, {}}, 2, 2, Padding_VALID,
|
||||||
|
ActivationFunctionType_NONE, 1, 1, /*num_threads=*/1);
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
// First batch
|
||||||
|
1, 1, 1, 1, // row = 1
|
||||||
|
2, 2, 2, 2, // row = 2
|
||||||
|
// Second batch
|
||||||
|
1, 2, 3, 4, // row = 1
|
||||||
|
1, 2, 3, 4, // row = 2
|
||||||
|
});
|
||||||
|
m.SetFilter({
|
||||||
|
1, 2, 3, 4, // first 2x2 filter
|
||||||
|
-1, 1, -1, 1, // second 2x2 filter
|
||||||
|
-1, -1, 1, 1, // third 2x2 filter
|
||||||
|
});
|
||||||
|
m.SetBias({1, 2, 3});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({
|
||||||
|
18, 2, 5, // first batch, left
|
||||||
|
18, 2, 5, // first batch, right
|
||||||
|
17, 4, 3, // second batch, left
|
||||||
|
37, 4, 3, // second batch, right
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// This test's output is equivalent to the SimpleTestFloat32
|
// This test's output is equivalent to the SimpleTestFloat32
|
||||||
// because we break each input into two channels, each with half of the value,
|
// because we break each input into two channels, each with half of the value,
|
||||||
// while keeping the filters for each channel equivalent.
|
// while keeping the filters for each channel equivalent.
|
||||||
|
@ -115,6 +115,7 @@ void SingleOpModel::SetCustomOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
|
int num_threads,
|
||||||
bool allow_fp32_relax_to_fp16) {
|
bool allow_fp32_relax_to_fp16) {
|
||||||
auto opcodes = builder_.CreateVector(opcodes_);
|
auto opcodes = builder_.CreateVector(opcodes_);
|
||||||
auto operators = builder_.CreateVector(operators_);
|
auto operators = builder_.CreateVector(operators_);
|
||||||
@ -141,7 +142,8 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
|||||||
}
|
}
|
||||||
resolver_ = std::unique_ptr<OpResolver>(resolver);
|
resolver_ = std::unique_ptr<OpResolver>(resolver);
|
||||||
}
|
}
|
||||||
CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_) == kTfLiteOk);
|
CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_, num_threads) ==
|
||||||
|
kTfLiteOk);
|
||||||
|
|
||||||
CHECK(interpreter_ != nullptr);
|
CHECK(interpreter_ != nullptr);
|
||||||
|
|
||||||
@ -174,6 +176,23 @@ void SingleOpModel::Invoke() { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); }
|
|||||||
|
|
||||||
TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); }
|
TfLiteStatus SingleOpModel::InvokeUnchecked() { return interpreter_->Invoke(); }
|
||||||
|
|
||||||
|
void SingleOpModel::BuildInterpreter(
|
||||||
|
std::vector<std::vector<int>> input_shapes) {
|
||||||
|
BuildInterpreter(input_shapes, /*num_threads=*/-1,
|
||||||
|
/*allow_fp32_relax_to_fp16=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
|
int num_threads) {
|
||||||
|
BuildInterpreter(input_shapes, num_threads,
|
||||||
|
/*allow_fp32_relax_to_fp16=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
|
bool allow_fp32_relax_to_fp16) {
|
||||||
|
BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16);
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
|
void SingleOpModel::SetForceUseNnapi(bool use_nnapi) {
|
||||||
force_use_nnapi = use_nnapi;
|
force_use_nnapi = use_nnapi;
|
||||||
|
@ -250,7 +250,15 @@ class SingleOpModel {
|
|||||||
// Build the interpreter for this model. Also, resize and allocate all
|
// Build the interpreter for this model. Also, resize and allocate all
|
||||||
// tensors given the shapes of the inputs.
|
// tensors given the shapes of the inputs.
|
||||||
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
bool allow_fp32_relax_to_fp16 = false);
|
int num_threads, bool allow_fp32_relax_to_fp16);
|
||||||
|
|
||||||
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
|
int num_threads);
|
||||||
|
|
||||||
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
|
||||||
|
bool allow_fp32_relax_to_fp16);
|
||||||
|
|
||||||
|
void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
|
||||||
|
|
||||||
// Executes inference, asserting success.
|
// Executes inference, asserting success.
|
||||||
void Invoke();
|
void Invoke();
|
||||||
|
Loading…
Reference in New Issue
Block a user