From 3cd6bdef5fa44efbf2b16eeb5fe026be839e6898 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Sep 2017 18:48:38 -0700 Subject: [PATCH] Added test cases on R4 slice. PiperOrigin-RevId: 168482049 --- tensorflow/compiler/xla/tests/slice_test.cc | 186 ++++++++++++++++++-- 1 file changed, 168 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 5da6104cfa7..3bf0f411a88 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -25,12 +25,16 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrCat; + class SliceTest : public ClientLibraryTestBase {}; TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { @@ -161,6 +165,20 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { + Array4D values(2, 4, 6, 8); + values.FillRandom(3.14f); + auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}}, + /*strides=*/{{1, 1, 2, 1}}); + auto expected_literal = Literal::CreateR4FromArray4DWithLayout( + *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); + ComputationBuilder builder(client_, TestName()); + auto original = builder.ConstantR4FromArray4D(values); + builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); + ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), + &expected_literal->shape()); +} + struct R1Spec { int64 input_dim0; int64 slice_start; @@ -193,29 +211,17 @@ class SliceR1Test : public ClientLibraryTestBase, } }; -XLA_TEST_P(SliceR1Test, DoIt_F32) { - Run(GetParam()); -} +XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } -XLA_TEST_P(SliceR1Test, DoIt_F64) { - Run(GetParam()); -} +XLA_TEST_P(SliceR1Test, DoIt_F64) { Run(GetParam()); } -XLA_TEST_P(SliceR1Test, DoIt_U32) { - Run(GetParam()); -} +XLA_TEST_P(SliceR1Test, DoIt_U32) { Run(GetParam()); } -XLA_TEST_P(SliceR1Test, DoIt_S32) { - Run(GetParam()); -} +XLA_TEST_P(SliceR1Test, DoIt_S32) { Run(GetParam()); } -XLA_TEST_P(SliceR1Test, DoIt_U64) { - Run(GetParam()); -} +XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } -XLA_TEST_P(SliceR1Test, DoIt_S64) { - Run(GetParam()); -} +XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } INSTANTIATE_TEST_CASE_P( // SliceR1TestInstantiation, // @@ -306,5 +312,149 @@ INSTANTIATE_TEST_CASE_P( ); // clang-format on +struct R4Spec { + std::array input_dims; + std::array input_layout; // minor-to-major + std::array slice_starts; + std::array slice_limits; + std::array slice_strides; +}; + +string R4SpecToString(const ::testing::TestParamInfo& data) { + const R4Spec& spec = data.param; + return StrCat( // + "input_", Join(spec.input_dims, "x"), // + "__layout_", Join(spec.input_layout, ""), // + "__starts_", Join(spec.slice_starts, "x"), // + "__limits_", Join(spec.slice_limits, "x"), // + "__strides_", Join(spec.slice_strides, "x") // + ); +} + +class SliceR4Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + protected: + void Run(const R4Spec& spec) { + Array4D values(spec.input_dims[0], spec.input_dims[1], + spec.input_dims[2], spec.input_dims[3]); + values.FillRandom(3.14f); + auto expected = ReferenceUtil::Slice4D( + values, spec.slice_starts, spec.slice_limits, spec.slice_strides); + ComputationBuilder builder(client_, TestName()); + auto literal = Literal::CreateR4FromArray4DWithLayout( + values, LayoutUtil::MakeLayout(spec.input_layout)); + auto parameter = builder.Parameter(0, literal->shape(), "p0"); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); + builder.Slice(parameter, spec.slice_starts, spec.slice_limits, + spec.slice_strides); + ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); + } +}; + +XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); } + +const R4Spec kR4SpecValues[] = { + R4Spec{{{2, 2, 2, 2}}, + {{3, 2, 1, 0}}, + {{0, 0, 0, 0}}, + {{0, 0, 0, 0}}, + {{1, 1, 1, 1}}}, // + R4Spec{{{3, 3, 4, 4}}, + {{3, 2, 1, 0}}, + {{0, 0, 0, 0}}, + {{3, 3, 4, 4}}, + {{1, 1, 2, 1}}}, // + R4Spec{{{2, 3, 16, 4}}, + {{3, 2, 1, 0}}, + {{0, 0, 0, 0}}, + {{2, 3, 16, 4}}, + {{1, 1, 3, 1}}}, // + // stride > 1 should be on the second-to-last dimension. + R4Spec{{{4, 16, 3, 2}}, + {{0, 1, 2, 3}}, + {{1, 4, 1, 1}}, + {{3, 12, 3, 2}}, + {{1, 1, 3, 1}}}, // + R4Spec{{{2, 2, 257, 129}}, + {{3, 2, 1, 0}}, + {{1, 1, 62, 64}}, + {{2, 2, 195, 129}}, + {{1, 1, 3, 1}}}, // + R4Spec{{{3, 5, 257, 129}}, + {{3, 2, 1, 0}}, + {{1, 2, 61, 64}}, + {{3, 5, 199, 129}}, + {{1, 1, 3, 1}}}, // + R4Spec{{{5, 8, 257, 129}}, + {{3, 2, 1, 0}}, + {{2, 3, 60, 64}}, + {{3, 5, 200, 68}}, + {{1, 1, 1, 1}}}, // + R4Spec{{{2, 2, 256, 130}}, + {{3, 2, 1, 0}}, + {{0, 0, 60, 127}}, + {{2, 2, 166, 129}}, + {{1, 1, 3, 1}}}, // + R4Spec{{{2, 4, 8, 4}}, + {{3, 2, 1, 0}}, + {{1, 2, 0, 1}}, + {{2, 4, 8, 3}}, + {{1, 1, 7, 1}}}, // + R4Spec{{{2, 4, 256, 130}}, + {{3, 2, 1, 0}}, + {{1, 2, 9, 127}}, + {{2, 4, 82, 129}}, + {{1, 1, 7, 1}}}, // + R4Spec{{{2, 4, 256, 130}}, + {{3, 2, 1, 0}}, + {{1, 2, 19, 127}}, + {{2, 4, 89, 129}}, + {{1, 1, 7, 1}}}, // + R4Spec{{{2, 4, 256, 130}}, + {{3, 2, 1, 0}}, + {{1, 2, 29, 127}}, + {{2, 4, 159, 129}}, + {{1, 1, 7, 1}}}, // + R4Spec{{{2, 4, 256, 130}}, + {{3, 2, 1, 0}}, + {{1, 2, 39, 127}}, + {{2, 4, 158, 129}}, + {{1, 1, 7, 1}}}, // + R4Spec{{{1, 1, 5, 512}}, + {{3, 2, 1, 0}}, + {{0, 0, 0, 0}}, + {{1, 1, 5, 512}}, + {{1, 1, 4, 1}}}, // + R4Spec{{{1, 1, 513, 512}}, + {{3, 2, 1, 0}}, + {{0, 0, 0, 0}}, + {{1, 1, 513, 512}}, + {{1, 1, 512, 1}}}, // + R4Spec{{{1, 1, 1024, 4}}, + {{3, 2, 1, 0}}, + {{0, 0, 15, 0}}, + {{1, 1, 1022, 4}}, + {{1, 1, 23, 1}}}, // + R4Spec{{{1, 1, 1024, 4}}, + {{3, 2, 1, 0}}, + {{0, 0, 14, 0}}, + {{1, 1, 1023, 4}}, + {{1, 1, 101, 1}}}, // + R4Spec{{{2, 2, 512, 1024}}, + {{3, 2, 1, 0}}, + {{0, 0, 0, 0}}, + {{2, 2, 512, 1024}}, + {{1, 1, 2, 1}}}, // + R4Spec{{{1, 1, 14, 2048}}, + {{3, 2, 1, 0}}, + {{0, 0, 2, 0}}, + {{1, 1, 14, 2}}, + {{1, 1, 1, 1}}}, // +}; + +INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test, + ::testing::ValuesIn(kR4SpecValues), R4SpecToString); + } // namespace } // namespace xla