STT-tensorflow/tensorflow/compiler/xla/tests/tuple_test.cc
Tim Shen 7ac4c1f85e [XLA/GPU] Remove TupleSelect implementation.
It is not used by major XLA/GPU users, and it adds a lot of implementation burden.

PiperOrigin-RevId: 341926708
Change-Id: I8291f11969b15f8439d2f390bb4e840e5cd70c80
2020-11-11 16:00:41 -08:00

614 lines
25 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <initializer_list>
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
class TupleTest : public ClientLibraryTestBase {
public:
ErrorSpec error_spec_{0.0001};
};
// Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest, TupleConstant) {
XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
auto value = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(constant_scalar),
LiteralUtil::CreateR1<float>(constant_vector),
LiteralUtil::CreateR2<float>(constant_matrix)});
ConstantLiteral(&builder, value);
ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest, TupleScalarConstant) {
XlaBuilder builder(TestName());
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
auto value = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(constant_scalar1),
LiteralUtil::CreateR0<float>(constant_scalar2)});
ConstantLiteral(&builder, value);
ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreate) {
XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
Tuple(&builder, {ConstantR0<float>(&builder, constant_scalar),
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(constant_scalar),
LiteralUtil::CreateR1<float>(constant_vector),
LiteralUtil::CreateR2<float>(constant_matrix)});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
XlaBuilder builder(TestName());
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
auto expected = LiteralUtil::MakeTuple({});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElement) {
XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
auto tuple_data =
Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
XlaBuilder builder(TestName());
auto tuple_data =
Tuple(&builder,
{ConstantR1<float>(&builder, {}),
ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 101))});
GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
XlaBuilder builder(TestName());
auto value = ConstantR1<float>(&builder, {4.5f});
GetTupleElement(value, 1);
auto result_status = builder.Build();
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(
result_status.status().error_message(),
::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
}
// Extracts both elements from a tuple with GetTupleElement and then adds them
// together.
XLA_TEST_F(TupleTest, AddTupleElements) {
XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
auto tuple_data =
Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
auto vector_element = GetTupleElement(tuple_data, 0);
auto matrix_element = GetTupleElement(tuple_data, 1);
auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
Add(matrix_element, vector_element,
/*broadcast_dimensions=*/{1});
Array2D<float> expected({
{2.f, 4.f, 6.f}, // row 0
{5.f, 7.f, 9.f}, // row 1
});
ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3})));
ASSERT_TRUE(ShapeUtil::Equal(matrix_shape,
ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3})));
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// Extracts both elements from a tuple and then puts them into a new tuple in
// the opposite order.
XLA_TEST_F(TupleTest, TupleGTEToTuple) {
XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
auto tuple_data =
Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR2<float>(constant_matrix),
LiteralUtil::CreateR1<float>(constant_vector)});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) {
XlaBuilder b(TestName());
XlaOp v1, v2;
for (bool direction : {false, true}) {
std::unique_ptr<GlobalData> v1_data =
CreateR0Parameter<float>(0.0f, /*parameter_number=*/0, /*name=*/"v1",
/*builder=*/&b, /*data_handle=*/&v1);
std::unique_ptr<GlobalData> v2_data =
CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&b, /*data_handle=*/&v2);
auto v1_gt = Gt(v1, v2); // false
auto v2_gt = Gt(v2, v1); // true
auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<bool>(direction),
LiteralUtil::CreateR0<bool>(!direction)});
ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
// Builds two new tuples from an existing tuple (by means of GetTupleElement),
// then adds up the components of the new tuples.
XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
//
// v------ --(GTE 0)-- --(GTE 0)----------
// \ / \ / \
// (tuple)-- (tuple01)-- \
// / | \ / \ \
// m------ | --(GTE 1)-- --(GTE 1)------------ \
// | \ \
// | (add)
// | / /
// |--------(GTE 1)-- --(GTE 0)------------ /
// \ \ / /
// \ (tuple10)-- /
// \ / \ /
// -----(GTE 0)-- --(GTE 1)----------
XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
auto tuple_data =
Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0),
GetTupleElement(tuple_data, 1)});
auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1),
GetTupleElement(tuple_data, 0)});
auto vector_from_01 = GetTupleElement(new_tuple01, 0);
auto vector_from_10 = GetTupleElement(new_tuple10, 1);
auto matrix_from_01 = GetTupleElement(new_tuple01, 1);
auto matrix_from_10 = GetTupleElement(new_tuple10, 0);
auto addvectors = Add(vector_from_01, vector_from_10);
auto addmatrices = Add(matrix_from_01, matrix_from_10);
Add(addmatrices, addvectors,
/*broadcast_dimensions=*/{1});
Array2D<float> expected({
{4.f, 8.f, 12.f}, // row 0
{10.f, 14.f, 18.f}, // row 1
});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesOnFalse)) {
// Tests a selection between tuples with "false" path taken.
XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
ConstantR1<float>(&builder, vec2)});
auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(TuplesInAMap)) {
XlaComputation tuple_computation;
{
// tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
//
// Need to put a select in there to prevent HLO-level optimizations from
// optimizing out the tuples.
XlaBuilder b("sort_square");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
auto x2 = Mul(x, x);
auto x_smaller_tuple = Tuple(&b, {x, x2});
auto x2_smaller_tuple = Tuple(&b, {x2, x});
auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
auto smaller = GetTupleElement(sorted, 0);
auto greater = GetTupleElement(sorted, 1);
Add(greater, Mul(ConstantR0<float>(&b, 100.0f), smaller));
auto computation_status = b.Build();
ASSERT_IS_OK(computation_status.status());
tuple_computation = computation_status.ConsumeValueOrDie();
}
XlaBuilder b(TestName());
auto input = ConstantR1<float>(&b, {-1.0f, 1.0f, 2.1f});
Map(&b, {input}, tuple_computation, {0});
ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
}
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesOnTrue)) {
// Tests a selection between tuples with "true" path taken.
XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
ConstantR1<float>(&builder, vec2)});
auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesElementResult)) {
// Tests a selection between tuples but the final result is an element of the
// tuple, not the whole tuple.
XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
ConstantR1<float>(&builder, vec2)});
auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
ConstantR1<float>(&builder, vec1)});
auto select = Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
GetTupleElement(select, 0);
ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
}
// Cascaded selects between tuple types.
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesCascaded)) {
//
// vec1 vec2 vec2 vec1
// | | | |
// | | | |
// (tuple 12) (tuple 21)
// \ /
// \ /
// \ /
// true -- --(GTE 0)--(select 1)
// \ / |
// (pred tuple)-- | --(GTE 0)--
// / \ V / \
// false -- --(GTE 1)--(select 2)-- --(add)
// / \ /
// / --(GTE 1)--
// /
// (tuple 21)
XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto pred_tuple = Tuple(&builder, {ConstantR0<bool>(&builder, true),
ConstantR0<bool>(&builder, false)});
auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
ConstantR1<float>(&builder, vec2)});
auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
ConstantR1<float>(&builder, vec1)});
auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21);
auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1);
Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1));
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesReuseConstants)) {
// Similar to SelectBetweenTuples, but the constants are shared between the
// input tuples.
XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto c1 = ConstantR1<float>(&builder, vec1);
auto c2 = ConstantR1<float>(&builder, vec2);
auto tuple12 = Tuple(&builder, {c1, c2});
auto tuple21 = Tuple(&builder, {c2, c1});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
XlaBuilder builder(TestName());
auto inner_tuple = Tuple(&builder, {ConstantR1<float>(&builder, {1.0, 2.0}),
ConstantR0<float>(&builder, 42.0)});
Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
LiteralUtil::MakeTuple({&expected_v1, &expected_s});
auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
XlaBuilder builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {3});
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
Shape outer_tuple_shape =
ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
auto input = Parameter(&builder, 0, outer_tuple_shape, "input");
auto gte0 = GetTupleElement(input, 0);
auto gte1 = GetTupleElement(gte0, 1);
Add(gte1, ConstantR1<float>(&builder, {10.0, 11.0, 12.0}));
std::unique_ptr<GlobalData> data =
client_
->TransferToServer(LiteralUtil::MakeTupleFromSlices({
LiteralUtil::MakeTupleFromSlices({
LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
}),
LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
}))
.ConsumeValueOrDie();
std::vector<GlobalData*> arguments = {data.get()};
const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
}
XLA_TEST_F(TupleTest, ComplexTuples) {
XlaBuilder builder(TestName());
{
Shape c64r0 = ShapeUtil::MakeShape(C64, {});
Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
Shape arg0_shape = ShapeUtil::MakeTupleShape(
{c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
auto input0 = Parameter(&builder, 0, arg0_shape, "input0");
auto t0 = GetTupleElement(input0, 0);
auto t1 = GetTupleElement(input0, 1);
auto t10 = GetTupleElement(t1, 0);
auto t11 = GetTupleElement(t1, 1);
auto sum = Add(Add(t10, t11, {1}), t0);
auto input1 = Parameter(&builder, 1, c64r1, "input1");
auto prod = Mul(input1, sum, {1});
Tuple(&builder, {Tuple(&builder, {prod, sum}),
ConstantR0<complex64>(&builder, {123, 456})});
}
std::unique_ptr<GlobalData> arg0 =
client_
->TransferToServer(LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR0<complex64>({1, 2}),
LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
{{10000, 20000}, {30000, 40000}}})})}))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
->TransferToServer(
LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
auto sum =
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
Literal prod(sum.shape());
ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
return sum.Get<complex64>(indexes) *
(indexes[indexes.size() - 1] == 0
? complex64(1, 2)
: complex64(1, -2));
})
.ok());
auto expected = LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::MakeTupleFromSlices({prod, sum}),
LiteralUtil::CreateR0<complex64>({123, 456})});
ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
error_spec_);
}
class TupleHloTest : public HloTestBase {};
XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
const char* testcase = R"(
HloModule m, is_scheduled=true
ENTRY test {
name.1 = (f32[3]{0}) parameter(0)
get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0
bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1)
copy = f32[1,3]{1,0} copy(bitcast)
ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
}
)";
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
result));
}
// Disabled on interpreter due to lack of outfeed.
XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
NonAmbiguousTopLevelAllocation))) {
const char* testcase = R"(
HloModule tuple
ENTRY main {
a = f32[2] parameter(0)
b = f32[2] parameter(1)
c = f32[2] parameter(2)
d = f32[2] parameter(3)
cond = pred[] parameter(4)
tup0 = (f32[2],f32[2]) tuple(a, b)
tup1 = (f32[2],f32[2]) tuple(c, d)
s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1)
gte = f32[2] get-tuple-element(s), index=0
tuple = (f32[2]) tuple(gte)
token0 = token[] after-all()
ROOT outfeed = token[] outfeed(tuple, token0)
}
)";
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param0 = LiteralUtil::CreateR1<float>({1, 2});
auto param1 = LiteralUtil::CreateR1<float>({2, 3});
auto param4 = LiteralUtil::CreateR0<bool>(false);
// Put execution on a separate thread so we can block on outfeed.
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
TF_EXPECT_OK(Execute(std::move(module),
{&param0, &param1, &param1, &param0, &param4})
.status());
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
backend().default_stream_executor(), expected.shape(), &literal));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(TupleSelectOfSort)) {
const char* testcase = R"(
HloModule sort
compare {
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
keys = f32[2]{0} iota(), iota_dimension=0
values = s32[2]{0} iota(), iota_dimension=0
preds = pred[] constant(true)
alt = (f32[2], s32[2]) parameter(0)
sorted = (f32[2]{0}, s32[2]{0}) sort(keys, values), dimensions={0},
to_apply=compare
ROOT selected = (f32[2], s32[2]) tuple-select(preds, sorted, alt)
}
)";
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}),
LiteralUtil::CreateR1<int>({3, 4}));
auto expected = LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({0, 1}), LiteralUtil::CreateR1<int>({0, 1}));
auto result = ExecuteAndTransfer(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
} // namespace
} // namespace xla