diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 57b61461d02..cdc545d5681 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -5768,6 +5768,8 @@ If two elements are equal, the lower-index element appears first. ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ return Verify(*this); }]; } def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3b836a6188d..1bd9accbb78 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1698,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) { return success(); } +//===----------------------------------------------------------------------===// +// TopKV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TopKV2Op op) { + if (!HasRankAtLeast(op.input(), 1)) + return op.emitOpError( + "requires input operand to have at least 1 dimension"); + + if (!IsOfRankOrUnranked(op.k(), 0)) + return op.emitOpError("requires k operand to be 0D tensor"); + + return success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 1914ca177cc..e064c1a53ef 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1658,3 +1658,19 @@ func @testTernaryEinsum(%arg0: tensor<2x3xf32>){ %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>) return } + +// ----- + +func @testTopKV2WrongInputRank(%input: tensor, %k: tensor) { + // expected-error @+1 {{op requires input operand to have at least 1 dimension}} + %0:2 = "tf.TopKV2"(%input, %k) : (tensor, tensor) -> (tensor<*xf32>, tensor<*xi32>) + return +} + +// ----- + +func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) { + // expected-error @+1 {{op requires k operand to be 0D tensor}} + %0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>) + return +} diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 8fa33d19363..b2f02bdf76f 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -841,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand, return RankedTensorType::get(shape, ranked_ty.getElementType()); } +//===----------------------------------------------------------------------===// +// SortOp +//===----------------------------------------------------------------------===// + +void SortOp::build(Builder* builder, OperationState& state, + ArrayRef operands, int64_t dimension, + bool is_stable) { + state.addOperands(operands); + state.addAttribute("dimension", builder->getI64IntegerAttr(dimension)); + state.addAttribute("is_stable", builder->getBoolAttr(dimension)); + + SmallVector element_types; + element_types.reserve(operands.size()); + for (Value* operand : operands) element_types.push_back(operand->getType()); + state.addTypes(builder->getTupleType(element_types)); + + state.addRegion(); +} + +static LogicalResult Verify(SortOp op) { + Operation::operand_range operands = op.operands(); + if (operands.empty()) return op.emitOpError("requires at least one input"); + + // TODO(antiagainst): verify partionally dynamic shapes + if (llvm::all_of(operands, [](Value* operand) { + return operand->getType().cast().hasRank(); + })) { + ArrayRef input_shape = + (*operands.begin())->getType().cast().getShape(); + + if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) { + return operand->getType().cast().getShape() != + input_shape; + })) + return op.emitOpError("requires all inputs to have the same dimensions"); + + if (op.dimension().getSExtValue() >= input_shape.size()) + return op.emitOpError( + "dimension attribute value must be less than input rank"); + } + + Block& block = op.comparator().front(); + size_t num_operands = op.getOperation()->getNumOperands(); + if (block.getNumArguments() != 2 * num_operands) + return op.emitOpError("comparator block should have ") + << 2 * num_operands << " arguments"; + + for (auto indexed_operand : llvm::enumerate(operands)) { + int index = indexed_operand.index(); + Type element_type = + indexed_operand.value()->getType().cast().getElementType(); + Type tensor_type = RankedTensorType::get({}, element_type); + for (int i : {2 * index, 2 * index + 1}) { + Type arg_type = block.getArgument(i)->getType(); + if (arg_type != tensor_type) + return op.emitOpError("comparator block argument #") + << i << " should be of type " << tensor_type << " but got " + << arg_type; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index e285b172806..c9b3e7985fc 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -868,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", let hasCustomHLOConverter = 1; } +def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { + let arguments = (ins + Variadic:$operands, + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$is_stable + ); + + let results = (outs HLO_TensorOrTuple); + + let regions = (region SizedRegion<1>:$comparator); + + let builders = [OpBuilder< + "Builder *builder, OperationState &state, ArrayRef operands, " + "int64_t dimension, bool is_stable" + >]; + + // TODO(b/129422361): SortOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + def HLO_ReverseOp: HLO_Op<"reverse", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp { let arguments = (ins diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index a0c790616fa..a6d4210b60c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp { }]; } +class BASE_HLO_SortOp { + string summary = "Sort operator"; + + string description = [{ + Sorts the given `operands` at the given `dimension` with the given + `comparator`. + + See https://www.tensorflow.org/xla/operation_semantics#sort. + }]; +} + class BASE_HLO_ReverseOp { string summary = "Reverse operator"; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 93716331d0d..e9bf3bac44b 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -624,6 +624,18 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { + xla::XlaComputation comparator; + if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(), + &comparator))) + return failure(); + + auto& value_map = *ctx.values; + value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator, + op.dimension().getSExtValue(), op.is_stable()); + return success(); +} + LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx)); diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 94a445fe8bd..8aa9b5ef101 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1934,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> } + +//===----------------------------------------------------------------------===// +// tf.TopKV2 legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: topk_v2_non_const_k +func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor) -> (tensor, tensor) { + // CHECK: tf.TopKV2 + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor) -> (tensor, tensor) + return %0#0, %0#1: tensor, tensor +} + +// CHECK-LABEL: topk_v2_unknown_input_last_dim +func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { + %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor + // CHECK: tf.TopKV2 + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor) -> (tensor<16x?xf32>, tensor<16x?xi32>) + return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> +} + +// CHECK-LABEL: topk_v2 +// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> +func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { + %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor + + // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} + // CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( { + // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): + // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"} + // CHECK-NEXT: "xla_hlo.return"(%[[CMP]]) + // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + // CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32} + // CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32} + // CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: return %[[VAL]], %[[IDX]] + %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) + return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 225fc97bb22..4f142f294e4 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -416,3 +416,98 @@ func @constants() -> () { %3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor<*xi32>) return } + +// ----- + +func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // CHECK: xla_hlo.sort + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_no_operands() { + // expected-error @+1 {{op requires at least one input}} + %0 = "xla_hlo.sort"() ( { + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> + return +} + +// ----- + +func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op requires all inputs to have the same dimensions}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op dimension attribute value must be less than input rank}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op comparator block should have 4 arguments}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 70b48fa43c9..ffcc1cc9df3 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -620,3 +620,21 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) return %0 : tensor<4xi1> } + +// ----- + +// CHECK-LABEL: HloModule +func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] { +// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT + +// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) { +// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index d7f3bf243e5..f0ba67e2fd5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -438,6 +438,38 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( return GetI64ElementsAttr(normalized_sizes, builder); } +//===----------------------------------------------------------------------===// +// Sort op utilities. +//===----------------------------------------------------------------------===// + +// Builds the region `body` for xla_hlo.sort's comparator: for each type in +// `element_types`, create two block arguments, one for lhs and one for rhs, and +// generates xla_hlo.compare op to compare them with the given `direction`. +// +// Note that this right now only does comparsion on the first pair of block +// arguments. +static void BuildSortComparisonBody(llvm::ArrayRef element_types, + StringRef direction, Region *body, + OpBuilder *builder) { + OpBuilder::InsertionGuard insertion_point_gurad(*builder); + + Block *block = builder->createBlock(body); + // Add two arguments for each element type. + for (Type element_type : element_types) { + TensorType tensor_type = RankedTensorType::get({}, element_type); + block->addArguments({tensor_type, tensor_type}); + } + + Location loc = body->getLoc(); + StringAttr compare_direction = + StringAttr::get(direction, builder->getContext()); + Value *compare = builder->create( + loc, block->getArgument(0), block->getArgument(1), + /*broadcast_dimensions=*/nullptr, compare_direction); + + builder->create(loc, compare); +} + //===----------------------------------------------------------------------===// // Op converters. //===----------------------------------------------------------------------===// @@ -1873,6 +1905,101 @@ class ConvertOneHotOp : public OpRewritePattern { } }; +// Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant. +// +// tf.TopKV2 sorts along last dimension of the input tensor and then returns +// the top K components' values and indices. This is translated into a few +// ops in XLA HLO: first generating an integer sequence for the indices, +// then sort both the original input tensor and the indices togheter, and +// at last slice out the top K components. +// +// For example, for the following IR: +// +// %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor +// %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> +// (tensor<16x8xf32>, tensor<16x8xi32>) +// +// We will get: +// +// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> +// %2 = "xla_hlo.sort"(%input, %1) ( { +// ^bb0(%arg1: tensor, %arg2: tensor, +// %arg3: tensor, %arg4: tensor): +// %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... +// "xla_hlo.return"(%7) : (tensor) -> () +// }) {dimension = 1 : i64, is_stable = true} : ... +// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ... +// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ... +// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, +// start_indices dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<16x16xf32>) -> tensor<16x8xf32> +// %6 = "xla_hlo.slice"(%4) ... +class ConvertTopKV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::TopKV2Op op, + PatternRewriter &rewriter) const override { + // We can only match when the `k` operand is a constant scalar. + DenseIntElementsAttr k_attr; + if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure(); + + // The last dimension of the input tensor's shape should be known so we can + // have clamped end_indices for slices. + TensorType input_type = op.input()->getType().cast(); + if (!input_type.hasRank()) return matchFailure(); + int64_t input_rank = input_type.getRank(); + int64_t last_dim_index = input_rank - 1; + int64_t last_dim_size = input_type.getDimSize(last_dim_index); + if (last_dim_size == ShapedType::kDynamicSize) return matchFailure(); + + // Create an Itoa op for indices. + auto i32_type = rewriter.getIntegerType(32); + Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type); + Value *iota_op = rewriter.create( + op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index)); + + // Create the sort op. It takes two inputs, one for the original input, the + // other for the indices. + auto sort_op = rewriter.create( + op.getLoc(), llvm::ArrayRef{op.input(), iota_op}, + last_dim_index, /*is_stable=*/true); + BuildSortComparisonBody({input_type.getElementType(), i32_type}, + /*direction=*/"GT", &sort_op.comparator(), + &rewriter); + + // Get the sorted input and index tuple element. + auto tuple_first_element = + rewriter.create(op.getLoc(), sort_op, 0); + auto tuple_second_element = + rewriter.create(op.getLoc(), sort_op, 1); + + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + end_indices.back() = + std::min((*k_attr.begin()).getSExtValue(), last_dim_size); + SmallVector strides(input_rank, 1); + + // Get the slice for the top K elements. + + Value *values = rewriter.create( + op.getLoc(), tuple_first_element, + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + Value *indices = rewriter.create( + op.getLoc(), tuple_second_element, + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + rewriter.replaceOp(op, {values, indices}); + return matchSuccess(); + } +}; + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { @@ -1892,10 +2019,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, - ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp, - ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp, - ConvertConv2DBackpropInputOp, ConvertConv2DBackpropFilterOp>( - op->getContext()); + ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp, + ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp, + ConvertOneHotOp, ConvertConv2DBackpropInputOp, + ConvertConv2DBackpropFilterOp>(op->getContext()); ConversionTarget target(*context); target.addLegalDialect();