diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 937a0ea5bbc..74aad5f5bd5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1418,6 +1418,13 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { AddThunkToThunkSequence( absl::make_unique(std::move(thunks), sort)); + if (sort->operand_count() > 1) { + // Emit the tuple as part of the last stage of sorting. + // We are currently in the block sorted.in_bounds.after. + b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); + llvm_ir::EmitTuple(GetIrArray(*sort, *sort), + ConstructIrArrayForOutputs(*sort), &b_); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 9ef589e5511..b6ad44497e6 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -577,5 +577,37 @@ XLA_TEST_F(TupleHloTest, EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } +XLA_TEST_F(TupleHloTest, TupleSelectOfSort) { + const char* testcase = R"( + HloModule sort + + compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY Sort { + keys = f32[2]{0} iota(), iota_dimension=0 + values = s32[2]{0} iota(), iota_dimension=0 + preds = pred[] constant(true) + alt = (f32[2], s32[2]) parameter(0) + + sorted = (f32[2]{0}, s32[2]{0}) sort(keys, values), dimensions={0}, + to_apply=compare + ROOT selected = (f32[2], s32[2]) tuple-select(preds, sorted, alt) + } + )"; + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); + auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3}), + LiteralUtil::CreateR1({3, 4})); + auto expected = LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1({0, 1}), LiteralUtil::CreateR1({0, 1})); + auto result = ExecuteAndTransfer(std::move(module), {¶m}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + } // namespace } // namespace xla