Emit tuple at the end of the Sort emitter.
If sort has more than one operand, the result is a tuple. So far, we didn't emit the tuple at the end of the emitter. PiperOrigin-RevId: 317082573 Change-Id: I7bec31302ba2e40556b17654daa081428871a00e
This commit is contained in:
parent
0a541ad1cc
commit
7378fabf90
@ -1418,6 +1418,13 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
|
||||
if (sort->operand_count() > 1) {
|
||||
// Emit the tuple as part of the last stage of sorting.
|
||||
// We are currently in the block sorted.in_bounds.after.
|
||||
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
|
||||
llvm_ir::EmitTuple(GetIrArray(*sort, *sort),
|
||||
ConstructIrArrayForOutputs(*sort), &b_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -577,5 +577,37 @@ XLA_TEST_F(TupleHloTest,
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TupleHloTest, TupleSelectOfSort) {
|
||||
const char* testcase = R"(
|
||||
HloModule sort
|
||||
|
||||
compare {
|
||||
p.1.lhs = s32[] parameter(2)
|
||||
p.1.rhs = s32[] parameter(3)
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY Sort {
|
||||
keys = f32[2]{0} iota(), iota_dimension=0
|
||||
values = s32[2]{0} iota(), iota_dimension=0
|
||||
preds = pred[] constant(true)
|
||||
alt = (f32[2], s32[2]) parameter(0)
|
||||
|
||||
sorted = (f32[2]{0}, s32[2]{0}) sort(keys, values), dimensions={0},
|
||||
to_apply=compare
|
||||
ROOT selected = (f32[2], s32[2]) tuple-select(preds, sorted, alt)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||
auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}),
|
||||
LiteralUtil::CreateR1<int>({3, 4}));
|
||||
auto expected = LiteralUtil::MakeTupleOwned(
|
||||
LiteralUtil::CreateR1<float>({0, 1}), LiteralUtil::CreateR1<int>({0, 1}));
|
||||
auto result = ExecuteAndTransfer(std::move(module), {¶m});
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user