Ensure XLA non-python tests pass
This commit is contained in:
parent
db58c826ae
commit
4f513e0555
@ -1845,7 +1845,7 @@ TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
|
||||
|
||||
auto x = builder.Parameter(0, x_literal->shape(), "x");
|
||||
auto y = builder.Parameter(1, y_literal->shape(), "y");
|
||||
auto slice = builder.Slice(x, {1}, {2});
|
||||
auto slice = builder.Slice(x, {1}, {2}, {1});
|
||||
builder.Sub(slice, y);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
|
||||
|
@ -367,9 +367,9 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) {
|
||||
std::vector<xla::ComputationDataHandle> out_slices;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
// Slice off individual matrices and reshape to 2D tensors.
|
||||
auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2});
|
||||
auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
|
||||
x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
|
||||
auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2});
|
||||
auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
|
||||
y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
|
||||
|
||||
auto out = builder.Dot(x_slice, y_slice);
|
||||
|
@ -204,7 +204,7 @@ XLA_TEST_F(FusionTest, Test) {
|
||||
HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
|
||||
HloOpcode::kSelect, const10, add8, const9));
|
||||
auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
|
||||
ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}));
|
||||
ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
|
||||
// CreateFusionInstruction needs the `instructions_to_fuse` argument in
|
||||
// reverse topological order, so the first element in `instructions_to_fuse`
|
||||
// must be the root.
|
||||
|
@ -37,7 +37,7 @@ XLA_TEST_F(SliceTest, Slice2D) {
|
||||
ComputationBuilder builder(client_, "slice_2d");
|
||||
auto original = builder.ConstantR2<float>(
|
||||
{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
|
||||
builder.Slice(original, {2, 1}, {4, 3});
|
||||
builder.Slice(original, {2, 1}, {4, 3}, {1, 1});
|
||||
|
||||
Array2D<float> expected({{8.0f, 9.0f}, {11.0f, 12.0f}});
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
|
||||
@ -48,7 +48,7 @@ XLA_TEST_F(SliceTest, Slice3D) {
|
||||
Array3D<float> array_3d(
|
||||
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
|
||||
auto original = builder.ConstantR3FromArray3D<float>(array_3d);
|
||||
builder.Slice(original, {0, 0, 1}, {2, 1, 2});
|
||||
builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1});
|
||||
|
||||
Array3D<float> expected_3d({{{2.0f}}, {{6.0f}}});
|
||||
ComputeAndCompareR3<float>(&builder, expected_3d, {}, ErrorSpec(0.000001));
|
||||
|
@ -328,7 +328,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto input = builder.Parameter(0, original, "input");
|
||||
// Use the slice operator to get an off-diagonal element.
|
||||
builder.Slice(input, {0, 1}, {1, 2});
|
||||
builder.Slice(input, {0, 1}, {1, 2}, {1, 1});
|
||||
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->TransferToServer(*literal).ConsumeValueOrDie();
|
||||
|
@ -45,7 +45,7 @@ class SliceTest : public ClientLibraryTestBase {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<NativeT>(constant);
|
||||
builder.Slice(original, {2}, {4});
|
||||
builder.Slice(original, {2}, {4}, {1});
|
||||
|
||||
const std::vector<NativeT> expected = {static_cast<NativeT>(2),
|
||||
static_cast<NativeT>(3)};
|
||||
@ -56,7 +56,7 @@ class SliceTest : public ClientLibraryTestBase {
|
||||
XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>({});
|
||||
builder.Slice(original, {0}, {0});
|
||||
builder.Slice(original, {0}, {0}, {1});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {}, {});
|
||||
}
|
||||
@ -65,7 +65,7 @@ XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
std::vector<float> constant(10, 0.3);
|
||||
auto original = builder.ConstantR1<float>(constant);
|
||||
builder.Slice(original, {7}, {7});
|
||||
builder.Slice(original, {7}, {7}, {1});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {}, {});
|
||||
}
|
||||
@ -88,7 +88,7 @@ TEST_F(SliceTest, SliceTenToTen) {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>(values);
|
||||
builder.Slice(original, {0}, {10});
|
||||
builder.Slice(original, {0}, {10}, {1});
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
@ -99,7 +99,7 @@ TEST_F(SliceTest, SliceLastFourOf1024) {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>(values);
|
||||
builder.Slice(original, {1024 - 4}, {1024});
|
||||
builder.Slice(original, {1024 - 4}, {1024}, {1});
|
||||
|
||||
const std::vector<float> expected = {1020, 1021, 1022, 1023};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
|
||||
@ -113,7 +113,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR1<float>(values);
|
||||
builder.Slice(original, {7}, {7 + 1024});
|
||||
builder.Slice(original, {7}, {7 + 1024}, {1});
|
||||
|
||||
std::vector<float> expected(1024);
|
||||
std::iota(values.begin(), values.end(), 7.0);
|
||||
@ -123,7 +123,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
|
||||
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
|
||||
builder.Slice(original, {0, 0}, {0, 0});
|
||||
builder.Slice(original, {0, 0}, {0, 0}, {1, 1});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
|
||||
}
|
||||
@ -131,7 +131,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
|
||||
XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
|
||||
builder.Slice(original, {0, 15}, {0, 20});
|
||||
builder.Slice(original, {0, 15}, {0, 20}, {1, 1});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
|
||||
}
|
||||
@ -139,7 +139,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
|
||||
XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
|
||||
builder.Slice(original, {1, 0}, {3, 0});
|
||||
builder.Slice(original, {1, 0}, {3, 0}, {1, 1});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
|
||||
}
|
||||
@ -154,7 +154,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR2FromArray2D<float>(values);
|
||||
builder.Slice(original, {128, 128}, {256, 256});
|
||||
builder.Slice(original, {128, 128}, {256, 256}, {1, 1});
|
||||
|
||||
Array2D<float> expected(128, 128);
|
||||
for (int row = 0; row < 128; ++row) {
|
||||
@ -172,7 +172,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR2FromArray2D<float>(values);
|
||||
builder.Slice(original, {0, 3072}, {1, 4096});
|
||||
builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1});
|
||||
|
||||
Array2D<float> expected(1, 1024);
|
||||
std::iota(expected.data(), expected.data() + 1024, 3072.0);
|
||||
@ -193,7 +193,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
|
||||
}
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR2FromArray2D<float>(values);
|
||||
builder.Slice(original, {0, 0}, {16, 2});
|
||||
builder.Slice(original, {0, 0}, {16, 2}, {1, 1});
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
|
||||
@ -205,7 +205,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
|
||||
ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto original = builder.ConstantR4FromArray4D(values);
|
||||
builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128});
|
||||
builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
|
||||
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
|
||||
}
|
||||
|
||||
@ -214,6 +214,7 @@ struct R2Spec {
|
||||
int64 input_dim1;
|
||||
std::array<int64, 2> slice_starts;
|
||||
std::array<int64, 2> slice_limits;
|
||||
std::array<int64, 2> slice_strides;
|
||||
Layout layout;
|
||||
};
|
||||
|
||||
@ -229,7 +230,7 @@ TEST_P(SliceR2Test, DoIt) {
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.ConstantR2FromArray2D<int32>(input);
|
||||
builder.Slice(a, spec.slice_starts, spec.slice_limits);
|
||||
builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
|
||||
|
||||
std::unique_ptr<Array2D<int32>> expected =
|
||||
ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
|
||||
@ -240,19 +241,23 @@ TEST_P(SliceR2Test, DoIt) {
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
SliceR2TestInstantiation, SliceR2Test,
|
||||
::testing::Values(
|
||||
R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {256, 400, {{0, 300}}, {{256, 400}},
|
||||
R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {500, 400, {{111, 123}}, {{300, 257}},
|
||||
R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({0, 1})},
|
||||
R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {500, 400, {{111, 123}}, {{300, 400}},
|
||||
R2Spec {256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {384, 512, {{128, 256}}, {{256, 384}},
|
||||
R2Spec {500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {357, 512, {{111, 256}}, {{301, 384}},
|
||||
R2Spec {500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})},
|
||||
R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}},
|
||||
LayoutUtil::MakeLayout({1, 0})}
|
||||
)
|
||||
);
|
||||
|
@ -556,7 +556,8 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
|
||||
auto build_condition = [this, v6s32](int count) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto prev = builder.Reshape(
|
||||
builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {});
|
||||
builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
|
||||
{});
|
||||
builder.Gt(builder.ConstantR0<int32>(count), prev);
|
||||
return builder.Build().ConsumeValueOrDie();
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user