From 2f960b4bc79af52caabcb664d615b5b3e94db10f Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Tue, 3 Nov 2020 19:11:27 -0800 Subject: [PATCH] Convert TF Broadcast op to the corresponding TFL Broadcast op when 5+ dimension inputs are given. Low dimension, where its rank is at most 4, will be handled by the Mul op as usual in order not to break acceleration support. PiperOrigin-RevId: 340567733 Change-Id: I764aa75f2fdd701478f78904cde72d8f5b97de5d --- tensorflow/compiler/mlir/lite/BUILD | 1 + tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 51 ++++++++++++++ .../compiler/mlir/lite/tests/legalize-tf.mlir | 18 +++++ tensorflow/compiler/mlir/lite/tests/ops.mlir | 18 +++++ .../compiler/mlir/lite/tests/prepare-tf.mlir | 68 +++++++++++++++++++ .../mlir/lite/transforms/legalize_patterns.td | 3 + .../mlir/lite/transforms/prepare_tf.cc | 4 +- 7 files changed, 161 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 9667b58613a..76fde446b15 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -265,6 +265,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", "//tensorflow/stream_executor/lib", diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 539c28133b9..84a46c3be30 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4407,4 +4407,55 @@ def TFL_CustomTfOp : Op:$body); } +def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRankAtMost<0, 8>, + TFL_OperandHasRank<1, 1>, + PredOpTrait<"output dimension count must be at most 8", + Or<[TFL_OperandIsUnrankedPred<1>, + TFL_OperandDimIsAtMost<1, 0, 8>]>>, + NoSideEffect]> { + let summary = "Broadcast an array for a compatible shape."; + + let description = [{ +Broadcasting is the process of making arrays to have compatible shapes +for arithmetic operations. Two shapes are compatible if for each +dimension pair they are either equal or one of them is one. When trying +to broadcast a Tensor to a shape, it starts with the trailing dimensions, +and works its way forward. + +For example, + +>>> x = tf.constant([1, 2, 3]) +>>> y = tf.broadcast_to(x, [3, 3]) +>>> print(y) +tf.Tensor( + [[1 2 3] + [1 2 3] + [1 2 3]], shape=(3, 3), dtype=int32) + +In the above example, the input Tensor with the shape of `[1, 3]` +is broadcasted to output Tensor with shape of `[3, 3]`. + +When doing broadcasted operations such as multiplying a tensor +by a scalar, broadcasting (usually) confers some time or space +benefit, as the broadcasted tensor is never materialized. + +However, `broadcast_to` does not carry with it any such benefits. +The newly-created tensor takes the full memory of the broadcasted +shape. (In a graph context, `broadcast_to` might be fused to +subsequent operation and then be optimized away, however.) + }]; + + let arguments = (ins + TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex>]>:$input, + TFL_I32OrI64Tensor:$shape + ); + + let results = (outs + TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex>]>:$output + ); +} + #endif // TFL_OPS diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 97a496a0b89..b712278fb17 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1520,6 +1520,24 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) { // CHECK: return [[VAL_4]] : tensor<28x1x28xf32> // CHECK: } +func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_f32 +// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> +// CHECK: return [[BCT]] : tensor<3x3xf32> +} + +func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0: tensor<3x3xi32> + +// CHECK-LABEL: broadcast_to_i32 +// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> +// CHECK: return [[BCT]] : tensor<3x3xi32> +} + func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 341dbe3359e..3a98f6db0c4 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -2419,3 +2419,21 @@ func @invalid_two_dynamic_dims_on_reshape(%arg0: tensor<3x4xi32>, %arg1: tensor< %0 = "tfl.reshape"(%arg0, %arg1) : (tensor<3x4xi32>, tensor) -> tensor<1x3x4xi32> return %0 : tensor<1x3x4xi32> } + +// ----- + +// CHECK-LABEL: testBroadcastToWithI32ShapeTensor +func @testBroadcastToWithI32ShapeTensor(tensor, tensor<8xi32>) -> tensor { +^bb0(%arg0: tensor, %arg1: tensor<8xi32>): + // CHECK: "tfl.broadcast_to"(%arg0, %arg1) + %0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor, tensor<8xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: testBroadcastToWithI64ShapeTensor +func @testBroadcastToWithI64ShapeTensor(tensor, tensor<8xi64>) -> tensor { +^bb0(%arg0: tensor, %arg1: tensor<8xi64>): + // CHECK: "tfl.broadcast_to"(%arg0, %arg1) + %0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor, tensor<8xi64>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 88de48cf1f9..88582bf17a2 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -571,6 +571,73 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> // CHECK: return %[[RES]] } +func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_f32_low_dim +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +// CHECK: return [[MUL]] : tensor<3x3xf32> +} + +func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> + return %0: tensor<3x3xi32> + +// CHECK-LABEL: broadcast_to_i32_low_dim +// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> +// CHECK: return [[MUL]] : tensor<3x3xi32> +} + +func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> + return %0: tensor<3x3xf32> + +// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape +// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> +// CHECK: return [[MUL]] : tensor<3x3xf32> +} + +func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> { + %0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> + +// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output +// CHECK: [[CST:%.*]] = constant dense<1> : tensor +// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor) -> tensor<*xi32> +// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> +// CHECK: return [[MUL]] : tensor<*xi32> +} + +func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> + return %0: tensor<7x8x1x2x3x4x5x6xf32> + +// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape +// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> +// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32> +} + +func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output +// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32> +// CHECK: return [[BCT]] : tensor<*xf32> +} + +func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output +// CHECK: "tf.BroadcastTo"(%arg0, %arg1) +} + // CHECK-LABEL: xla_conv func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1") @@ -647,3 +714,4 @@ func @DoNotConvertConv2DWhenFilterTypeDimIsNotDecided(%arg0 : tensor; +def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim), + (TFL_BroadcastToOp $input, $dim)>; + def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>; def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 3fb5c2cc6f7..5671bfa547d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -716,9 +716,9 @@ struct ConvertTFBroadcastTo : public RewritePattern { // Allow lowering when low dimension inputs are given and its type is F32 or // I32. - if (!((output_type.hasRank() && output_type.getRank() <= 5) || + if (!((output_type.hasRank() && output_type.getRank() <= 4) || (shape_type.hasStaticShape() && shape_type.getRank() == 1 && - shape_type.getDimSize(0) <= 5))) + shape_type.getDimSize(0) <= 4))) return failure(); if (!(element_type.isa() ||