Add tests for the slice op with const inputs

PiperOrigin-RevId: 256358275
This commit is contained in:
A. Unique TensorFlower 2019-07-03 06:49:59 -07:00 committed by TensorFlower Gardener
parent 2b96dc608e
commit cc5b183064

View File

@ -23,20 +23,39 @@ namespace {
using ::testing::ElementsAreArray;
enum class TestType {
CONST = 0,
DYNAMIC = 1,
};
template <typename input_type, typename index_type>
class SliceOpModel : public SingleOpModel {
public:
SliceOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> begin_shape,
std::initializer_list<index_type> begin_data,
std::initializer_list<int> size_shape,
TensorType tensor_index_type, TensorType tensor_input_type) {
std::initializer_list<index_type> size_data,
TensorType tensor_index_type, TensorType tensor_input_type,
TestType input_tensor_types) {
input_ = AddInput(tensor_input_type);
begin_ = AddInput(tensor_index_type);
size_ = AddInput(tensor_index_type);
if (input_tensor_types == TestType::DYNAMIC) {
begin_ = AddInput(tensor_index_type);
size_ = AddInput(tensor_index_type);
} else {
begin_ =
AddConstInput(GetTensorType<index_type>(), begin_data, begin_shape);
size_ = AddConstInput(GetTensorType<index_type>(), size_data, size_shape);
}
output_ = AddOutput(tensor_input_type);
SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions,
CreateSliceOptions(builder_).Union());
BuildInterpreter({input_shape, begin_shape, size_shape});
if (input_tensor_types == TestType::DYNAMIC) {
PopulateTensor<index_type>(begin_, begin_data);
PopulateTensor<index_type>(size_, size_data);
}
}
void SetInput(std::initializer_list<input_type> data) {
@ -45,12 +64,6 @@ class SliceOpModel : public SingleOpModel {
void SetStringInput(std::vector<string> data) {
PopulateStringTensor(input_, data);
}
void SetBegin(std::initializer_list<index_type> data) {
PopulateTensor<index_type>(begin_, data);
}
void SetSize(std::initializer_list<index_type> data) {
PopulateTensor<index_type>(size_, data);
}
std::vector<input_type> GetOutput() {
return ExtractVector<input_type>(output_);
@ -64,57 +77,53 @@ class SliceOpModel : public SingleOpModel {
int output_;
};
TEST(SliceOpTest, In1D) {
SliceOpModel<float, int32_t> m({4}, {1}, {1}, TensorType_INT32,
TensorType_FLOAT32);
class SliceOpTest : public ::testing::TestWithParam<TestType> {};
TEST_P(SliceOpTest, In1D) {
SliceOpModel<float, int32_t> m({4}, {1}, {1}, {1}, {2}, TensorType_INT32,
TensorType_FLOAT32, GetParam());
m.SetInput({1, 2, 3, 4});
m.SetBegin({1});
m.SetSize({2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
}
TEST(SliceOpTest, In2D) {
SliceOpModel<float, int32_t> m({2, 3}, {2}, {2}, TensorType_INT32,
TensorType_FLOAT32);
TEST_P(SliceOpTest, In2D) {
SliceOpModel<float, int32_t> m({2, 3}, {2}, {1, 0}, {2}, {1, 2},
TensorType_INT32, TensorType_FLOAT32,
GetParam());
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({1, 0});
m.SetSize({1, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
}
TEST(SliceOpTest, In3D) {
SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {4}, TensorType_INT32,
TensorType_FLOAT32);
TEST_P(SliceOpTest, In3D) {
SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 2},
TensorType_INT32, TensorType_FLOAT32,
GetParam());
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetSize({2, 3, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
}
TEST(SliceOpTest, InputFloat) {
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT32,
TensorType_FLOAT32);
TEST_P(SliceOpTest, InputFloat) {
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
{3, 1, 1, 1}, TensorType_INT32,
TensorType_FLOAT32, GetParam());
m.SetInput({1, 2, 3, 4});
m.SetBegin({1, 0, 0, 0});
m.SetSize({3, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
}
TEST(SliceOpTest, IndexInt64) {
SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT64,
TensorType_FLOAT32);
TEST_P(SliceOpTest, IndexInt64) {
SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
{3, 1, 1, 1}, TensorType_INT64,
TensorType_FLOAT32, GetParam());
m.SetInput({1, 2, 3, 4});
m.SetBegin({1, 0, 0, 0});
m.SetSize({3, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
@ -122,116 +131,106 @@ TEST(SliceOpTest, IndexInt64) {
// See these test cases under:
// https://www.tensorflow.org/versions/master/api_docs/python/tf/slice
TEST(SliceOpTest, InputInteger1) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, InputInteger1) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{1, 1, 3, 1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 0});
m.SetSize({1, 1, 3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3}));
}
TEST(SliceOpTest, InputInteger2) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, InputInteger2) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{1, 2, 3, 1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 0});
m.SetSize({1, 2, 3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 4, 4, 4}));
}
TEST(SliceOpTest, InputInteger3) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, InputInteger3) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, 3, 1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, 3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST(SliceOpTest, SizeMinus1) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, SizeMinus1) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis1) {
SliceOpModel<int32_t, int32_t> m({3, 3, 2, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis1) {
SliceOpModel<int32_t, int32_t> m({3, 3, 2, 1}, {4}, {1, 1, 0, 0}, {4},
{2, -1, 1, 1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9});
m.SetBegin({1, 1, 0, 0});
m.SetSize({2, -1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 6, 8, 9}));
}
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis2) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis2) {
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 1, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 1, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
}
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis3) {
SliceOpModel<int32_t, int32_t> m({3, 1, 2, 3}, {4}, {4}, TensorType_INT32,
TensorType_INT32);
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis3) {
SliceOpModel<int32_t, int32_t> m({3, 1, 2, 3}, {4}, {1, 0, 0, 1}, {4},
{2, 1, 1, -1}, TensorType_INT32,
TensorType_INT32, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 1});
m.SetSize({2, 1, 1, -1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
}
TEST(SliceOpTest, SliceUint8) {
SliceOpModel<uint8_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_UINT8);
TEST_P(SliceOpTest, SliceUint8) {
SliceOpModel<uint8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_UINT8, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST(SliceOpTest, SliceInt8) {
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_INT8);
TEST_P(SliceOpTest, SliceInt8) {
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_INT8, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST(SliceOpTest, SliceString) {
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
TensorType_STRING);
TEST_P(SliceOpTest, SliceString) {
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_STRING, GetParam());
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
m.SetBegin({1, 0, 0, 0});
m.SetSize({2, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(),
@ -239,5 +238,8 @@ TEST(SliceOpTest, SliceString) {
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
}
INSTANTIATE_TEST_SUITE_P(SliceOpTest, SliceOpTest,
::testing::Values(TestType::CONST, TestType::DYNAMIC));
} // namespace
} // namespace tflite