Make mhlo.sort return variadic results instead of a tuple

Tuple is only used on XLA's sort to return multiple inputs. MLIR supports
multiple inputs, switch to a tuple return.

PiperOrigin-RevId: 334226937
Change-Id: I269cfca1596064aedaff91449024a2dff7e006e8
This commit is contained in:
Robert Suderman 2020-09-28 13:31:28 -07:00 committed by TensorFlower Gardener
parent 6fe1ca9125
commit e683cc593e
10 changed files with 79 additions and 59 deletions

View File

@ -1198,14 +1198,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
let results = (outs HLO_Tensor);
}
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp {
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
let arguments = (ins
Variadic<HLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let results = (outs HLO_TensorOrTuple);
let results = (outs Variadic<HLO_Tensor>);
let regions = (region SizedRegion<1>:$comparator);

View File

@ -2261,10 +2261,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
SmallVector<Type, 2> element_types;
element_types.reserve(operands.size());
for (Value operand : operands) element_types.push_back(operand.getType());
state.addTypes(builder.getTupleType(element_types));
for (Value operand : operands) state.addTypes(operand.getType());
state.addRegion();
}

View File

@ -1010,34 +1010,34 @@ func @constant_invalid() -> () {
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: mhlo.sort
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
// -----
func @sort_no_operands() {
// expected-error @+1 {{op requires at least one input}}
%0 = "mhlo.sort"() ( {
// expected-error @+1 {{expected named operation to have atleast 1 result}}
%0:0 = "mhlo.sort"() ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
%7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
}) {dimension = 1 : i64, is_stable = true} : () -> ()
return
}
// -----
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
@ -1045,23 +1045,23 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
// -----
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
%0 = "mhlo.sort"(%input0, %input1) ( {
// expected-error @+1 {{op requires the same shape for all operands and results}}
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
@ -1069,11 +1069,11 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>)
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}}
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
@ -1081,11 +1081,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}}
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
@ -1093,11 +1093,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block should have 4 arguments}}
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
@ -1105,11 +1105,11 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}

View File

@ -417,13 +417,27 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kSort: {
auto sort_instruction = Cast<HloSortInstruction>(instruction);
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
loc, result_type, operands,
loc, return_types, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
return sort_op.getOperation();
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return sort_op.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
.getOperation();
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;

View File

@ -243,11 +243,22 @@ StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
int64 dimension, bool is_stable) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<mlir::Type, 4> sort_types = {ty};
if (auto tuple_ty = ty.dyn_cast<mlir::TupleType>()) {
sort_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto op = builder_.create<mlir::mhlo::SortOp>(
loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension),
builder_.getBoolAttr(is_stable));
loc_, sort_types, GetValues(operands),
builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable));
TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
return MakeXlaOp(op);
if (ty.isa<mlir::TupleType>()) {
auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
return MakeXlaOp(op.getResult(0));
}
StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,

View File

@ -1008,9 +1008,14 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
&comparator)))
return failure();
auto tupled = xla::Sort(GetTuple(op.operands(), ctx), comparator,
op.dimension(), op.is_stable());
auto& value_map = *ctx.values;
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
op.dimension(), op.is_stable());
// MLIR's sort supports multiple returns, untuple all the results of XLA's.
for (auto it : llvm::enumerate(op.getResults())) {
value_map[it.value()] = xla::GetTupleElement(tupled, it.index());
}
return success();
}

View File

@ -316,12 +316,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32>
// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32>
// CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>> {
%res = "mhlo.sort"(%key, %value) ({
func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> (tensor<5x5xi32>, tensor<5x5xf32>) {
%res:2 = "mhlo.sort"(%key, %value) ({
^bb0(%a: tensor<i32>, %b: tensor<i32>, %c: tensor<f32>, %d: tensor<f32>):
%ret = "mhlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%ret) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>>
}) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> (tensor<5x5xi32>, tensor<5x5xf32>)
return %res : tuple<tensor<5x5xi32>, tensor<5x5xf32>>
return %res#0, %res#1 : tensor<5x5xi32>, tensor<5x5xf32>
}

View File

@ -3824,15 +3824,13 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64}
// CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
// CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
// CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
// CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
// CHECK-NEXT: "mhlo.return"(%[[CMP]])
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
// CHECK-NEXT: %[[TUPL0:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
// CHECK-NEXT: %[[TUPL1:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
// CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
// CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: return %[[VAL]], %[[IDX]]
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
@ -4199,12 +4197,11 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> {
// CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
// CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32>
// CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]])
// CHECK: [[SORT:%.*]] = "mhlo.sort"([[RNG]], [[INPUT]]) ( {
// CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( {
// CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>):
// CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"}
// CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> tuple<tensor<16xi32>, tensor<16xf32>>
// CHECK: [[RES:%.*]] = "mhlo.get_tuple_element"([[SORT]]) {index = 1 : i32}
// CHECK: return [[RES]]
// CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>)
// CHECK: return [[SORT]]#1
%0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>)
return %0: tensor<16xf32>
}
@ -4213,10 +4210,8 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> {
func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> {
// CHECK: mhlo.rng_uniform
// CHECK: mhlo.sort
// CHECK: mhlo.get_tuple_element
// CHECK: mhlo.rng_uniform
// CHECK: mhlo.sort
// CHECK: mhlo.get_tuple_element
%0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>)
return %0: tensor<10240xf32>
}

View File

@ -948,19 +948,20 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: HloModule
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
%0 = "mhlo.sort"(%input0, %input1) ( {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
}
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
// CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(f32[16,16] %Arg_0.1, s32[16,16] %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0
// CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1
// -----

View File

@ -4722,10 +4722,8 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
&rewriter);
// Get the sorted input and index tuple element.
auto tuple_first_element =
rewriter.create<mhlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
auto tuple_second_element =
rewriter.create<mhlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
auto tuple_first_element = sort_op.getResult(0);
auto tuple_second_element = sort_op.getResult(1);
SmallVector<int64_t, 4> begin_indices(input_rank, 0);
auto end_indices = llvm::to_vector<4>(input_type.getShape());
@ -5011,8 +5009,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
BuildSortComparisonBody({i32_type, input_type.getElementType()},
/*direction=*/"LT", &sorted.comparator(),
&rewriter);
current = rewriter.create<GetTupleElementOp>(op.getLoc(),
sorted.getResult(), 1);
current = sorted.getResult(1);
}
rewriter.replaceOp(op, current);
return success();