Added test cases on R4 slice.

PiperOrigin-RevId: 168482049
This commit is contained in:
A. Unique TensorFlower 2017-09-12 18:48:38 -07:00 committed by TensorFlower Gardener
parent 46a81b5c34
commit 3cd6bdef5f

View File

@ -25,12 +25,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrCat;
class SliceTest : public ClientLibraryTestBase {};
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
@ -161,6 +165,20 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
}
XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
Array4D<float> values(2, 4, 6, 8);
values.FillRandom(3.14f);
auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
/*strides=*/{{1, 1, 2, 1}});
auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
*expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR4FromArray4D(values);
builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
&expected_literal->shape());
}
struct R1Spec {
int64 input_dim0;
int64 slice_start;
@ -193,29 +211,17 @@ class SliceR1Test : public ClientLibraryTestBase,
}
};
XLA_TEST_P(SliceR1Test, DoIt_F32) {
Run<float>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_F64) {
Run<double>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_U32) {
Run<uint32>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_S32) {
Run<int32>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_U64) {
Run<uint64>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_S64) {
Run<int64>(GetParam());
}
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
INSTANTIATE_TEST_CASE_P( //
SliceR1TestInstantiation, //
@ -306,5 +312,149 @@ INSTANTIATE_TEST_CASE_P(
);
// clang-format on
struct R4Spec {
std::array<int64, 4> input_dims;
std::array<int64, 4> input_layout; // minor-to-major
std::array<int64, 4> slice_starts;
std::array<int64, 4> slice_limits;
std::array<int64, 4> slice_strides;
};
string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
const R4Spec& spec = data.param;
return StrCat( //
"input_", Join(spec.input_dims, "x"), //
"__layout_", Join(spec.input_layout, ""), //
"__starts_", Join(spec.slice_starts, "x"), //
"__limits_", Join(spec.slice_limits, "x"), //
"__strides_", Join(spec.slice_strides, "x") //
);
}
class SliceR4Test : public ClientLibraryTestBase,
public ::testing::WithParamInterface<R4Spec> {
protected:
void Run(const R4Spec& spec) {
Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
spec.input_dims[2], spec.input_dims[3]);
values.FillRandom(3.14f);
auto expected = ReferenceUtil::Slice4D(
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputationBuilder builder(client_, TestName());
auto literal = Literal::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
auto parameter = builder.Parameter(0, literal->shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
client_->TransferToServer(*literal));
builder.Slice(parameter, spec.slice_starts, spec.slice_limits,
spec.slice_strides);
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
}
};
XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
const R4Spec kR4SpecValues[] = {
R4Spec{{{2, 2, 2, 2}},
{{3, 2, 1, 0}},
{{0, 0, 0, 0}},
{{0, 0, 0, 0}},
{{1, 1, 1, 1}}}, //
R4Spec{{{3, 3, 4, 4}},
{{3, 2, 1, 0}},
{{0, 0, 0, 0}},
{{3, 3, 4, 4}},
{{1, 1, 2, 1}}}, //
R4Spec{{{2, 3, 16, 4}},
{{3, 2, 1, 0}},
{{0, 0, 0, 0}},
{{2, 3, 16, 4}},
{{1, 1, 3, 1}}}, //
// stride > 1 should be on the second-to-last dimension.
R4Spec{{{4, 16, 3, 2}},
{{0, 1, 2, 3}},
{{1, 4, 1, 1}},
{{3, 12, 3, 2}},
{{1, 1, 3, 1}}}, //
R4Spec{{{2, 2, 257, 129}},
{{3, 2, 1, 0}},
{{1, 1, 62, 64}},
{{2, 2, 195, 129}},
{{1, 1, 3, 1}}}, //
R4Spec{{{3, 5, 257, 129}},
{{3, 2, 1, 0}},
{{1, 2, 61, 64}},
{{3, 5, 199, 129}},
{{1, 1, 3, 1}}}, //
R4Spec{{{5, 8, 257, 129}},
{{3, 2, 1, 0}},
{{2, 3, 60, 64}},
{{3, 5, 200, 68}},
{{1, 1, 1, 1}}}, //
R4Spec{{{2, 2, 256, 130}},
{{3, 2, 1, 0}},
{{0, 0, 60, 127}},
{{2, 2, 166, 129}},
{{1, 1, 3, 1}}}, //
R4Spec{{{2, 4, 8, 4}},
{{3, 2, 1, 0}},
{{1, 2, 0, 1}},
{{2, 4, 8, 3}},
{{1, 1, 7, 1}}}, //
R4Spec{{{2, 4, 256, 130}},
{{3, 2, 1, 0}},
{{1, 2, 9, 127}},
{{2, 4, 82, 129}},
{{1, 1, 7, 1}}}, //
R4Spec{{{2, 4, 256, 130}},
{{3, 2, 1, 0}},
{{1, 2, 19, 127}},
{{2, 4, 89, 129}},
{{1, 1, 7, 1}}}, //
R4Spec{{{2, 4, 256, 130}},
{{3, 2, 1, 0}},
{{1, 2, 29, 127}},
{{2, 4, 159, 129}},
{{1, 1, 7, 1}}}, //
R4Spec{{{2, 4, 256, 130}},
{{3, 2, 1, 0}},
{{1, 2, 39, 127}},
{{2, 4, 158, 129}},
{{1, 1, 7, 1}}}, //
R4Spec{{{1, 1, 5, 512}},
{{3, 2, 1, 0}},
{{0, 0, 0, 0}},
{{1, 1, 5, 512}},
{{1, 1, 4, 1}}}, //
R4Spec{{{1, 1, 513, 512}},
{{3, 2, 1, 0}},
{{0, 0, 0, 0}},
{{1, 1, 513, 512}},
{{1, 1, 512, 1}}}, //
R4Spec{{{1, 1, 1024, 4}},
{{3, 2, 1, 0}},
{{0, 0, 15, 0}},
{{1, 1, 1022, 4}},
{{1, 1, 23, 1}}}, //
R4Spec{{{1, 1, 1024, 4}},
{{3, 2, 1, 0}},
{{0, 0, 14, 0}},
{{1, 1, 1023, 4}},
{{1, 1, 101, 1}}}, //
R4Spec{{{2, 2, 512, 1024}},
{{3, 2, 1, 0}},
{{0, 0, 0, 0}},
{{2, 2, 512, 1024}},
{{1, 1, 2, 1}}}, //
R4Spec{{{1, 1, 14, 2048}},
{{3, 2, 1, 0}},
{{0, 0, 2, 0}},
{{1, 1, 14, 2}},
{{1, 1, 1, 1}}}, //
};
INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
::testing::ValuesIn(kR4SpecValues), R4SpecToString);
} // namespace
} // namespace xla