[XLA/GPU] Remove TupleSelect implementation.

It is not used by major XLA/GPU users, and it adds a lot of implementation burden.

PiperOrigin-RevId: 341926708
Change-Id: I8291f11969b15f8439d2f390bb4e840e5cd70c80
This commit is contained in:
Tim Shen 2020-11-11 15:51:09 -08:00 committed by TensorFlower Gardener
parent 0131d1a7d0
commit 7ac4c1f85e
6 changed files with 14 additions and 70 deletions

View File

@ -143,46 +143,5 @@ TEST_F(CustomCallTest, SubBuffers) {
EXPECT_THAT(result.data<float>({1, 1}), ::testing::Each(2));
EXPECT_THAT(result.data<float>({2}), ::testing::Each(3));
}
void Callback_TupleSelect(CUstream stream, void** buffers,
const char* /*opaque*/, size_t /*opaque_len*/) {
// Set the two output leaf buffers equal to the two input leaf buffers.
cudaMemcpyAsync(buffers[2], buffers[0], 10 * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(buffers[3], buffers[1], 10 * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_TupleSelect, "CUDA");
// Tuple-shaped select is a case where XLA can't know all buffer assignments
// statically ahead of time and has to walk the on-device tuple sub-buffers.
TEST_F(CustomCallTest, TupleSelect) {
XlaBuilder b(TestName());
auto tuple_shape = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {10}),
ShapeUtil::MakeShape(F32, {10}),
});
auto p0 = AddParam(LiteralUtil::CreateR0(false), &b);
auto p1 =
AddParam(LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>(std::vector<float>(10, 1.0f)),
LiteralUtil::CreateR1<float>(std::vector<float>(10, 2.0f))),
&b);
auto p2 =
AddParam(LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>(std::vector<float>(10, 10.0f)),
LiteralUtil::CreateR1<float>(std::vector<float>(10, 20.0f))),
&b);
auto cc = CustomCall(&b, "Callback_TupleSelect",
/*operands=*/{Select(p0, p1, p2)}, tuple_shape,
/*opaque=*/"");
// Do a tuple-select on the custom-call result to ensure that the custom-call
// sets its output tuple index buffers.
Select(p0, p1, cc);
TF_ASSERT_OK_AND_ASSIGN(auto result, ComputeAndTransfer(&b, {}));
EXPECT_THAT(result.data<float>({0}), ::testing::Each(10));
EXPECT_THAT(result.data<float>({1}), ::testing::Each(20));
}
} // anonymous namespace
} // namespace xla

View File

@ -531,17 +531,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
}
Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
auto pred = tuple_select->operand(0);
auto on_true = tuple_select->operand(1);
auto on_false = tuple_select->operand(2);
TF_RET_CHECK(pred->shape().element_type() == PRED);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
TF_RET_CHECK(tuple_select->shape().IsTuple());
llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
GetIrArray(*pred, *tuple_select),
GetBasePointer(*on_true), GetBasePointer(*on_false),
&b_);
return Status::OK();
return InternalError(
"Dynamic selection of tuples is not supported. Please file a bug against "
"XLA/GPU if you need it");
}
namespace {

View File

@ -2061,12 +2061,6 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
return Status::OK();
}
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
AddThunkToThunkSequence(
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTupleSelect(tuple_select);
}
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>(
GetThunkInfo(hlo), GetAllocationSlice(*hlo)));

View File

@ -186,7 +186,6 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleSort(HloInstruction* sort) override;
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleAfterAll(HloInstruction* after_all) override;
Status HandleReplicaId(HloInstruction* hlo) override;

View File

@ -732,7 +732,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
ContainsRegex("stream is uninitialized or in an error state"));
}
XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(SelectBetweenTuples)) {
XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};

View File

@ -202,7 +202,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) {
XlaBuilder b(TestName());
XlaOp v1, v2;
@ -275,7 +275,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesOnFalse)) {
// Tests a selection between tuples with "false" path taken.
XlaBuilder builder(TestName());
@ -292,7 +292,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, TuplesInAMap) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(TuplesInAMap)) {
XlaComputation tuple_computation;
{
// tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
@ -319,7 +319,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesOnTrue)) {
// Tests a selection between tuples with "true" path taken.
XlaBuilder builder(TestName());
@ -336,7 +336,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesElementResult)) {
// Tests a selection between tuples but the final result is an element of the
// tuple, not the whole tuple.
XlaBuilder builder(TestName());
@ -355,7 +355,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
}
// Cascaded selects between tuple types.
XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesCascaded)) {
//
// vec1 vec2 vec2 vec1
// | | | |
@ -392,7 +392,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesReuseConstants)) {
// Similar to SelectBetweenTuples, but the constants are shared between the
// input tuples.
XlaBuilder builder(TestName());
@ -535,8 +535,8 @@ XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
}
// Disabled on interpreter due to lack of outfeed.
XLA_TEST_F(TupleHloTest,
DISABLED_ON_INTERPRETER(NonAmbiguousTopLevelAllocation)) {
XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
NonAmbiguousTopLevelAllocation))) {
const char* testcase = R"(
HloModule tuple
@ -577,7 +577,7 @@ XLA_TEST_F(TupleHloTest,
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
XLA_TEST_F(TupleHloTest, TupleSelectOfSort) {
XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(TupleSelectOfSort)) {
const char* testcase = R"(
HloModule sort