Export HLO sort ops with a single operand to XLA HLO
GetTupleElement should be used only if the sort op result is a tuple. PiperOrigin-RevId: 338346572 Change-Id: I6aecb16816d92b2b5605dc31fe301aacd66dae10
This commit is contained in:
parent
a81774fb45
commit
f75608af78
@ -1012,13 +1012,24 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
|
||||
&comparator)))
|
||||
return failure();
|
||||
|
||||
auto tupled = xla::Sort(GetTuple(op.operands(), ctx), comparator,
|
||||
auto sorted = xla::Sort(GetTuple(op.operands(), ctx), comparator,
|
||||
op.dimension(), op.is_stable());
|
||||
|
||||
auto& value_map = *ctx.values;
|
||||
auto shape_or = sorted.builder()->GetShape(sorted);
|
||||
if (!shape_or.ok()) {
|
||||
return op.emitError(shape_or.status().ToString());
|
||||
}
|
||||
|
||||
xla::Shape& shape = shape_or.ValueOrDie();
|
||||
if (!shape.IsTuple()) {
|
||||
value_map[op.getResult(0)] = sorted;
|
||||
return success();
|
||||
}
|
||||
|
||||
// MLIR's sort supports multiple returns, untuple all the results of XLA's.
|
||||
for (auto it : llvm::enumerate(op.getResults())) {
|
||||
value_map[it.value()] = xla::GetTupleElement(tupled, it.index());
|
||||
value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -963,6 +963,22 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0
|
||||
// CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%input0: tensor<16x16xf32>) {
|
||||
%0 = "mhlo.sort"(%input0) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>) -> (tensor<16x16xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
|
||||
// CHECK: ROOT %[[CMP:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
|
||||
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] sort(f32[16,16] %Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
|
||||
|
||||
// -----
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user