Add legalization from tf.TopKV2 to XLA HLO ops

- Tightened tf.TopKV2 verification
- Defined xla_hlo.sort operation
- Added lowering from tf.TopKV2 to XLA HLO ops

PiperOrigin-RevId: 283623396
Change-Id: Ia705c72022452617ab1209532c6408c3cb399a9c
This commit is contained in:
Lei Zhang 2019-12-03 14:30:41 -08:00 committed by TensorFlower Gardener
parent 8f46707460
commit 1c81d32696
11 changed files with 423 additions and 4 deletions

View File

@ -5768,6 +5768,8 @@ If two elements are equal, the lower-index element appears first.
); );
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let verifier = [{ return Verify(*this); }];
} }
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> { def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {

View File

@ -1698,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TopKV2Op op) {
if (!HasRankAtLeast(op.input(), 1))
return op.emitOpError(
"requires input operand to have at least 1 dimension");
if (!IsOfRankOrUnranked(op.k(), 0))
return op.emitOpError("requires k operand to be 0D tensor");
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TransposeOp // TransposeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1658,3 +1658,19 @@ func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
%0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>) %0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
return return
} }
// -----
func @testTopKV2WrongInputRank(%input: tensor<f32>, %k: tensor<i32>) {
// expected-error @+1 {{op requires input operand to have at least 1 dimension}}
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<f32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return
}
// -----
func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
// expected-error @+1 {{op requires k operand to be 0D tensor}}
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
return
}

View File

@ -841,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
return RankedTensorType::get(shape, ranked_ty.getElementType()); return RankedTensorType::get(shape, ranked_ty.getElementType());
} }
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
void SortOp::build(Builder* builder, OperationState& state,
ArrayRef<Value*> operands, int64_t dimension,
bool is_stable) {
state.addOperands(operands);
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
SmallVector<Type, 2> element_types;
element_types.reserve(operands.size());
for (Value* operand : operands) element_types.push_back(operand->getType());
state.addTypes(builder->getTupleType(element_types));
state.addRegion();
}
static LogicalResult Verify(SortOp op) {
Operation::operand_range operands = op.operands();
if (operands.empty()) return op.emitOpError("requires at least one input");
// TODO(antiagainst): verify partionally dynamic shapes
if (llvm::all_of(operands, [](Value* operand) {
return operand->getType().cast<ShapedType>().hasRank();
})) {
ArrayRef<int64_t> input_shape =
(*operands.begin())->getType().cast<ShapedType>().getShape();
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
return operand->getType().cast<ShapedType>().getShape() !=
input_shape;
}))
return op.emitOpError("requires all inputs to have the same dimensions");
if (op.dimension().getSExtValue() >= input_shape.size())
return op.emitOpError(
"dimension attribute value must be less than input rank");
}
Block& block = op.comparator().front();
size_t num_operands = op.getOperation()->getNumOperands();
if (block.getNumArguments() != 2 * num_operands)
return op.emitOpError("comparator block should have ")
<< 2 * num_operands << " arguments";
for (auto indexed_operand : llvm::enumerate(operands)) {
int index = indexed_operand.index();
Type element_type =
indexed_operand.value()->getType().cast<ShapedType>().getElementType();
Type tensor_type = RankedTensorType::get({}, element_type);
for (int i : {2 * index, 2 * index + 1}) {
Type arg_type = block.getArgument(i)->getType();
if (arg_type != tensor_type)
return op.emitOpError("comparator block argument #")
<< i << " should be of type " << tensor_type << " but got "
<< arg_type;
}
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TransposeOp // TransposeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -868,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
let arguments = (ins
Variadic<HLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let results = (outs HLO_TensorOrTuple);
let regions = (region SizedRegion<1>:$comparator);
let builders = [OpBuilder<
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
"int64_t dimension, bool is_stable"
>];
// TODO(b/129422361): SortOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
def HLO_ReverseOp: HLO_Op<"reverse", def HLO_ReverseOp: HLO_Op<"reverse",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp { [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
let arguments = (ins let arguments = (ins

View File

@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp {
}]; }];
} }
class BASE_HLO_SortOp {
string summary = "Sort operator";
string description = [{
Sorts the given `operands` at the given `dimension` with the given
`comparator`.
See https://www.tensorflow.org/xla/operation_semantics#sort.
}];
}
class BASE_HLO_ReverseOp { class BASE_HLO_ReverseOp {
string summary = "Reverse operator"; string summary = "Reverse operator";

View File

@ -624,6 +624,18 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
return failure(); return failure();
} }
LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
xla::XlaComputation comparator;
if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
&comparator)))
return failure();
auto& value_map = *ctx.values;
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
op.dimension().getSExtValue(), op.is_stable());
return success();
}
LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values; auto& value_map = *ctx.values;
value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx)); value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));

View File

@ -1934,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
} }
//===----------------------------------------------------------------------===//
// tf.TopKV2 legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: topk_v2_non_const_k
func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
// CHECK: tf.TopKV2
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
}
// CHECK-LABEL: topk_v2_unknown_input_last_dim
func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
// CHECK: tf.TopKV2
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
}
// CHECK-LABEL: topk_v2
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64}
// CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
// CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
// CHECK-NEXT: "xla_hlo.return"(%[[CMP]])
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
// CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
// CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
// CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-NEXT: return %[[VAL]], %[[IDX]]
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
}

View File

@ -416,3 +416,98 @@ func @constants() -> () {
%3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>) %3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>)
return return
} }
// -----
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: xla_hlo.sort
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_no_operands() {
// expected-error @+1 {{op requires at least one input}}
%0 = "xla_hlo.sort"() ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
%7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
return
}
// -----
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op dimension attribute value must be less than input rank}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block should have 4 arguments}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// -----
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}

View File

@ -620,3 +620,21 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
return %0 : tensor<4xi1> return %0 : tensor<4xi1>
} }
// -----
// CHECK-LABEL: HloModule
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
%0 = "xla_hlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]

View File

@ -438,6 +438,38 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
return GetI64ElementsAttr(normalized_sizes, builder); return GetI64ElementsAttr(normalized_sizes, builder);
} }
//===----------------------------------------------------------------------===//
// Sort op utilities.
//===----------------------------------------------------------------------===//
// Builds the region `body` for xla_hlo.sort's comparator: for each type in
// `element_types`, create two block arguments, one for lhs and one for rhs, and
// generates xla_hlo.compare op to compare them with the given `direction`.
//
// Note that this right now only does comparsion on the first pair of block
// arguments.
static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
StringRef direction, Region *body,
OpBuilder *builder) {
OpBuilder::InsertionGuard insertion_point_gurad(*builder);
Block *block = builder->createBlock(body);
// Add two arguments for each element type.
for (Type element_type : element_types) {
TensorType tensor_type = RankedTensorType::get({}, element_type);
block->addArguments({tensor_type, tensor_type});
}
Location loc = body->getLoc();
StringAttr compare_direction =
StringAttr::get(direction, builder->getContext());
Value *compare = builder->create<xla_hlo::CompareOp>(
loc, block->getArgument(0), block->getArgument(1),
/*broadcast_dimensions=*/nullptr, compare_direction);
builder->create<xla_hlo::ReturnOp>(loc, compare);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op converters. // Op converters.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1873,6 +1905,101 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
} }
}; };
// Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
//
// tf.TopKV2 sorts along last dimension of the input tensor and then returns
// the top K components' values and indices. This is translated into a few
// ops in XLA HLO: first generating an integer sequence for the indices,
// then sort both the original input tensor and the indices togheter, and
// at last slice out the top K components.
//
// For example, for the following IR:
//
// %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
// %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
// (tensor<16x8xf32>, tensor<16x8xi32>)
//
// We will get:
//
// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
// %2 = "xla_hlo.sort"(%input, %1) ( {
// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
// %arg3: tensor<i32>, %arg4: tensor<i32>):
// %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
// "xla_hlo.return"(%7) : (tensor<i1>) -> ()
// }) {dimension = 1 : i64, is_stable = true} : ...
// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
// start_indices dense<0> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<16x16xf32>) -> tensor<16x8xf32>
// %6 = "xla_hlo.slice"(%4) ...
class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
public:
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(TF::TopKV2Op op,
PatternRewriter &rewriter) const override {
// We can only match when the `k` operand is a constant scalar.
DenseIntElementsAttr k_attr;
if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure();
// The last dimension of the input tensor's shape should be known so we can
// have clamped end_indices for slices.
TensorType input_type = op.input()->getType().cast<TensorType>();
if (!input_type.hasRank()) return matchFailure();
int64_t input_rank = input_type.getRank();
int64_t last_dim_index = input_rank - 1;
int64_t last_dim_size = input_type.getDimSize(last_dim_index);
if (last_dim_size == ShapedType::kDynamicSize) return matchFailure();
// Create an Itoa op for indices.
auto i32_type = rewriter.getIntegerType(32);
Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
Value *iota_op = rewriter.create<xla_hlo::IotaOp>(
op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
// Create the sort op. It takes two inputs, one for the original input, the
// other for the indices.
auto sort_op = rewriter.create<xla_hlo::SortOp>(
op.getLoc(), llvm::ArrayRef<Value *>{op.input(), iota_op},
last_dim_index, /*is_stable=*/true);
BuildSortComparisonBody({input_type.getElementType(), i32_type},
/*direction=*/"GT", &sort_op.comparator(),
&rewriter);
// Get the sorted input and index tuple element.
auto tuple_first_element =
rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
auto tuple_second_element =
rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
SmallVector<int64_t, 4> begin_indices(input_rank, 0);
auto end_indices = llvm::to_vector<4>(input_type.getShape());
end_indices.back() =
std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
SmallVector<int64_t, 4> strides(input_rank, 1);
// Get the slice for the top K elements.
Value *values = rewriter.create<xla_hlo::SliceOp>(
op.getLoc(), tuple_first_element,
GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter));
Value *indices = rewriter.create<xla_hlo::SliceOp>(
op.getLoc(), tuple_second_element,
GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter));
rewriter.replaceOp(op, {values, indices});
return matchSuccess();
}
};
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
@ -1892,10 +2019,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
ConvertSigmoidOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>, ConvertSigmoidOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp,
ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp, ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp,
ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp, ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp,
ConvertConv2DBackpropInputOp, ConvertConv2DBackpropFilterOp>( ConvertOneHotOp, ConvertConv2DBackpropInputOp,
op->getContext()); ConvertConv2DBackpropFilterOp>(op->getContext());
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>(); target.addLegalDialect<XlaHloDialect>();