Add legalization from tf.TopKV2 to XLA HLO ops
- Tightened tf.TopKV2 verification - Defined xla_hlo.sort operation - Added lowering from tf.TopKV2 to XLA HLO ops PiperOrigin-RevId: 283623396 Change-Id: Ia705c72022452617ab1209532c6408c3cb399a9c
This commit is contained in:
parent
8f46707460
commit
1c81d32696
@ -5768,6 +5768,8 @@ If two elements are equal, the lower-index element appears first.
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_TransposeOp : TF_Op<"Transpose", [NoSideEffect]> {
|
||||
|
@ -1698,6 +1698,21 @@ static LogicalResult Verify(TensorListStackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TopKV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TopKV2Op op) {
|
||||
if (!HasRankAtLeast(op.input(), 1))
|
||||
return op.emitOpError(
|
||||
"requires input operand to have at least 1 dimension");
|
||||
|
||||
if (!IsOfRankOrUnranked(op.k(), 0))
|
||||
return op.emitOpError("requires k operand to be 0D tensor");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1658,3 +1658,19 @@ func @testTernaryEinsum(%arg0: tensor<2x3xf32>){
|
||||
%0 = "tf.Einsum"(%arg0, %arg0, %arg0) {equation = "ab,cd,ef->"} : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTopKV2WrongInputRank(%input: tensor<f32>, %k: tensor<i32>) {
|
||||
// expected-error @+1 {{op requires input operand to have at least 1 dimension}}
|
||||
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<f32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTopKV2WrongKRank(%input: tensor<8xf32>, %k: tensor<5xi32>) {
|
||||
// expected-error @+1 {{op requires k operand to be 0D tensor}}
|
||||
%0:2 = "tf.TopKV2"(%input, %k) : (tensor<8xf32>, tensor<5xi32>) -> (tensor<*xf32>, tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
@ -841,6 +841,70 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
|
||||
return RankedTensorType::get(shape, ranked_ty.getElementType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SortOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SortOp::build(Builder* builder, OperationState& state,
|
||||
ArrayRef<Value*> operands, int64_t dimension,
|
||||
bool is_stable) {
|
||||
state.addOperands(operands);
|
||||
state.addAttribute("dimension", builder->getI64IntegerAttr(dimension));
|
||||
state.addAttribute("is_stable", builder->getBoolAttr(dimension));
|
||||
|
||||
SmallVector<Type, 2> element_types;
|
||||
element_types.reserve(operands.size());
|
||||
for (Value* operand : operands) element_types.push_back(operand->getType());
|
||||
state.addTypes(builder->getTupleType(element_types));
|
||||
|
||||
state.addRegion();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(SortOp op) {
|
||||
Operation::operand_range operands = op.operands();
|
||||
if (operands.empty()) return op.emitOpError("requires at least one input");
|
||||
|
||||
// TODO(antiagainst): verify partionally dynamic shapes
|
||||
if (llvm::all_of(operands, [](Value* operand) {
|
||||
return operand->getType().cast<ShapedType>().hasRank();
|
||||
})) {
|
||||
ArrayRef<int64_t> input_shape =
|
||||
(*operands.begin())->getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
|
||||
return operand->getType().cast<ShapedType>().getShape() !=
|
||||
input_shape;
|
||||
}))
|
||||
return op.emitOpError("requires all inputs to have the same dimensions");
|
||||
|
||||
if (op.dimension().getSExtValue() >= input_shape.size())
|
||||
return op.emitOpError(
|
||||
"dimension attribute value must be less than input rank");
|
||||
}
|
||||
|
||||
Block& block = op.comparator().front();
|
||||
size_t num_operands = op.getOperation()->getNumOperands();
|
||||
if (block.getNumArguments() != 2 * num_operands)
|
||||
return op.emitOpError("comparator block should have ")
|
||||
<< 2 * num_operands << " arguments";
|
||||
|
||||
for (auto indexed_operand : llvm::enumerate(operands)) {
|
||||
int index = indexed_operand.index();
|
||||
Type element_type =
|
||||
indexed_operand.value()->getType().cast<ShapedType>().getElementType();
|
||||
Type tensor_type = RankedTensorType::get({}, element_type);
|
||||
for (int i : {2 * index, 2 * index + 1}) {
|
||||
Type arg_type = block.getArgument(i)->getType();
|
||||
if (arg_type != tensor_type)
|
||||
return op.emitOpError("comparator block argument #")
|
||||
<< i << " should be of type " << tensor_type << " but got "
|
||||
<< arg_type;
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -868,6 +868,26 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter",
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp {
|
||||
let arguments = (ins
|
||||
Variadic<HLO_Tensor>:$operands,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||
);
|
||||
|
||||
let results = (outs HLO_TensorOrTuple);
|
||||
|
||||
let regions = (region SizedRegion<1>:$comparator);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &state, ArrayRef<Value *> operands, "
|
||||
"int64_t dimension, bool is_stable"
|
||||
>];
|
||||
|
||||
// TODO(b/129422361): SortOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_ReverseOp: HLO_Op<"reverse",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp {
|
||||
let arguments = (ins
|
||||
|
@ -832,6 +832,17 @@ class BASE_HLO_SelectAndScatterOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_SortOp {
|
||||
string summary = "Sort operator";
|
||||
|
||||
string description = [{
|
||||
Sorts the given `operands` at the given `dimension` with the given
|
||||
`comparator`.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#sort.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ReverseOp {
|
||||
string summary = "Reverse operator";
|
||||
|
||||
|
@ -624,6 +624,18 @@ LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
|
||||
xla::XlaComputation comparator;
|
||||
if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
|
||||
&comparator)))
|
||||
return failure();
|
||||
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator,
|
||||
op.dimension().getSExtValue(), op.is_stable());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
|
||||
|
@ -1934,3 +1934,42 @@ func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf
|
||||
// CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
|
||||
return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.TopKV2 legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: topk_v2_non_const_k
|
||||
func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
// CHECK: tf.TopKV2
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: topk_v2_unknown_input_last_dim
|
||||
func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
|
||||
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: tf.TopKV2
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
|
||||
return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: topk_v2
|
||||
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
|
||||
func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
|
||||
%k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||
|
||||
// CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64}
|
||||
// CHECK-NEXT: %[[SORT:.*]] = "xla_hlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
|
||||
// CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
|
||||
// CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"}
|
||||
// CHECK-NEXT: "xla_hlo.return"(%[[CMP]])
|
||||
// CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
// CHECK-NEXT: %[[TUPL0:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[TUPL1:.*]] = "xla_hlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[VAL:.*]] = "xla_hlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[IDX:.*]] = "xla_hlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-NEXT: return %[[VAL]], %[[IDX]]
|
||||
%0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
|
||||
return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
|
||||
}
|
||||
|
@ -416,3 +416,98 @@ func @constants() -> () {
|
||||
%3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor<i32>} : () -> (tensor<*xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// CHECK: xla_hlo.sort
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_no_operands() {
|
||||
// expected-error @+1 {{op requires at least one input}}
|
||||
%0 = "xla_hlo.sort"() ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op dimension attribute value must be less than input rank}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op comparator block should have 4 arguments}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
@ -620,3 +620,21 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] {
|
||||
// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT
|
||||
|
||||
// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) {
|
||||
// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]]
|
||||
|
@ -438,6 +438,38 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
|
||||
return GetI64ElementsAttr(normalized_sizes, builder);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sort op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Builds the region `body` for xla_hlo.sort's comparator: for each type in
|
||||
// `element_types`, create two block arguments, one for lhs and one for rhs, and
|
||||
// generates xla_hlo.compare op to compare them with the given `direction`.
|
||||
//
|
||||
// Note that this right now only does comparsion on the first pair of block
|
||||
// arguments.
|
||||
static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
|
||||
StringRef direction, Region *body,
|
||||
OpBuilder *builder) {
|
||||
OpBuilder::InsertionGuard insertion_point_gurad(*builder);
|
||||
|
||||
Block *block = builder->createBlock(body);
|
||||
// Add two arguments for each element type.
|
||||
for (Type element_type : element_types) {
|
||||
TensorType tensor_type = RankedTensorType::get({}, element_type);
|
||||
block->addArguments({tensor_type, tensor_type});
|
||||
}
|
||||
|
||||
Location loc = body->getLoc();
|
||||
StringAttr compare_direction =
|
||||
StringAttr::get(direction, builder->getContext());
|
||||
Value *compare = builder->create<xla_hlo::CompareOp>(
|
||||
loc, block->getArgument(0), block->getArgument(1),
|
||||
/*broadcast_dimensions=*/nullptr, compare_direction);
|
||||
|
||||
builder->create<xla_hlo::ReturnOp>(loc, compare);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op converters.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1873,6 +1905,101 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
|
||||
//
|
||||
// tf.TopKV2 sorts along last dimension of the input tensor and then returns
|
||||
// the top K components' values and indices. This is translated into a few
|
||||
// ops in XLA HLO: first generating an integer sequence for the indices,
|
||||
// then sort both the original input tensor and the indices togheter, and
|
||||
// at last slice out the top K components.
|
||||
//
|
||||
// For example, for the following IR:
|
||||
//
|
||||
// %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||
// %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
|
||||
// (tensor<16x8xf32>, tensor<16x8xi32>)
|
||||
//
|
||||
// We will get:
|
||||
//
|
||||
// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
|
||||
// %2 = "xla_hlo.sort"(%input, %1) ( {
|
||||
// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
|
||||
// %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||
// %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
|
||||
// "xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
// }) {dimension = 1 : i64, is_stable = true} : ...
|
||||
// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
|
||||
// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
|
||||
// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
|
||||
// start_indices dense<0> : tensor<2xi64>,
|
||||
// strides = dense<1> : tensor<2xi64>} :
|
||||
// (tensor<16x16xf32>) -> tensor<16x8xf32>
|
||||
// %6 = "xla_hlo.slice"(%4) ...
|
||||
class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TF::TopKV2Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We can only match when the `k` operand is a constant scalar.
|
||||
DenseIntElementsAttr k_attr;
|
||||
if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure();
|
||||
|
||||
// The last dimension of the input tensor's shape should be known so we can
|
||||
// have clamped end_indices for slices.
|
||||
TensorType input_type = op.input()->getType().cast<TensorType>();
|
||||
if (!input_type.hasRank()) return matchFailure();
|
||||
int64_t input_rank = input_type.getRank();
|
||||
int64_t last_dim_index = input_rank - 1;
|
||||
int64_t last_dim_size = input_type.getDimSize(last_dim_index);
|
||||
if (last_dim_size == ShapedType::kDynamicSize) return matchFailure();
|
||||
|
||||
// Create an Itoa op for indices.
|
||||
auto i32_type = rewriter.getIntegerType(32);
|
||||
Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
|
||||
Value *iota_op = rewriter.create<xla_hlo::IotaOp>(
|
||||
op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
|
||||
|
||||
// Create the sort op. It takes two inputs, one for the original input, the
|
||||
// other for the indices.
|
||||
auto sort_op = rewriter.create<xla_hlo::SortOp>(
|
||||
op.getLoc(), llvm::ArrayRef<Value *>{op.input(), iota_op},
|
||||
last_dim_index, /*is_stable=*/true);
|
||||
BuildSortComparisonBody({input_type.getElementType(), i32_type},
|
||||
/*direction=*/"GT", &sort_op.comparator(),
|
||||
&rewriter);
|
||||
|
||||
// Get the sorted input and index tuple element.
|
||||
auto tuple_first_element =
|
||||
rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
|
||||
auto tuple_second_element =
|
||||
rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
|
||||
|
||||
SmallVector<int64_t, 4> begin_indices(input_rank, 0);
|
||||
auto end_indices = llvm::to_vector<4>(input_type.getShape());
|
||||
end_indices.back() =
|
||||
std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
|
||||
SmallVector<int64_t, 4> strides(input_rank, 1);
|
||||
|
||||
// Get the slice for the top K elements.
|
||||
|
||||
Value *values = rewriter.create<xla_hlo::SliceOp>(
|
||||
op.getLoc(), tuple_first_element,
|
||||
GetI64ElementsAttr(begin_indices, &rewriter),
|
||||
GetI64ElementsAttr(end_indices, &rewriter),
|
||||
GetI64ElementsAttr(strides, &rewriter));
|
||||
|
||||
Value *indices = rewriter.create<xla_hlo::SliceOp>(
|
||||
op.getLoc(), tuple_second_element,
|
||||
GetI64ElementsAttr(begin_indices, &rewriter),
|
||||
GetI64ElementsAttr(end_indices, &rewriter),
|
||||
GetI64ElementsAttr(strides, &rewriter));
|
||||
|
||||
rewriter.replaceOp(op, {values, indices});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
|
||||
|
||||
LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
@ -1892,10 +2019,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
ConvertSigmoidOp, ConvertSizeOp, ConvertMaxPoolOp, ConvertRangeOp,
|
||||
ConvertSigmoidOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp,
|
||||
ConvertStridedSliceOp, ConvertMeanOp, ConvertSumOp, ConvertMaxOp,
|
||||
ConvertTileOp, ConvertMaxPoolGradOp, ConvertOneHotOp,
|
||||
ConvertConv2DBackpropInputOp, ConvertConv2DBackpropFilterOp>(
|
||||
op->getContext());
|
||||
ConvertStridedSliceOp, ConvertTopKV2Op, ConvertMeanOp,
|
||||
ConvertSumOp, ConvertMaxOp, ConvertTileOp, ConvertMaxPoolGradOp,
|
||||
ConvertOneHotOp, ConvertConv2DBackpropInputOp,
|
||||
ConvertConv2DBackpropFilterOp>(op->getContext());
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
|
Loading…
Reference in New Issue
Block a user