Added test cases on R4 slice.
PiperOrigin-RevId: 168482049
This commit is contained in:
parent
46a81b5c34
commit
3cd6bdef5f
@ -25,12 +25,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::tensorflow::str_util::Join;
|
||||||
|
using ::tensorflow::strings::StrCat;
|
||||||
|
|
||||||
class SliceTest : public ClientLibraryTestBase {};
|
class SliceTest : public ClientLibraryTestBase {};
|
||||||
|
|
||||||
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
|
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
|
||||||
@ -161,6 +165,20 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
|
|||||||
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
|
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
|
||||||
|
Array4D<float> values(2, 4, 6, 8);
|
||||||
|
values.FillRandom(3.14f);
|
||||||
|
auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
|
||||||
|
/*strides=*/{{1, 1, 2, 1}});
|
||||||
|
auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
|
||||||
|
*expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto original = builder.ConstantR4FromArray4D(values);
|
||||||
|
builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
|
||||||
|
ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
|
||||||
|
&expected_literal->shape());
|
||||||
|
}
|
||||||
|
|
||||||
struct R1Spec {
|
struct R1Spec {
|
||||||
int64 input_dim0;
|
int64 input_dim0;
|
||||||
int64 slice_start;
|
int64 slice_start;
|
||||||
@ -193,29 +211,17 @@ class SliceR1Test : public ClientLibraryTestBase,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_F32) {
|
XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
|
||||||
Run<float>(GetParam());
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_F64) {
|
XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
|
||||||
Run<double>(GetParam());
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_U32) {
|
XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
|
||||||
Run<uint32>(GetParam());
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_S32) {
|
XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
|
||||||
Run<int32>(GetParam());
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_U64) {
|
XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
|
||||||
Run<uint64>(GetParam());
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_P(SliceR1Test, DoIt_S64) {
|
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
|
||||||
Run<int64>(GetParam());
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P( //
|
INSTANTIATE_TEST_CASE_P( //
|
||||||
SliceR1TestInstantiation, //
|
SliceR1TestInstantiation, //
|
||||||
@ -306,5 +312,149 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
);
|
);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
struct R4Spec {
|
||||||
|
std::array<int64, 4> input_dims;
|
||||||
|
std::array<int64, 4> input_layout; // minor-to-major
|
||||||
|
std::array<int64, 4> slice_starts;
|
||||||
|
std::array<int64, 4> slice_limits;
|
||||||
|
std::array<int64, 4> slice_strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
|
||||||
|
const R4Spec& spec = data.param;
|
||||||
|
return StrCat( //
|
||||||
|
"input_", Join(spec.input_dims, "x"), //
|
||||||
|
"__layout_", Join(spec.input_layout, ""), //
|
||||||
|
"__starts_", Join(spec.slice_starts, "x"), //
|
||||||
|
"__limits_", Join(spec.slice_limits, "x"), //
|
||||||
|
"__strides_", Join(spec.slice_strides, "x") //
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
class SliceR4Test : public ClientLibraryTestBase,
|
||||||
|
public ::testing::WithParamInterface<R4Spec> {
|
||||||
|
protected:
|
||||||
|
void Run(const R4Spec& spec) {
|
||||||
|
Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
|
||||||
|
spec.input_dims[2], spec.input_dims[3]);
|
||||||
|
values.FillRandom(3.14f);
|
||||||
|
auto expected = ReferenceUtil::Slice4D(
|
||||||
|
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto literal = Literal::CreateR4FromArray4DWithLayout(
|
||||||
|
values, LayoutUtil::MakeLayout(spec.input_layout));
|
||||||
|
auto parameter = builder.Parameter(0, literal->shape(), "p0");
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
|
||||||
|
client_->TransferToServer(*literal));
|
||||||
|
builder.Slice(parameter, spec.slice_starts, spec.slice_limits,
|
||||||
|
spec.slice_strides);
|
||||||
|
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
|
||||||
|
|
||||||
|
const R4Spec kR4SpecValues[] = {
|
||||||
|
R4Spec{{{2, 2, 2, 2}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{1, 1, 1, 1}}}, //
|
||||||
|
R4Spec{{{3, 3, 4, 4}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{3, 3, 4, 4}},
|
||||||
|
{{1, 1, 2, 1}}}, //
|
||||||
|
R4Spec{{{2, 3, 16, 4}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{2, 3, 16, 4}},
|
||||||
|
{{1, 1, 3, 1}}}, //
|
||||||
|
// stride > 1 should be on the second-to-last dimension.
|
||||||
|
R4Spec{{{4, 16, 3, 2}},
|
||||||
|
{{0, 1, 2, 3}},
|
||||||
|
{{1, 4, 1, 1}},
|
||||||
|
{{3, 12, 3, 2}},
|
||||||
|
{{1, 1, 3, 1}}}, //
|
||||||
|
R4Spec{{{2, 2, 257, 129}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 1, 62, 64}},
|
||||||
|
{{2, 2, 195, 129}},
|
||||||
|
{{1, 1, 3, 1}}}, //
|
||||||
|
R4Spec{{{3, 5, 257, 129}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 2, 61, 64}},
|
||||||
|
{{3, 5, 199, 129}},
|
||||||
|
{{1, 1, 3, 1}}}, //
|
||||||
|
R4Spec{{{5, 8, 257, 129}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{2, 3, 60, 64}},
|
||||||
|
{{3, 5, 200, 68}},
|
||||||
|
{{1, 1, 1, 1}}}, //
|
||||||
|
R4Spec{{{2, 2, 256, 130}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 60, 127}},
|
||||||
|
{{2, 2, 166, 129}},
|
||||||
|
{{1, 1, 3, 1}}}, //
|
||||||
|
R4Spec{{{2, 4, 8, 4}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 2, 0, 1}},
|
||||||
|
{{2, 4, 8, 3}},
|
||||||
|
{{1, 1, 7, 1}}}, //
|
||||||
|
R4Spec{{{2, 4, 256, 130}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 2, 9, 127}},
|
||||||
|
{{2, 4, 82, 129}},
|
||||||
|
{{1, 1, 7, 1}}}, //
|
||||||
|
R4Spec{{{2, 4, 256, 130}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 2, 19, 127}},
|
||||||
|
{{2, 4, 89, 129}},
|
||||||
|
{{1, 1, 7, 1}}}, //
|
||||||
|
R4Spec{{{2, 4, 256, 130}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 2, 29, 127}},
|
||||||
|
{{2, 4, 159, 129}},
|
||||||
|
{{1, 1, 7, 1}}}, //
|
||||||
|
R4Spec{{{2, 4, 256, 130}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{1, 2, 39, 127}},
|
||||||
|
{{2, 4, 158, 129}},
|
||||||
|
{{1, 1, 7, 1}}}, //
|
||||||
|
R4Spec{{{1, 1, 5, 512}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{1, 1, 5, 512}},
|
||||||
|
{{1, 1, 4, 1}}}, //
|
||||||
|
R4Spec{{{1, 1, 513, 512}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{1, 1, 513, 512}},
|
||||||
|
{{1, 1, 512, 1}}}, //
|
||||||
|
R4Spec{{{1, 1, 1024, 4}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 15, 0}},
|
||||||
|
{{1, 1, 1022, 4}},
|
||||||
|
{{1, 1, 23, 1}}}, //
|
||||||
|
R4Spec{{{1, 1, 1024, 4}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 14, 0}},
|
||||||
|
{{1, 1, 1023, 4}},
|
||||||
|
{{1, 1, 101, 1}}}, //
|
||||||
|
R4Spec{{{2, 2, 512, 1024}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 0, 0}},
|
||||||
|
{{2, 2, 512, 1024}},
|
||||||
|
{{1, 1, 2, 1}}}, //
|
||||||
|
R4Spec{{{1, 1, 14, 2048}},
|
||||||
|
{{3, 2, 1, 0}},
|
||||||
|
{{0, 0, 2, 0}},
|
||||||
|
{{1, 1, 14, 2}},
|
||||||
|
{{1, 1, 1, 1}}}, //
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
|
||||||
|
::testing::ValuesIn(kR4SpecValues), R4SpecToString);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user