Adrian Kuegel 334a3b5ad9 Replace reproducer test case with a better one.
The original test case only is a reproducer when constant folding is enabled.
By default, tests disable constant folding, so instead replace the module with
the one created after the constant folding pass has run.

PiperOrigin-RevId: 335618559
Change-Id: Ib55a197aac53dc9207169df7c57c023c4b14c5bf
2020-10-06 05:23:59 -07:00

895 lines
35 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
using ConcatTest = ClientLibraryTestBase;
using ConcatTestHlo = HloTestBase;
using ::testing::HasSubstr;
// Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest, Concat_Nothing) {
XlaBuilder builder(TestName());
ConcatInDim(&builder, {}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
HasSubstr("Concatenate expects at least one argument"));
}
// Concatenate with one argument works.
XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0, 64.0});
ConcatInDim(&builder, {a}, 0);
std::vector<float> expected = {42, 64};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {});
ConcatInDim(&builder, {a}, 0);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
// Show that we can't concatenate R0 with R0 because we can't name the dimension
// to concatenate on.
XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
XlaBuilder builder(TestName());
auto a = ConstantR0<float>(&builder, 42.0);
auto b = ConstantR0<float>(&builder, 64.0);
ConcatInDim(&builder, {a, b}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
HasSubstr("out of bounds: 0"));
}
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {});
auto b = ConstantR1<float>(&builder, {});
ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {});
auto b = ConstantR1<float>(&builder, {256.0});
ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0, 64.0});
auto b = ConstantR1<float>(&builder, {});
ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {42, 64};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0, 64.0});
auto b = ConstantR1<float>(&builder, {256.0});
ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
std::vector<float> lhs(253);
std::vector<float> rhs(7);
std::vector<float> expected(253 + 7);
for (int i = 0; i < 253; ++i) {
expected[i] = lhs[i] = i + 1;
}
for (int i = 0; i < 7; ++i) {
expected[253 + i] = rhs[i] = 253 + i + 1;
}
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, lhs);
auto b = ConstantR1<float>(&builder, rhs);
ConcatInDim(&builder, {a, b}, 0);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
for (int dim : {0, 1}) {
XlaBuilder builder(TestName());
auto a = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
ConcatInDim(&builder, {a, b}, dim);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
ErrorSpec(0.0001));
}
}
XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(1, 1);
auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
auto a = ConstantR2FromArray2D(&builder, *a_array);
auto b = ConstantR2FromArray2D(&builder, *b_array);
ConcatInDim(&builder, {a, b}, 0);
Array2D<float> expected({
{0},
{64},
});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(1, 1);
auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
auto a = ConstantR2FromArray2D(&builder, *a_array);
auto b = ConstantR2FromArray2D(&builder, *b_array);
ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected({
{0, 64},
});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
XlaBuilder builder(TestName());
auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
auto a = ConstantR2FromArray2D(&builder, Array2D<float>(2, 0));
auto b = ConstantR2FromArray2D(&builder, *b_array);
ConcatInDim(&builder, {a, b}, 1);
ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(2, 3);
auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
auto a = ConstantR2FromArray2D(&builder, *a_array);
auto b = ConstantR2FromArray2D(&builder, *b_array);
ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected({
{0, 1, 2, 64, 65, 66, 67, 68},
{1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(3, 2);
auto a = ConstantR2FromArray2D(&builder, *a_array);
auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 2));
ConcatInDim(&builder, {a, b}, 0);
ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(3, 2);
auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
auto a = ConstantR2FromArray2D(&builder, *a_array);
auto b = ConstantR2FromArray2D(&builder, *b_array);
ConcatInDim(&builder, {a, b}, 0);
Array2D<float> expected({
{0, 1},
{1000, 1001},
{2000, 2001},
{64, 65},
{1064, 1065},
{2064, 2065},
{3064, 3065},
{4064, 4065},
});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
XlaBuilder builder(TestName());
auto a = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 2));
auto b = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 1));
ConcatInDim(&builder, {a, b}, 2);
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
XlaBuilder builder(TestName());
Array3D<float> a_array({
// 3x1x2
{{0, 1}},
{{2, 3}},
{{4, 5}},
});
Array3D<float> b_array({
// 3x1x1
{{6}},
{{7}},
{{8}},
});
auto a = ConstantR3FromArray3D(&builder, a_array);
auto b = ConstantR3FromArray3D(&builder, b_array);
ConcatInDim(&builder, {a, b}, 2);
Array3D<float> expected({
{{0, 1, 6}},
{{2, 3, 7}},
{{4, 5, 8}},
});
ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0});
auto b = ConstantR1<float>(&builder, {64.0});
auto c = ConstantR1<float>(&builder, {256.0});
ConcatInDim(&builder, {a, b, c}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
XlaBuilder builder(TestName());
Array3D<float> a_array({
// 3x1x2
{{0, 1}},
{{4, 5}},
{{8, 9}},
});
Array3D<float> b_array({
// 3x1x1
{{2}},
{{6}},
{{10}},
});
Array3D<float> c_array({
// 3x1x1
{{3}},
{{7}},
{{11}},
});
auto a = ConstantR3FromArray3D(&builder, a_array);
auto b = ConstantR3FromArray3D(&builder, b_array);
auto c = ConstantR3FromArray3D(&builder, c_array);
ConcatInDim(&builder, {a, b, c}, 2);
Array3D<float> expected({
{{0, 1, 2, 3}},
{{4, 5, 6, 7}},
{{8, 9, 10, 11}},
});
ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0});
auto b = ConstantR1<float>(&builder, {64.0});
auto c = ConstantR1<float>(&builder, {256.0});
// concatenated = (a concat b) concat c
ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0});
auto b = ConstantR1<float>(&builder, {64.0});
auto c = ConstantR1<float>(&builder, {256.0});
// concatenated = a concat (b concat c)
ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
Array2D<float> lhs(1, 1024);
Array2D<float> rhs(1, 1024);
for (int i = 0; i < 1024; ++i) {
lhs(0, i) = i;
rhs(0, i) = i + 1024;
}
XlaBuilder builder(TestName());
auto a = ConstantR2FromArray2D<float>(&builder, lhs);
auto b = ConstantR2FromArray2D<float>(&builder, rhs);
ConcatInDim(&builder, {a, b}, 0);
Array2D<float> expected(2, 1024);
for (int i = 0; i < 1024; ++i) {
expected(0, i) = i;
expected(1, i) = i + 1024;
}
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
Array2D<float> lhs(1, 1024);
Array2D<float> rhs(1, 1024);
for (int i = 0; i < 1024; ++i) {
lhs(0, i) = i;
rhs(0, i) = i + 1024;
}
XlaBuilder builder(TestName());
auto a = ConstantR2FromArray2D<float>(&builder, lhs);
auto b = ConstantR2FromArray2D<float>(&builder, rhs);
ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected(1, 2048);
for (int i = 0; i < 1024; ++i) {
expected(0, i) = i;
expected(0, i + 1024) = i + 1024;
}
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
Array2D<float> lhs(64, 64);
Array2D<float> rhs(64, 2);
for (int i0 = 0; i0 < 64; ++i0) {
for (int i1 = 0; i1 < 64; ++i1) {
lhs(i0, i1) = (i0 << 10) | i1;
}
for (int i1 = 0; i1 < 2; ++i1) {
rhs(i0, i1) = (i0 << 10) | (i1 + 64);
}
}
XlaBuilder builder(TestName());
auto a = ConstantR2FromArray2D<float>(&builder, lhs);
auto b = ConstantR2FromArray2D<float>(&builder, rhs);
ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected(64, 66);
for (int i0 = 0; i0 < 64; ++i0) {
for (int i1 = 0; i1 < 66; ++i1) {
expected(i0, i1) = (i0 << 10) | i1;
}
}
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
// Show that we can't concatenate with an opaques.
XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
XlaBuilder builder(TestName());
auto opaque_shape = ShapeUtil::MakeOpaqueShape();
auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
auto x = Parameter(&builder, 0, r1f32, "x");
auto y = Parameter(&builder, 1, opaque_shape, "y");
ConcatInDim(&builder, {x, y}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(
computation_status.status().ToString(),
HasSubstr("Expected array argument for operand of concatenation"));
}
// Show that we can't concatenate with tokens.
XLA_TEST_F(ConcatTest, CannotConcatTokens) {
XlaBuilder builder(TestName());
auto token_shape = ShapeUtil::MakeTokenShape();
auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
auto x = Parameter(&builder, 0, r1f32, "x");
auto y = Parameter(&builder, 1, token_shape, "y");
ConcatInDim(&builder, {x, y}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(
computation_status.status().ToString(),
HasSubstr("Expected array argument for operand of concatenation"));
}
XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
XlaBuilder builder(TestName());
auto p0 = ConstantR1<bool>(&builder, {true});
auto p1 = ConstantR1<bool>(&builder, {false});
auto p2 = ConstantR1<bool>(&builder, {true});
ConcatInDim(&builder, {p0, p1, p2}, 0);
bool expected[] = {true, false, true};
ComputeAndCompareR1<bool>(&builder, expected, {});
}
XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
XlaBuilder builder(TestName());
auto a0 = ConstantR1<int32>(&builder, {1});
auto a1 = ConstantR1<int32>(&builder, {2, 3});
auto a2 = ConstantR1<int32>(&builder, {4, 5, 6});
auto a3 = ConstantR1<int32>(&builder, {7, 8, 9, 10});
ConcatInDim(&builder, {a0, a1, a2, a3}, 0);
std::vector<int32> expected(10);
std::iota(expected.begin(), expected.end(), 1);
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
XlaBuilder builder(TestName());
Array3D<float> arr0(9, 17, 1);
arr0.Fill(1);
Array3D<float> arr1(9, 17, 256);
arr1.Fill(2);
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
for (int64 i = 0; i < expected.n1(); ++i) {
for (int64 j = 0; j < expected.n2(); ++j) {
int64 kk = 0;
for (const Array3D<float>& arr : {arr0, arr1}) {
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
expected(i, j, kk) = arr(i, j, k);
}
}
}
}
XlaOp h0;
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
&builder, &h0);
XlaOp h1;
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
&builder, &h1);
ConcatInDim(&builder, {h0, h1}, 2);
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
}
XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
XlaBuilder builder(TestName());
auto a_literal = LiteralUtil::CreateR1<float>({256.0});
auto a = Parameter(&builder, 0, a_literal.shape(), "x");
auto b = ConcatInDim(&builder, {a, a}, 0);
auto c = ConcatInDim(&builder, {b, b}, 0);
auto d = ConcatInDim(&builder, {c, c}, 0);
auto e = ConcatInDim(&builder, {d, d}, 0);
auto f = ConcatInDim(&builder, {e, e}, 0);
auto g = ConcatInDim(&builder, {f, f}, 0);
auto h = ConcatInDim(&builder, {g, g}, 0);
auto i = ConcatInDim(&builder, {h, h}, 0);
auto j = ConcatInDim(&builder, {i, i}, 0);
auto k = ConcatInDim(&builder, {j, j}, 0);
auto l = ConcatInDim(&builder, {k, k}, 0);
auto m = ConcatInDim(&builder, {l, l}, 0);
auto n = ConcatInDim(&builder, {m, m}, 0);
auto o = ConcatInDim(&builder, {n, n}, 0);
auto p = ConcatInDim(&builder, {o, o}, 0);
auto q = ConcatInDim(&builder, {p, p}, 0);
ConcatInDim(&builder, {q, q}, 0);
std::vector<float> expected(131072, 256.0);
auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
}
// TODO(b/169314478): Enable the test when the slow compilation is fixed.
XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule jit_broken.874
primitive_computation_add.866 {
parameter.867 = f32[] parameter(0)
parameter.868 = f32[] parameter(1)
ROOT add.869 = f32[] add(parameter.867, parameter.868)
}
ENTRY jit_broken.874 {
parameter.38 = f32[4,2]{1,0} parameter(0)
reshape.723 = f32[4,2,1]{2,1,0} reshape(parameter.38)
reshape.724 = f32[4,2,1]{2,1,0} reshape(parameter.38)
concatenate.42 = f32[4,2,2]{2,1,0} concatenate(reshape.723, reshape.724), dimensions={2}
slice.351 = f32[4,1,2]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:2]}
reshape.1058 = f32[4,2]{1,0} reshape(slice.351)
slice.352 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [1:2]}
reshape.1059 = f32[4]{0} reshape(slice.352)
slice.353 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1060 = f32[4]{0} reshape(slice.353)
add.124 = f32[4]{0} add(reshape.1059, reshape.1060)
slice.354 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [0:1]}
reshape.1061 = f32[4]{0} reshape(slice.354)
slice.379 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1062 = f32[4]{0} reshape(slice.379)
add.89 = f32[4]{0} add(reshape.1061, reshape.1062)
subtract.126 = f32[4]{0} subtract(add.124, add.89)
is-finite.127 = pred[4]{0} is-finite(subtract.126)
not.128 = pred[4]{0} not(is-finite.127)
abs.129 = f32[4]{0} abs(subtract.126)
constant.130 = f32[] constant(inf)
broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={}
compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED
not.133 = pred[4]{0} not(compare.132)
and.134 = pred[4]{0} and(not.128, not.133)
add.135 = f32[4]{0} add(add.124, add.89)
maximum.125 = f32[4]{0} maximum(add.124, add.89)
abs.136 = f32[4]{0} abs(subtract.126)
negate.137 = f32[4]{0} negate(abs.136)
exponential.138 = f32[4]{0} exponential(negate.137)
log-plus-one.139 = f32[4]{0} log-plus-one(exponential.138)
add.140 = f32[4]{0} add(maximum.125, log-plus-one.139)
select.141 = f32[4]{0} select(and.134, add.135, add.140)
slice.356 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1064 = f32[4]{0} reshape(slice.356)
add.214 = f32[4]{0} add(select.141, reshape.1064)
slice.380 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1066 = f32[4]{0} reshape(slice.380)
add.179 = f32[4]{0} add(select.141, reshape.1066)
subtract.216 = f32[4]{0} subtract(add.214, add.179)
is-finite.217 = pred[4]{0} is-finite(subtract.216)
not.218 = pred[4]{0} not(is-finite.217)
abs.219 = f32[4]{0} abs(subtract.216)
constant.220 = f32[] constant(inf)
broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={}
compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED
not.223 = pred[4]{0} not(compare.222)
and.224 = pred[4]{0} and(not.218, not.223)
add.225 = f32[4]{0} add(add.214, add.179)
maximum.215 = f32[4]{0} maximum(add.214, add.179)
abs.226 = f32[4]{0} abs(subtract.216)
negate.227 = f32[4]{0} negate(abs.226)
exponential.228 = f32[4]{0} exponential(negate.227)
log-plus-one.229 = f32[4]{0} log-plus-one(exponential.228)
add.230 = f32[4]{0} add(maximum.215, log-plus-one.229)
select.231 = f32[4]{0} select(and.224, add.225, add.230)
slice.359 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1068 = f32[4]{0} reshape(slice.359)
add.304 = f32[4]{0} add(select.231, reshape.1068)
slice.381 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1070 = f32[4]{0} reshape(slice.381)
add.269 = f32[4]{0} add(select.231, reshape.1070)
subtract.306 = f32[4]{0} subtract(add.304, add.269)
is-finite.307 = pred[4]{0} is-finite(subtract.306)
not.308 = pred[4]{0} not(is-finite.307)
abs.309 = f32[4]{0} abs(subtract.306)
constant.310 = f32[] constant(inf)
broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={}
compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED
not.313 = pred[4]{0} not(compare.312)
and.314 = pred[4]{0} and(not.308, not.313)
add.315 = f32[4]{0} add(add.304, add.269)
maximum.305 = f32[4]{0} maximum(add.304, add.269)
abs.316 = f32[4]{0} abs(subtract.306)
negate.317 = f32[4]{0} negate(abs.316)
exponential.318 = f32[4]{0} exponential(negate.317)
log-plus-one.319 = f32[4]{0} log-plus-one(exponential.318)
add.320 = f32[4]{0} add(maximum.305, log-plus-one.319)
select.321 = f32[4]{0} select(and.314, add.315, add.320)
slice.362 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1072 = f32[4]{0} reshape(slice.362)
add.394 = f32[4]{0} add(select.321, reshape.1072)
slice.382 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1074 = f32[4]{0} reshape(slice.382)
add.359 = f32[4]{0} add(select.321, reshape.1074)
subtract.396 = f32[4]{0} subtract(add.394, add.359)
is-finite.397 = pred[4]{0} is-finite(subtract.396)
not.398 = pred[4]{0} not(is-finite.397)
abs.399 = f32[4]{0} abs(subtract.396)
constant.400 = f32[] constant(inf)
broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={}
compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED
not.403 = pred[4]{0} not(compare.402)
and.404 = pred[4]{0} and(not.398, not.403)
add.405 = f32[4]{0} add(add.394, add.359)
maximum.395 = f32[4]{0} maximum(add.394, add.359)
abs.406 = f32[4]{0} abs(subtract.396)
negate.407 = f32[4]{0} negate(abs.406)
exponential.408 = f32[4]{0} exponential(negate.407)
log-plus-one.409 = f32[4]{0} log-plus-one(exponential.408)
add.410 = f32[4]{0} add(maximum.395, log-plus-one.409)
select.411 = f32[4]{0} select(and.404, add.405, add.410)
slice.365 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1076 = f32[4]{0} reshape(slice.365)
add.484 = f32[4]{0} add(select.411, reshape.1076)
slice.383 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1078 = f32[4]{0} reshape(slice.383)
add.449 = f32[4]{0} add(select.411, reshape.1078)
subtract.486 = f32[4]{0} subtract(add.484, add.449)
is-finite.487 = pred[4]{0} is-finite(subtract.486)
not.488 = pred[4]{0} not(is-finite.487)
abs.489 = f32[4]{0} abs(subtract.486)
constant.490 = f32[] constant(inf)
broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={}
compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED
not.493 = pred[4]{0} not(compare.492)
and.494 = pred[4]{0} and(not.488, not.493)
add.495 = f32[4]{0} add(add.484, add.449)
maximum.485 = f32[4]{0} maximum(add.484, add.449)
abs.496 = f32[4]{0} abs(subtract.486)
negate.497 = f32[4]{0} negate(abs.496)
exponential.498 = f32[4]{0} exponential(negate.497)
log-plus-one.499 = f32[4]{0} log-plus-one(exponential.498)
add.500 = f32[4]{0} add(maximum.485, log-plus-one.499)
select.501 = f32[4]{0} select(and.494, add.495, add.500)
slice.368 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1080 = f32[4]{0} reshape(slice.368)
add.574 = f32[4]{0} add(select.501, reshape.1080)
slice.384 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1082 = f32[4]{0} reshape(slice.384)
add.539 = f32[4]{0} add(select.501, reshape.1082)
subtract.576 = f32[4]{0} subtract(add.574, add.539)
is-finite.577 = pred[4]{0} is-finite(subtract.576)
not.578 = pred[4]{0} not(is-finite.577)
abs.579 = f32[4]{0} abs(subtract.576)
constant.580 = f32[] constant(inf)
broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={}
compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED
not.583 = pred[4]{0} not(compare.582)
and.584 = pred[4]{0} and(not.578, not.583)
add.585 = f32[4]{0} add(add.574, add.539)
maximum.575 = f32[4]{0} maximum(add.574, add.539)
abs.586 = f32[4]{0} abs(subtract.576)
negate.587 = f32[4]{0} negate(abs.586)
exponential.588 = f32[4]{0} exponential(negate.587)
log-plus-one.589 = f32[4]{0} log-plus-one(exponential.588)
add.590 = f32[4]{0} add(maximum.575, log-plus-one.589)
select.591 = f32[4]{0} select(and.584, add.585, add.590)
slice.371 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1084 = f32[4]{0} reshape(slice.371)
add.664 = f32[4]{0} add(select.591, reshape.1084)
slice.385 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1086 = f32[4]{0} reshape(slice.385)
add.629 = f32[4]{0} add(select.591, reshape.1086)
subtract.666 = f32[4]{0} subtract(add.664, add.629)
is-finite.667 = pred[4]{0} is-finite(subtract.666)
not.668 = pred[4]{0} not(is-finite.667)
abs.669 = f32[4]{0} abs(subtract.666)
constant.670 = f32[] constant(inf)
broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={}
compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED
not.673 = pred[4]{0} not(compare.672)
and.674 = pred[4]{0} and(not.668, not.673)
add.675 = f32[4]{0} add(add.664, add.629)
maximum.665 = f32[4]{0} maximum(add.664, add.629)
abs.676 = f32[4]{0} abs(subtract.666)
negate.677 = f32[4]{0} negate(abs.676)
exponential.678 = f32[4]{0} exponential(negate.677)
log-plus-one.679 = f32[4]{0} log-plus-one(exponential.678)
add.680 = f32[4]{0} add(maximum.665, log-plus-one.679)
select.681 = f32[4]{0} select(and.674, add.675, add.680)
slice.374 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1088 = f32[4]{0} reshape(slice.374)
add.754 = f32[4]{0} add(select.681, reshape.1088)
slice.386 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1090 = f32[4]{0} reshape(slice.386)
add.719 = f32[4]{0} add(select.681, reshape.1090)
subtract.756 = f32[4]{0} subtract(add.754, add.719)
is-finite.757 = pred[4]{0} is-finite(subtract.756)
not.758 = pred[4]{0} not(is-finite.757)
abs.759 = f32[4]{0} abs(subtract.756)
constant.760 = f32[] constant(inf)
broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={}
compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED
not.763 = pred[4]{0} not(compare.762)
and.764 = pred[4]{0} and(not.758, not.763)
add.765 = f32[4]{0} add(add.754, add.719)
maximum.755 = f32[4]{0} maximum(add.754, add.719)
abs.766 = f32[4]{0} abs(subtract.756)
negate.767 = f32[4]{0} negate(abs.766)
exponential.768 = f32[4]{0} exponential(negate.767)
log-plus-one.769 = f32[4]{0} log-plus-one(exponential.768)
add.770 = f32[4]{0} add(maximum.755, log-plus-one.769)
select.771 = f32[4]{0} select(and.764, add.765, add.770)
slice.377 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]}
reshape.1092 = f32[4]{0} reshape(slice.377)
add.844 = f32[4]{0} add(select.771, reshape.1092)
slice.387 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]}
reshape.1094 = f32[4]{0} reshape(slice.387)
add.809 = f32[4]{0} add(select.771, reshape.1094)
subtract.846 = f32[4]{0} subtract(add.844, add.809)
is-finite.847 = pred[4]{0} is-finite(subtract.846)
not.848 = pred[4]{0} not(is-finite.847)
abs.849 = f32[4]{0} abs(subtract.846)
constant.850 = f32[] constant(inf)
broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={}
compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED
not.853 = pred[4]{0} not(compare.852)
and.854 = pred[4]{0} and(not.848, not.853)
add.855 = f32[4]{0} add(add.844, add.809)
maximum.845 = f32[4]{0} maximum(add.844, add.809)
abs.856 = f32[4]{0} abs(subtract.846)
negate.857 = f32[4]{0} negate(abs.856)
exponential.858 = f32[4]{0} exponential(negate.857)
log-plus-one.859 = f32[4]{0} log-plus-one(exponential.858)
add.860 = f32[4]{0} add(maximum.845, log-plus-one.859)
select.861 = f32[4]{0} select(and.854, add.855, add.860)
constant.865 = f32[] constant(0)
reduce.2 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866
reduce.3 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866
add.77 = f32[] add(reduce.2, reduce.3)
constant.719 = f32[] constant(0.125)
multiply = f32[] multiply(add.77, constant.719)
ROOT tuple.873 = (f32[]) tuple(multiply)
})")
.ConsumeValueOrDie();
auto input_array = absl::make_unique<Array2D<float>>(4, 2);
input_array->FillUnique(1.0f);
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt));
}
// Describes a binary rank-2 concatenation test.
struct R2BinarySpec {
int64 lhs_dim0;
int64 lhs_dim1;
int64 rhs_dim0;
int64 rhs_dim1;
int64 concat_dimension;
};
// TEST_P harness for binary rank-2 concatenation.
class ConcatR2BinaryTest : public ClientLibraryTestBase,
public ::testing::WithParamInterface<R2BinarySpec> {
};
TEST_P(ConcatR2BinaryTest, DoIt) {
const R2BinarySpec& spec = GetParam();
Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
lhs.FillUnique();
Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
rhs.FillUnique(1000);
XlaBuilder builder(TestName());
auto a0 = ConstantR2FromArray2D<int32>(&builder, lhs);
auto a1 = ConstantR2FromArray2D<int32>(&builder, rhs);
ConcatInDim(&builder, {a0, a1}, spec.concat_dimension);
std::unique_ptr<Array2D<int32>> expected =
ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
ComputeAndCompareR2<int32>(&builder, *expected, {});
}
// Regression test for b/31944287. x*y is used (at the same index) by all
// operands of the concat. We should emit x*y in three incoming basic blocks of
// the concat because these basic blocks are not control-equivalent.
//
// x*y
// / | \
// add1 add2 add3
// \ | /
// concat
XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR0<float>(2.f);
auto y_literal = LiteralUtil::CreateR0<float>(3.f);
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, f32_scalar, "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto mul = Mul(x, y);
auto add1 = Add(mul, ConstantR1<float>(&builder, {1.f, 2.f}));
auto add2 = Add(mul, ConstantR1<float>(&builder, {3.f, 4.f}));
auto add3 = Add(mul, ConstantR1<float>(&builder, {5.f, 6.f}));
ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0);
ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
{x_data.get(), y_data.get()}, ErrorSpec(1e-4));
}
// Test that the HLO optimization to replace a concat of a broadcasted scalar
// produces the correct result in rank 1.
XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "z");
auto bcast = Broadcast(y, {5});
auto bcast2 = Broadcast(z, {3});
auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0);
ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0);
ComputeAndCompareR1<float>(
&builder,
{1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 2.0f, 3.0f, 5.0f, 6.0f, 5.5f, 5.5f, 5.5f},
{x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4));
}
// Test that the HLO optimization to replace a concat of a broadcasted scalar
// produces the correct result in rank 3 with both high and low padding in
// different dimensions.
XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
Array3D<float> x3d(3, 5, 7, 3.14f);
auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, x_literal.shape(), "x");
auto y = Parameter(&builder, 1, f32_scalar, "y");
auto z = Parameter(&builder, 2, f32_scalar, "y");
auto y_bcast = Broadcast(y, {1, 5, 7});
auto z_bcast = Broadcast(z, {4, 1, 7});
auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0);
ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1);
Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
auto concat1 = ReferenceUtil::Concat3D(*concat0, z_bcast3d, 1);
ComputeAndCompareR3<float>(&builder, *concat1,
{x_data.get(), y_data.get(), z_data.get()},
ErrorSpec(1e-4));
}
INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
R2BinarySpec{1, 1, 1, 1, 1},
R2BinarySpec{4, 3, 4, 3, 0},
R2BinarySpec{4, 3, 4, 3, 1},
R2BinarySpec{7, 128, 1, 128, 0},
R2BinarySpec{8, 127, 8, 1, 1}));
} // namespace
} // namespace xla