Added test cases on R4 slice.
PiperOrigin-RevId: 168482049
This commit is contained in:
parent
46a81b5c34
commit
3cd6bdef5f
@ -25,12 +25,16 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
class SliceTest : public ClientLibraryTestBase {};
|
||||
|
||||
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
|
||||
@ -161,6 +165,20 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
|
||||
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
|
||||
Array4D<float> values(2, 4, 6, 8);
|
||||
values.FillRandom(3.14f);
|
||||
auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
|
||||
/*strides=*/{{1, 1, 2, 1}});
|
||||
auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
|
||||
*expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR4FromArray4D(values);
|
||||
builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
|
||||
ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
|
||||
&expected_literal->shape());
|
||||
}
|
||||
|
||||
struct R1Spec {
|
||||
int64 input_dim0;
|
||||
int64 slice_start;
|
||||
@ -193,29 +211,17 @@ class SliceR1Test : public ClientLibraryTestBase,
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt_F32) {
|
||||
Run<float>(GetParam());
|
||||
}
|
||||
XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt_F64) {
|
||||
Run<double>(GetParam());
|
||||
}
|
||||
XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt_U32) {
|
||||
Run<uint32>(GetParam());
|
||||
}
|
||||
XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt_S32) {
|
||||
Run<int32>(GetParam());
|
||||
}
|
||||
XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt_U64) {
|
||||
Run<uint64>(GetParam());
|
||||
}
|
||||
XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
|
||||
|
||||
XLA_TEST_P(SliceR1Test, DoIt_S64) {
|
||||
Run<int64>(GetParam());
|
||||
}
|
||||
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
|
||||
|
||||
INSTANTIATE_TEST_CASE_P( //
|
||||
SliceR1TestInstantiation, //
|
||||
@ -306,5 +312,149 @@ INSTANTIATE_TEST_CASE_P(
|
||||
);
|
||||
// clang-format on
|
||||
|
||||
struct R4Spec {
|
||||
std::array<int64, 4> input_dims;
|
||||
std::array<int64, 4> input_layout; // minor-to-major
|
||||
std::array<int64, 4> slice_starts;
|
||||
std::array<int64, 4> slice_limits;
|
||||
std::array<int64, 4> slice_strides;
|
||||
};
|
||||
|
||||
string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
|
||||
const R4Spec& spec = data.param;
|
||||
return StrCat( //
|
||||
"input_", Join(spec.input_dims, "x"), //
|
||||
"__layout_", Join(spec.input_layout, ""), //
|
||||
"__starts_", Join(spec.slice_starts, "x"), //
|
||||
"__limits_", Join(spec.slice_limits, "x"), //
|
||||
"__strides_", Join(spec.slice_strides, "x") //
|
||||
);
|
||||
}
|
||||
|
||||
class SliceR4Test : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<R4Spec> {
|
||||
protected:
|
||||
void Run(const R4Spec& spec) {
|
||||
Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
|
||||
spec.input_dims[2], spec.input_dims[3]);
|
||||
values.FillRandom(3.14f);
|
||||
auto expected = ReferenceUtil::Slice4D(
|
||||
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto literal = Literal::CreateR4FromArray4DWithLayout(
|
||||
values, LayoutUtil::MakeLayout(spec.input_layout));
|
||||
auto parameter = builder.Parameter(0, literal->shape(), "p0");
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
|
||||
client_->TransferToServer(*literal));
|
||||
builder.Slice(parameter, spec.slice_starts, spec.slice_limits,
|
||||
spec.slice_strides);
|
||||
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
|
||||
|
||||
const R4Spec kR4SpecValues[] = {
|
||||
R4Spec{{{2, 2, 2, 2}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{1, 1, 1, 1}}}, //
|
||||
R4Spec{{{3, 3, 4, 4}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{3, 3, 4, 4}},
|
||||
{{1, 1, 2, 1}}}, //
|
||||
R4Spec{{{2, 3, 16, 4}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{2, 3, 16, 4}},
|
||||
{{1, 1, 3, 1}}}, //
|
||||
// stride > 1 should be on the second-to-last dimension.
|
||||
R4Spec{{{4, 16, 3, 2}},
|
||||
{{0, 1, 2, 3}},
|
||||
{{1, 4, 1, 1}},
|
||||
{{3, 12, 3, 2}},
|
||||
{{1, 1, 3, 1}}}, //
|
||||
R4Spec{{{2, 2, 257, 129}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 1, 62, 64}},
|
||||
{{2, 2, 195, 129}},
|
||||
{{1, 1, 3, 1}}}, //
|
||||
R4Spec{{{3, 5, 257, 129}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 2, 61, 64}},
|
||||
{{3, 5, 199, 129}},
|
||||
{{1, 1, 3, 1}}}, //
|
||||
R4Spec{{{5, 8, 257, 129}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{2, 3, 60, 64}},
|
||||
{{3, 5, 200, 68}},
|
||||
{{1, 1, 1, 1}}}, //
|
||||
R4Spec{{{2, 2, 256, 130}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 60, 127}},
|
||||
{{2, 2, 166, 129}},
|
||||
{{1, 1, 3, 1}}}, //
|
||||
R4Spec{{{2, 4, 8, 4}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 2, 0, 1}},
|
||||
{{2, 4, 8, 3}},
|
||||
{{1, 1, 7, 1}}}, //
|
||||
R4Spec{{{2, 4, 256, 130}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 2, 9, 127}},
|
||||
{{2, 4, 82, 129}},
|
||||
{{1, 1, 7, 1}}}, //
|
||||
R4Spec{{{2, 4, 256, 130}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 2, 19, 127}},
|
||||
{{2, 4, 89, 129}},
|
||||
{{1, 1, 7, 1}}}, //
|
||||
R4Spec{{{2, 4, 256, 130}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 2, 29, 127}},
|
||||
{{2, 4, 159, 129}},
|
||||
{{1, 1, 7, 1}}}, //
|
||||
R4Spec{{{2, 4, 256, 130}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{1, 2, 39, 127}},
|
||||
{{2, 4, 158, 129}},
|
||||
{{1, 1, 7, 1}}}, //
|
||||
R4Spec{{{1, 1, 5, 512}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{1, 1, 5, 512}},
|
||||
{{1, 1, 4, 1}}}, //
|
||||
R4Spec{{{1, 1, 513, 512}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{1, 1, 513, 512}},
|
||||
{{1, 1, 512, 1}}}, //
|
||||
R4Spec{{{1, 1, 1024, 4}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 15, 0}},
|
||||
{{1, 1, 1022, 4}},
|
||||
{{1, 1, 23, 1}}}, //
|
||||
R4Spec{{{1, 1, 1024, 4}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 14, 0}},
|
||||
{{1, 1, 1023, 4}},
|
||||
{{1, 1, 101, 1}}}, //
|
||||
R4Spec{{{2, 2, 512, 1024}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 0, 0}},
|
||||
{{2, 2, 512, 1024}},
|
||||
{{1, 1, 2, 1}}}, //
|
||||
R4Spec{{{1, 1, 14, 2048}},
|
||||
{{3, 2, 1, 0}},
|
||||
{{0, 0, 2, 0}},
|
||||
{{1, 1, 14, 2}},
|
||||
{{1, 1, 1, 1}}}, //
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
|
||||
::testing::ValuesIn(kR4SpecValues), R4SpecToString);
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user