Emit tuple at the end of the Sort emitter.
If sort has more than one operand, the result is a tuple. So far, we didn't emit the tuple at the end of the emitter. PiperOrigin-RevId: 317082573 Change-Id: I7bec31302ba2e40556b17654daa081428871a00e
This commit is contained in:
parent
0a541ad1cc
commit
7378fabf90
@ -1418,6 +1418,13 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
|||||||
|
|
||||||
AddThunkToThunkSequence(
|
AddThunkToThunkSequence(
|
||||||
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
|
absl::make_unique<SequentialThunk>(std::move(thunks), sort));
|
||||||
|
if (sort->operand_count() > 1) {
|
||||||
|
// Emit the tuple as part of the last stage of sorting.
|
||||||
|
// We are currently in the block sorted.in_bounds.after.
|
||||||
|
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
|
||||||
|
llvm_ir::EmitTuple(GetIrArray(*sort, *sort),
|
||||||
|
ConstructIrArrayForOutputs(*sort), &b_);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -577,5 +577,37 @@ XLA_TEST_F(TupleHloTest,
|
|||||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
|
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(TupleHloTest, TupleSelectOfSort) {
|
||||||
|
const char* testcase = R"(
|
||||||
|
HloModule sort
|
||||||
|
|
||||||
|
compare {
|
||||||
|
p.1.lhs = s32[] parameter(2)
|
||||||
|
p.1.rhs = s32[] parameter(3)
|
||||||
|
p.0.lhs = f32[] parameter(0)
|
||||||
|
p.0.rhs = f32[] parameter(1)
|
||||||
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY Sort {
|
||||||
|
keys = f32[2]{0} iota(), iota_dimension=0
|
||||||
|
values = s32[2]{0} iota(), iota_dimension=0
|
||||||
|
preds = pred[] constant(true)
|
||||||
|
alt = (f32[2], s32[2]) parameter(0)
|
||||||
|
|
||||||
|
sorted = (f32[2]{0}, s32[2]{0}) sort(keys, values), dimensions={0},
|
||||||
|
to_apply=compare
|
||||||
|
ROOT selected = (f32[2], s32[2]) tuple-select(preds, sorted, alt)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
|
||||||
|
auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}),
|
||||||
|
LiteralUtil::CreateR1<int>({3, 4}));
|
||||||
|
auto expected = LiteralUtil::MakeTupleOwned(
|
||||||
|
LiteralUtil::CreateR1<float>({0, 1}), LiteralUtil::CreateR1<int>({0, 1}));
|
||||||
|
auto result = ExecuteAndTransfer(std::move(module), {¶m});
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user