Emit tuple at the end of the Sort emitter.

If sort has more than one operand, the result is a tuple. So far, we didn't
emit the tuple at the end of the emitter.

PiperOrigin-RevId: 317082573
Change-Id: I7bec31302ba2e40556b17654daa081428871a00e
This commit is contained in:
Adrian Kuegel 2020-06-18 05:19:27 -07:00 committed by TensorFlower Gardener
parent 0a541ad1cc
commit 7378fabf90
2 changed files with 39 additions and 0 deletions

View File

@ -1418,6 +1418,13 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
if (sort->operand_count() > 1) {
// Emit the tuple as part of the last stage of sorting.
// We are currently in the block sorted.in_bounds.after.
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
llvm_ir::EmitTuple(GetIrArray(*sort, *sort),
ConstructIrArrayForOutputs(*sort), &b_);
}
return Status::OK();
}

View File

@ -577,5 +577,37 @@ XLA_TEST_F(TupleHloTest,
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
XLA_TEST_F(TupleHloTest, TupleSelectOfSort) {
const char* testcase = R"(
HloModule sort
compare {
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
keys = f32[2]{0} iota(), iota_dimension=0
values = s32[2]{0} iota(), iota_dimension=0
preds = pred[] constant(true)
alt = (f32[2], s32[2]) parameter(0)
sorted = (f32[2]{0}, s32[2]{0}) sort(keys, values), dimensions={0},
to_apply=compare
ROOT selected = (f32[2], s32[2]) tuple-select(preds, sorted, alt)
}
)";
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}),
LiteralUtil::CreateR1<int>({3, 4}));
auto expected = LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<float>({0, 1}), LiteralUtil::CreateR1<int>({0, 1}));
auto result = ExecuteAndTransfer(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
} // namespace
} // namespace xla