[XLA/GPU] Remove TupleSelect implementation.
It is not used by major XLA/GPU users, and it adds a lot of implementation burden. PiperOrigin-RevId: 341926708 Change-Id: I8291f11969b15f8439d2f390bb4e840e5cd70c80
This commit is contained in:
parent
0131d1a7d0
commit
7ac4c1f85e
@ -143,46 +143,5 @@ TEST_F(CustomCallTest, SubBuffers) {
|
||||
EXPECT_THAT(result.data<float>({1, 1}), ::testing::Each(2));
|
||||
EXPECT_THAT(result.data<float>({2}), ::testing::Each(3));
|
||||
}
|
||||
|
||||
void Callback_TupleSelect(CUstream stream, void** buffers,
|
||||
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||
// Set the two output leaf buffers equal to the two input leaf buffers.
|
||||
cudaMemcpyAsync(buffers[2], buffers[0], 10 * sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(buffers[3], buffers[1], 10 * sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_TupleSelect, "CUDA");
|
||||
// Tuple-shaped select is a case where XLA can't know all buffer assignments
|
||||
// statically ahead of time and has to walk the on-device tuple sub-buffers.
|
||||
TEST_F(CustomCallTest, TupleSelect) {
|
||||
XlaBuilder b(TestName());
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({
|
||||
ShapeUtil::MakeShape(F32, {10}),
|
||||
ShapeUtil::MakeShape(F32, {10}),
|
||||
});
|
||||
auto p0 = AddParam(LiteralUtil::CreateR0(false), &b);
|
||||
auto p1 =
|
||||
AddParam(LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(10, 1.0f)),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(10, 2.0f))),
|
||||
&b);
|
||||
auto p2 =
|
||||
AddParam(LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(10, 10.0f)),
|
||||
LiteralUtil::CreateR1<float>(std::vector<float>(10, 20.0f))),
|
||||
&b);
|
||||
auto cc = CustomCall(&b, "Callback_TupleSelect",
|
||||
/*operands=*/{Select(p0, p1, p2)}, tuple_shape,
|
||||
/*opaque=*/"");
|
||||
|
||||
// Do a tuple-select on the custom-call result to ensure that the custom-call
|
||||
// sets its output tuple index buffers.
|
||||
Select(p0, p1, cc);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto result, ComputeAndTransfer(&b, {}));
|
||||
EXPECT_THAT(result.data<float>({0}), ::testing::Each(10));
|
||||
EXPECT_THAT(result.data<float>({1}), ::testing::Each(20));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
||||
|
||||
@ -531,17 +531,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) {
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
auto pred = tuple_select->operand(0);
|
||||
auto on_true = tuple_select->operand(1);
|
||||
auto on_false = tuple_select->operand(2);
|
||||
TF_RET_CHECK(pred->shape().element_type() == PRED);
|
||||
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
|
||||
TF_RET_CHECK(tuple_select->shape().IsTuple());
|
||||
llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
|
||||
GetIrArray(*pred, *tuple_select),
|
||||
GetBasePointer(*on_true), GetBasePointer(*on_false),
|
||||
&b_);
|
||||
return Status::OK();
|
||||
return InternalError(
|
||||
"Dynamic selection of tuples is not supported. Please file a bug against "
|
||||
"XLA/GPU if you need it");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@ -2061,12 +2061,6 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
AddThunkToThunkSequence(
|
||||
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
|
||||
return IrEmitter::HandleTupleSelect(tuple_select);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
|
||||
AddThunkToThunkSequence(absl::make_unique<ReplicaIdThunk>(
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo)));
|
||||
|
||||
@ -186,7 +186,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
|
||||
Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
Status HandleAfterAll(HloInstruction* after_all) override;
|
||||
Status HandleReplicaId(HloInstruction* hlo) override;
|
||||
|
||||
@ -732,7 +732,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
|
||||
ContainsRegex("stream is uninitialized or in an error state"));
|
||||
}
|
||||
|
||||
XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
|
||||
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(SelectBetweenTuples)) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
|
||||
|
||||
@ -202,7 +202,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
|
||||
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) {
|
||||
XlaBuilder b(TestName());
|
||||
XlaOp v1, v2;
|
||||
|
||||
@ -275,7 +275,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
|
||||
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesOnFalse)) {
|
||||
// Tests a selection between tuples with "false" path taken.
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
@ -292,7 +292,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
|
||||
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, TuplesInAMap) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(TuplesInAMap)) {
|
||||
XlaComputation tuple_computation;
|
||||
{
|
||||
// tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
|
||||
@ -319,7 +319,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
|
||||
ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesOnTrue)) {
|
||||
// Tests a selection between tuples with "true" path taken.
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
@ -336,7 +336,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
|
||||
ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesElementResult)) {
|
||||
// Tests a selection between tuples but the final result is an element of the
|
||||
// tuple, not the whole tuple.
|
||||
XlaBuilder builder(TestName());
|
||||
@ -355,7 +355,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
|
||||
}
|
||||
|
||||
// Cascaded selects between tuple types.
|
||||
XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesCascaded)) {
|
||||
//
|
||||
// vec1 vec2 vec2 vec1
|
||||
// | | | |
|
||||
@ -392,7 +392,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
|
||||
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
|
||||
XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenTuplesReuseConstants)) {
|
||||
// Similar to SelectBetweenTuples, but the constants are shared between the
|
||||
// input tuples.
|
||||
XlaBuilder builder(TestName());
|
||||
@ -535,8 +535,8 @@ XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
|
||||
}
|
||||
|
||||
// Disabled on interpreter due to lack of outfeed.
|
||||
XLA_TEST_F(TupleHloTest,
|
||||
DISABLED_ON_INTERPRETER(NonAmbiguousTopLevelAllocation)) {
|
||||
XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
NonAmbiguousTopLevelAllocation))) {
|
||||
const char* testcase = R"(
|
||||
HloModule tuple
|
||||
|
||||
@ -577,7 +577,7 @@ XLA_TEST_F(TupleHloTest,
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleHloTest, TupleSelectOfSort) {
|
||||
XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(TupleSelectOfSort)) {
|
||||
const char* testcase = R"(
|
||||
HloModule sort
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user