Export HLO sort ops with a single operand to XLA HLO

GetTupleElement should be used only if the sort op result is a tuple.

PiperOrigin-RevId: 338346572
Change-Id: I6aecb16816d92b2b5605dc31fe301aacd66dae10
This commit is contained in:
Smit Hinsu 2020-10-21 14:56:45 -07:00 committed by TensorFlower Gardener
parent a81774fb45
commit f75608af78
2 changed files with 29 additions and 2 deletions

View File

@ -1012,13 +1012,24 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
&comparator)))
return failure();
auto tupled = xla::Sort(GetTuple(op.operands(), ctx), comparator,
auto sorted = xla::Sort(GetTuple(op.operands(), ctx), comparator,
op.dimension(), op.is_stable());
auto& value_map = *ctx.values;
auto shape_or = sorted.builder()->GetShape(sorted);
if (!shape_or.ok()) {
return op.emitError(shape_or.status().ToString());
}
xla::Shape& shape = shape_or.ValueOrDie();
if (!shape.IsTuple()) {
value_map[op.getResult(0)] = sorted;
return success();
}
// MLIR's sort supports multiple returns, untuple all the results of XLA's.
for (auto it : llvm::enumerate(op.getResults())) {
value_map[it.value()] = xla::GetTupleElement(tupled, it.index());
value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
}
return success();
}

View File

@ -963,6 +963,22 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0
// CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1
// -----
// CHECK: HloModule
func @main(%input0: tensor<16x16xf32>) {
%0 = "mhlo.sort"(%input0) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>) -> (tensor<16x16xf32>)
return
}
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
// CHECK: ROOT %[[CMP:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] sort(f32[16,16] %Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
// -----