From 46379cc745e109464cd34c7422776d92b64d8ea4 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 11 Jan 2021 15:15:14 -0800 Subject: [PATCH] This PR adds `tf.SizeOp` lowering pattern to make TFLite users can use `tf.size` without `SELECT_TF_OPS`. Because number of elements is the product of each dimension size, it can be lowered to `Prod(Shape, reduction_indices=0, keep_dims=false)`. The pattern is safe to apply any kind of Tensor. For ranked tensor with static shape, operations can be folded to a constant. For ranked tensor with dynamic shape or unranked tensor, the pattern does as what it is. --- .../mlir/tensorflow/tests/lower_tf.mlir | 20 ++++++ .../mlir/tensorflow/transforms/lower_tf.td | 20 +++++- .../compiler/mlir/xla/tests/legalize-tf.mlir | 70 ------------------- .../mlir/xla/transforms/legalize_tf.cc | 55 +-------------- 4 files changed, 39 insertions(+), 126 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index d0fc9b1a351..605a71f1476 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -946,3 +946,23 @@ func @xlogy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return %[[RESULT]] return %0 : tensor<*xf32> } + +// CHECK-LABEL: size_to_prod_shape_i32 +func @size_to_prod_shape_i32(%arg0 : tensor<1x?x2x3xf32>) -> tensor { + %0 = "tf.Size"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor + return %0 : tensor + // CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<4xi32> + // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi32>, tensor) -> tensor + // CHECK: return %[[PROD]] +} + +// CHECK-LABEL: size_to_prod_shape_i64 +func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor { + %0 = "tf.Size"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor + return %0 : tensor + // CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<4xi64> + // CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor) -> tensor + // CHECK: return %[[PROD]] +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 1dbae15ea68..4b7f5ce82e0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -37,6 +37,12 @@ class GetI64ScalarElementsAttr : def TrueBoolAttr : AttrConstraint>; +def CreateTFShapeOp : NativeCodeCall< + "$_builder.create($0.getLoc(), $1, $2)">; + +def IsI32 : NativeCodeCall< + "$_builder.getBoolAttr(getElementTypeOrSelf($0.getType()).isInteger(32))">; + //===----------------------------------------------------------------------===// // BiasAddGrad op patterns. //===----------------------------------------------------------------------===// @@ -252,6 +258,18 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs), (TF_ConstOp (GetScalarOfType<-2> $rhs))))>; +//===----------------------------------------------------------------------===// +// Size op patterns. +//===----------------------------------------------------------------------===// + +// Size(x) = Prod(Shape(x), reduction_indices=0, keep_dims=false) +def : Pat<(TF_SizeOp:$res $arg), + (TF_ProdOp + (CreateTFShapeOp $res, $arg, (IsI32 $res)), + /*reduction_indices=*/(TF_ConstOp (GetScalarOfType<0> $res)), + /*keep_dims=*/ConstBoolAttrFalse)>; + + //===----------------------------------------------------------------------===// // TanhGrad op patterns. //===----------------------------------------------------------------------===// @@ -279,8 +297,6 @@ def LowerFakeQuantWithMinMaxArgs : // ZerosLike op patterns. //===----------------------------------------------------------------------===// -def CreateTFShapeOp : NativeCodeCall< - "$_builder.create($0.getLoc(), $1, $2)">; class LowerInitializationOp : Pat<(FromOp:$src_op diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 9055385d52d..b281df37d1f 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -3816,76 +3816,6 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { return %result : tensor<10xf32> } -//===----------------------------------------------------------------------===// -// tf.Size legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @size_scalar_i32 -func @size_scalar_i32(%input: tensor) -> (tensor) { - // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> - // CHECK-SAME: tensor - // CHECK: %[[CAST:.*]] = tensor.cast %[[CONST]] : tensor to tensor - %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor) -> tensor - // CHECK: return %[[CAST]] - return %size : tensor -} - -// CHECK-LABEL: @size_scalar_i64 -func @size_scalar_i64(%input: tensor) -> (tensor) { - // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> - // CHECK-SAME: tensor - // CHECK: %[[CAST:.*]] = tensor.cast %[[CONST]] : tensor to tensor - %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor) -> tensor - // CHECK: return %[[CAST]] - return %size : tensor -} - -// CHECK-LABEL: @size_rank_one_i64 -// CHECK-SAME: (%[[INPUT:.*]]: tensor) -func @size_rank_one_i64(%input: tensor) -> (tensor) { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<1> - // CHECK-SAME: tensor - - // CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) - // CHECK-SAME: dimension = 0 - // CHECK-SAME: tensor - - // CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor) -> tensor - // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply %[[INIT]], %[[CAST_DIM_0]] - - %size = "tf.Size"(%input) : (tensor) -> tensor - // CHECK: return %[[RESULT]] - return %size : tensor -} - -// CHECK-LABEL: @size_ranked -// CHECK-SAME: (%[[INPUT:.*]]: tensor<2x?x8xf32>) -func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { - // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> - // CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) - // CHECK-SAME: dimension = 0 - // CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor) -> tensor - // CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[CAST_DIM_0]] - // CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) - // CHECK-SAME: dimension = 1 - // CHECK: %[[CAST_DIM_1:.*]] = "mhlo.convert"(%[[DIM_1]]) : (tensor) -> tensor - // CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[CAST_DIM_1]] - // CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) - // CHECK-SAME: dimension = 2 - // CHECK: %[[CAST_DIM_2:.*]] = "mhlo.convert"(%[[DIM_2]]) : (tensor) -> tensor - // CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[CAST_DIM_2]] - %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor - // CHECK: return %[[MUL_2]] - return %size : tensor -} - -// CHECK-LABEL: @size_unranked -func @size_unranked(%input: tensor<*xf32>) -> (tensor) { - // CHECK: tf.Size - %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor - return %size : tensor -} - //===----------------------------------------------------------------------===// // tf.Split legalization //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index c16290662d2..0e998d4e734 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -2751,59 +2751,6 @@ class ConvertSoftmaxOp : public OpRewritePattern { } }; -// Converts Size to HLO ops, computing the size of a ranked input tensor. -// TODO(b/145253252): Update this to not require ranked input tensor shapes. -// -// The main logic of this pattern is to calculate the size by multiplying every -// dimension of the input tensor's shape together. -// -// For example, the following source IR: -// -// %size = "tf.Size"(%input) : (tensor<2x?x8xf32>) -> tensor -// -// will be converted into: -// -// %const = mhlo.constant dense<1> : tensor -// %dim_0 = "mhlo.get_dimension_size"(%input) {dimension = 0 : i64} : -// (tensor<2x?x8xf32>) -> tensor -// %prod_0 = mhlo.multiply %const, %dim_0 : tensor -// %dim_1 = "mhlo.get_dimension_size"(%input) {dimension = 1 : i64} : -// (tensor<2x?x8xf32>) -> tensor -// %prod_1 = mhlo.multiply %prod_0, %dim_1 : tensor -// %dim_2 = "mhlo.get_dimension_size"(%input) {dimension = 2 : i64} : -// (tensor<2x?x8xf32>) -> tensor -// %size = mhlo.multiply %prod_1, %dim_2 : tensor -class ConvertSizeOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SizeOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - auto input_ty = input.getType().dyn_cast(); - if (!input_ty) return failure(); - - const int64_t rank = input_ty.getRank(); - auto result_ty = op.getResult().getType(); - auto element_ty = result_ty.cast().getElementType(); - Value size = GetScalarConstOfType(element_ty, op.getLoc(), 1, &rewriter); - for (int64_t i = 0; i < rank; ++i) { - auto i32_ty = rewriter.getIntegerType(32); - auto size_ty = RankedTensorType::get({}, i32_ty); - auto dim_index = rewriter.getI64IntegerAttr(i); - Value dim = rewriter.create(op.getLoc(), size_ty, - input, dim_index); - dim = rewriter.create(op.getLoc(), result_ty, dim); - size = rewriter.create( - op.getLoc(), size, dim, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); - } - rewriter.replaceOp(op, size); - - return success(); - } -}; - static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, Value *out_lhs, Value *out_rhs, PatternRewriter *rewriter) { @@ -6228,7 +6175,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op, - ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, + ConvertSigmoidOp, ConvertShapeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,