Add tests for the slice op with const inputs
PiperOrigin-RevId: 256358275
This commit is contained in:
parent
2b96dc608e
commit
cc5b183064
@ -23,20 +23,39 @@ namespace {
|
|||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
enum class TestType {
|
||||||
|
CONST = 0,
|
||||||
|
DYNAMIC = 1,
|
||||||
|
};
|
||||||
|
|
||||||
template <typename input_type, typename index_type>
|
template <typename input_type, typename index_type>
|
||||||
class SliceOpModel : public SingleOpModel {
|
class SliceOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
SliceOpModel(std::initializer_list<int> input_shape,
|
SliceOpModel(std::initializer_list<int> input_shape,
|
||||||
std::initializer_list<int> begin_shape,
|
std::initializer_list<int> begin_shape,
|
||||||
|
std::initializer_list<index_type> begin_data,
|
||||||
std::initializer_list<int> size_shape,
|
std::initializer_list<int> size_shape,
|
||||||
TensorType tensor_index_type, TensorType tensor_input_type) {
|
std::initializer_list<index_type> size_data,
|
||||||
|
TensorType tensor_index_type, TensorType tensor_input_type,
|
||||||
|
TestType input_tensor_types) {
|
||||||
input_ = AddInput(tensor_input_type);
|
input_ = AddInput(tensor_input_type);
|
||||||
begin_ = AddInput(tensor_index_type);
|
if (input_tensor_types == TestType::DYNAMIC) {
|
||||||
size_ = AddInput(tensor_index_type);
|
begin_ = AddInput(tensor_index_type);
|
||||||
|
size_ = AddInput(tensor_index_type);
|
||||||
|
} else {
|
||||||
|
begin_ =
|
||||||
|
AddConstInput(GetTensorType<index_type>(), begin_data, begin_shape);
|
||||||
|
size_ = AddConstInput(GetTensorType<index_type>(), size_data, size_shape);
|
||||||
|
}
|
||||||
output_ = AddOutput(tensor_input_type);
|
output_ = AddOutput(tensor_input_type);
|
||||||
SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions,
|
SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions,
|
||||||
CreateSliceOptions(builder_).Union());
|
CreateSliceOptions(builder_).Union());
|
||||||
BuildInterpreter({input_shape, begin_shape, size_shape});
|
BuildInterpreter({input_shape, begin_shape, size_shape});
|
||||||
|
|
||||||
|
if (input_tensor_types == TestType::DYNAMIC) {
|
||||||
|
PopulateTensor<index_type>(begin_, begin_data);
|
||||||
|
PopulateTensor<index_type>(size_, size_data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInput(std::initializer_list<input_type> data) {
|
void SetInput(std::initializer_list<input_type> data) {
|
||||||
@ -45,12 +64,6 @@ class SliceOpModel : public SingleOpModel {
|
|||||||
void SetStringInput(std::vector<string> data) {
|
void SetStringInput(std::vector<string> data) {
|
||||||
PopulateStringTensor(input_, data);
|
PopulateStringTensor(input_, data);
|
||||||
}
|
}
|
||||||
void SetBegin(std::initializer_list<index_type> data) {
|
|
||||||
PopulateTensor<index_type>(begin_, data);
|
|
||||||
}
|
|
||||||
void SetSize(std::initializer_list<index_type> data) {
|
|
||||||
PopulateTensor<index_type>(size_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<input_type> GetOutput() {
|
std::vector<input_type> GetOutput() {
|
||||||
return ExtractVector<input_type>(output_);
|
return ExtractVector<input_type>(output_);
|
||||||
@ -64,57 +77,53 @@ class SliceOpModel : public SingleOpModel {
|
|||||||
int output_;
|
int output_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(SliceOpTest, In1D) {
|
class SliceOpTest : public ::testing::TestWithParam<TestType> {};
|
||||||
SliceOpModel<float, int32_t> m({4}, {1}, {1}, TensorType_INT32,
|
|
||||||
TensorType_FLOAT32);
|
TEST_P(SliceOpTest, In1D) {
|
||||||
|
SliceOpModel<float, int32_t> m({4}, {1}, {1}, {1}, {2}, TensorType_INT32,
|
||||||
|
TensorType_FLOAT32, GetParam());
|
||||||
m.SetInput({1, 2, 3, 4});
|
m.SetInput({1, 2, 3, 4});
|
||||||
m.SetBegin({1});
|
|
||||||
m.SetSize({2});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, In2D) {
|
TEST_P(SliceOpTest, In2D) {
|
||||||
SliceOpModel<float, int32_t> m({2, 3}, {2}, {2}, TensorType_INT32,
|
SliceOpModel<float, int32_t> m({2, 3}, {2}, {1, 0}, {2}, {1, 2},
|
||||||
TensorType_FLOAT32);
|
TensorType_INT32, TensorType_FLOAT32,
|
||||||
|
GetParam());
|
||||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||||
m.SetBegin({1, 0});
|
|
||||||
m.SetSize({1, 2});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, In3D) {
|
TEST_P(SliceOpTest, In3D) {
|
||||||
SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {4}, TensorType_INT32,
|
SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 2},
|
||||||
TensorType_FLOAT32);
|
TensorType_INT32, TensorType_FLOAT32,
|
||||||
|
GetParam());
|
||||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
m.SetBegin({0, 0, 0});
|
|
||||||
m.SetSize({2, 3, 2});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
|
||||||
EXPECT_THAT(m.GetOutput(),
|
EXPECT_THAT(m.GetOutput(),
|
||||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, InputFloat) {
|
TEST_P(SliceOpTest, InputFloat) {
|
||||||
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_FLOAT32);
|
{3, 1, 1, 1}, TensorType_INT32,
|
||||||
|
TensorType_FLOAT32, GetParam());
|
||||||
m.SetInput({1, 2, 3, 4});
|
m.SetInput({1, 2, 3, 4});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({3, 1, 1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, IndexInt64) {
|
TEST_P(SliceOpTest, IndexInt64) {
|
||||||
SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT64,
|
SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_FLOAT32);
|
{3, 1, 1, 1}, TensorType_INT64,
|
||||||
|
TensorType_FLOAT32, GetParam());
|
||||||
m.SetInput({1, 2, 3, 4});
|
m.SetInput({1, 2, 3, 4});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({3, 1, 1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
|
||||||
@ -122,116 +131,106 @@ TEST(SliceOpTest, IndexInt64) {
|
|||||||
|
|
||||||
// See these test cases under:
|
// See these test cases under:
|
||||||
// https://www.tensorflow.org/versions/master/api_docs/python/tf/slice
|
// https://www.tensorflow.org/versions/master/api_docs/python/tf/slice
|
||||||
TEST(SliceOpTest, InputInteger1) {
|
TEST_P(SliceOpTest, InputInteger1) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_INT32);
|
{1, 1, 3, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({1, 1, 3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, InputInteger2) {
|
TEST_P(SliceOpTest, InputInteger2) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_INT32);
|
{1, 2, 3, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({1, 2, 3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 4, 4, 4}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 4, 4, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, InputInteger3) {
|
TEST_P(SliceOpTest, InputInteger3) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_INT32);
|
{2, 1, 3, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({2, 1, 3, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, SizeMinus1) {
|
TEST_P(SliceOpTest, SizeMinus1) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_INT32);
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({2, 1, -1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis1) {
|
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis1) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 3, 2, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 3, 2, 1}, {4}, {1, 1, 0, 0}, {4},
|
||||||
TensorType_INT32);
|
{2, -1, 1, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9});
|
m.SetInput({1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9});
|
||||||
m.SetBegin({1, 1, 0, 0});
|
|
||||||
m.SetSize({2, -1, 1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 6, 8, 9}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 6, 8, 9}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis2) {
|
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis2) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 1, 0}, {4},
|
||||||
TensorType_INT32);
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 1, 0});
|
|
||||||
m.SetSize({2, 1, -1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis3) {
|
TEST_P(SliceOpTest, BeginNonZeroSizeMinus1Axis3) {
|
||||||
SliceOpModel<int32_t, int32_t> m({3, 1, 2, 3}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int32_t, int32_t> m({3, 1, 2, 3}, {4}, {1, 0, 0, 1}, {4},
|
||||||
TensorType_INT32);
|
{2, 1, 1, -1}, TensorType_INT32,
|
||||||
|
TensorType_INT32, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 1});
|
|
||||||
m.SetSize({2, 1, 1, -1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 2}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 2}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, SliceUint8) {
|
TEST_P(SliceOpTest, SliceUint8) {
|
||||||
SliceOpModel<uint8_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<uint8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_UINT8);
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
TensorType_UINT8, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({2, 1, -1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, SliceInt8) {
|
TEST_P(SliceOpTest, SliceInt8) {
|
||||||
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<int8_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_INT8);
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
TensorType_INT8, GetParam());
|
||||||
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({2, 1, -1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SliceOpTest, SliceString) {
|
TEST_P(SliceOpTest, SliceString) {
|
||||||
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
|
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
|
||||||
TensorType_STRING);
|
{2, 1, -1, 1}, TensorType_INT32,
|
||||||
|
TensorType_STRING, GetParam());
|
||||||
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
|
m.SetStringInput({"0,0,0,0", "0,0,1,0", "0,0,2,0", //
|
||||||
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
|
"0,1,0,0", "0,1,1,0", "0,1,2,0", //
|
||||||
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
"1,0,0,0", "1,0,1,0", "1,0,2,0", //
|
||||||
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
|
"1,1,0,0", "1,1,1,0", "1,1,2,0", //
|
||||||
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
|
"2,0,0,0", "2,0,1,0", "2,0,2,0", //
|
||||||
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
|
"2,1,0,0", "2,1,1,0", "2,1,2,0"});
|
||||||
m.SetBegin({1, 0, 0, 0});
|
|
||||||
m.SetSize({2, 1, -1, 1});
|
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
|
||||||
EXPECT_THAT(m.GetOutput(),
|
EXPECT_THAT(m.GetOutput(),
|
||||||
@ -239,5 +238,8 @@ TEST(SliceOpTest, SliceString) {
|
|||||||
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
|
"2,0,0,0", "2,0,1,0", "2,0,2,0"}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(SliceOpTest, SliceOpTest,
|
||||||
|
::testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
x
Reference in New Issue
Block a user