Add tests for the slice op with const inputs
PiperOrigin-RevId: 256358275
This commit is contained in:
parent
2b96dc608e
commit
cc5b183064
@ -23,20 +23,39 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
};
|
||||
|
||||
template <typename input_type, typename index_type>
|
||||
class SliceOpModel : public SingleOpModel {
|
||||
public:
|
||||
SliceOpModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> begin_shape,
|
||||
std::initializer_list<index_type> begin_data,
|
||||
std::initializer_list<int> size_shape,
|
||||
TensorType tensor_index_type, TensorType tensor_input_type) {
|
||||
std::initializer_list<index_type> size_data,
|
||||
TensorType tensor_index_type, TensorType tensor_input_type,
|
||||
TestType input_tensor_types) {
|
||||
input_ = AddInput(tensor_input_type);
|
||||
begin_ = AddInput(tensor_index_type);
|
||||
size_ = AddInput(tensor_index_type);
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
begin_ = AddInput(tensor_index_type);
|
||||
size_ = AddInput(tensor_index_type);
|
||||
} else {
|
||||
begin_ =
|
||||
AddConstInput(GetTensorType<index_type>(), begin_data, begin_shape);
|
||||
size_ = AddConstInput(GetTensorType<index_type>(), size_data, size_shape);
|
||||
}
|
||||
output_ = AddOutput(tensor_input_type);
|
||||
SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions,
|
||||
CreateSliceOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape, begin_shape, size_shape});
|
||||
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
PopulateTensor<index_type>(begin_, begin_data);
|
||||
PopulateTensor<index_type>(size_, size_data);
|
||||
}
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<input_type> data) {
|
||||
@ -45,12 +64,6 @@ class SliceOpModel : public SingleOpModel {
|
||||
void SetStringInput(std::vector<string> data) {
|
||||
PopulateStringTensor(input_, data);
|
||||
}
|
||||
void SetBegin(std::initializer_list<index_type> data) {
|
||||
PopulateTensor<index_type>(begin_, data);
|
||||
}
|
||||
void SetSize(std::initializer_list<index_type> data) {
|
||||
PopulateTensor<index_type>(size_, data);
|
||||
}
|
||||
|
||||
std::vector<input_type> GetOutput() {
|
||||
return ExtractVector<input_type>(output_);
|
||||
@ -64,57 +77,53 @@ class SliceOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(SliceOpTest, In1D) {
|
||||
SliceOpModel<float, int32_t> m({4}, {1}, {1}, TensorType_INT32,
|
||||
TensorType_FLOAT32);
|
||||
class SliceOpTest : public ::testing::TestWithParam<TestType> {};
|
||||
|
||||
TEST_P(SliceOpTest, In1D) {
|
||||
SliceOpModel<float, int32_t> m({4}, {1}, {1}, {1}, {2}, TensorType_INT32,
|
||||
TensorType_FLOAT32, GetParam());
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1});
|
||||
m.SetSize({2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, In2D) {
|
||||
SliceOpModel<float, int32_t> m({2, 3}, {2}, {2}, TensorType_INT32,
|
||||
TensorType_FLOAT32);
|
||||
TEST_P(SliceOpTest, In2D) {
|
||||
SliceOpModel<float, int32_t> m({2, 3}, {2}, {1, 0}, {2}, {1, 2},
|
||||
TensorType_INT32, TensorType_FLOAT32,
|
||||
GetParam());
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetBegin({1, 0});
|
||||
m.SetSize({1, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, In3D) {
|
||||
SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {4}, TensorType_INT32,
|
||||
TensorType_FLOAT32);
|
||||
TEST_P(SliceOpTest, In3D) {
|
||||
SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 2},
|
||||
TensorType_INT32, TensorType_FLOAT32,
|
||||
GetParam());
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetSize({2, 3, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, InputFloat) {
|
||||
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_FLOAT32);
|
||||
TEST_P(SliceOpTest, InputFloat) {
|
||||
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{3, 1, 1, 1}, TensorType_INT32,
|
||||
TensorType_FLOAT32, GetParam());
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({3, 1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, IndexInt64) {
|
||||
SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT64,
|
||||
TensorType_FLOAT32);
|
||||
TEST_P(SliceOpTest, IndexInt64) {
|
||||
SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{3, 1, 1, 1}, TensorType_INT64,
|
||||
TensorType_FLOAT32, GetParam());
|
||||
m.SetInput({1, 2, 3, 4});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({3, 1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||
@ -122,116 +131,106 @@ TEST(SliceOpTest, IndexInt64) {
|
||||
|
||||
// See these test cases under:
|
||||
// https://www.tensorflow.org/versions/master/api_docs/python/tf/slice
|
||||
TEST(SliceOpTest, InputInteger1) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, InputInteger1) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{1, 1, 3, 1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({1, 1, 3, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, InputInteger2) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, InputInteger2) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{1, 2, 3, 1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({1, 2, 3, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 4, 4, 4}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, InputInteger3) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, InputInteger3) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{2, 1, 3, 1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({2, 1, 3, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, SizeMinus1) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, SizeMinus1) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{2, 1, -1, 1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({2, 1, -1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis1) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 3, 2, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis1) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 3, 2, 1}, {4}, {1, 1, 0, 0}, {4},
|
||||
{2, -1, 1, 1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9});
|
||||
m.SetBegin({1, 1, 0, 0});
|
||||
m.SetSize({2, -1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 6, 8, 9}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis2) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis2) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 1, 0}, {4},
|
||||
{2, 1, -1, 1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 1, 0});
|
||||
m.SetSize({2, 1, -1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis3) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 1, 2, 3}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis3) {
|
||||
SliceOpModel<int32_t, int32_t> m({3, 1, 2, 3}, {4}, {1, 0, 0, 1}, {4},
|
||||
{2, 1, 1, -1}, TensorType_INT32,
|
||||
TensorType_INT32, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 1});
|
||||
m.SetSize({2, 1, 1, -1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, SliceUint8) {
|
||||
SliceOpModel<uint8_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_UINT8);
|
||||
TEST_P(SliceOpTest, SliceUint8) {
|
||||
SliceOpModel<uint8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{2, 1, -1, 1}, TensorType_INT32,
|
||||
TensorType_UINT8, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({2, 1, -1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, SliceInt8) {
|
||||
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_INT8);
|
||||
TEST_P(SliceOpTest, SliceInt8) {
|
||||
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{2, 1, -1, 1}, TensorType_INT32,
|
||||
TensorType_INT8, GetParam());
|
||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({2, 1, -1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||
}
|
||||
|
||||
TEST(SliceOpTest, SliceString) {
|
||||
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
||||
TensorType_STRING);
|
||||
TEST_P(SliceOpTest, SliceString) {
|
||||
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||
{2, 1, -1, 1}, TensorType_INT32,
|
||||
TensorType_STRING, GetParam());
|
||||
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
|
||||
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
|
||||
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
||||
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
|
||||
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
|
||||
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
|
||||
m.SetBegin({1, 0, 0, 0});
|
||||
m.SetSize({2, 1, -1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
@ -239,5 +238,8 @@ TEST(SliceOpTest, SliceString) {
|
||||
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SliceOpTest, SliceOpTest,
|
||||
::testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
Loading…
x
Reference in New Issue
Block a user