Convert TF Broadcast op to the corresponding TFL Broadcast op when 5+ dimension inputs are given.
Low dimension, where its rank is at most 4, will be handled by the Mul op as usual in order not to break acceleration support. PiperOrigin-RevId: 340567733 Change-Id: I764aa75f2fdd701478f78904cde72d8f5b97de5d
This commit is contained in:
parent
4565291a98
commit
2f960b4bc7
@ -265,6 +265,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
|
@ -4407,4 +4407,55 @@ def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankAtMost<0, 8>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
PredOpTrait<"output dimension count must be at most 8",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Broadcast an array for a compatible shape.";
|
||||
|
||||
let description = [{
|
||||
Broadcasting is the process of making arrays to have compatible shapes
|
||||
for arithmetic operations. Two shapes are compatible if for each
|
||||
dimension pair they are either equal or one of them is one. When trying
|
||||
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
|
||||
and works its way forward.
|
||||
|
||||
For example,
|
||||
|
||||
>>> x = tf.constant([1, 2, 3])
|
||||
>>> y = tf.broadcast_to(x, [3, 3])
|
||||
>>> print(y)
|
||||
tf.Tensor(
|
||||
[[1 2 3]
|
||||
[1 2 3]
|
||||
[1 2 3]], shape=(3, 3), dtype=int32)
|
||||
|
||||
In the above example, the input Tensor with the shape of `[1, 3]`
|
||||
is broadcasted to output Tensor with shape of `[3, 3]`.
|
||||
|
||||
When doing broadcasted operations such as multiplying a tensor
|
||||
by a scalar, broadcasting (usually) confers some time or space
|
||||
benefit, as the broadcasted tensor is never materialized.
|
||||
|
||||
However, `broadcast_to` does not carry with it any such benefits.
|
||||
The newly-created tensor takes the full memory of the broadcasted
|
||||
shape. (In a graph context, `broadcast_to` might be fused to
|
||||
subsequent operation and then be optimized away, however.)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
|
||||
TFL_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -1520,6 +1520,24 @@ func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||
// CHECK: }
|
||||
|
||||
func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} :
|
||||
(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32>
|
||||
|
@ -2419,3 +2419,21 @@ func @invalid_two_dynamic_dims_on_reshape(%arg0: tensor<3x4xi32>, %arg1: tensor<
|
||||
%0 = "tfl.reshape"(%arg0, %arg1) : (tensor<3x4xi32>, tensor<?x?x4xi32>) -> tensor<1x3x4xi32>
|
||||
return %0 : tensor<1x3x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
|
||||
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
|
||||
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
@ -571,6 +571,73 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
// CHECK: return [[MUL]] : tensor<*xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
return %0: tensor<7x8x1x2x3x4x5x6xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
// CHECK: return [[BCT]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
|
||||
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: xla_conv
|
||||
func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1")
|
||||
@ -647,3 +714,4 @@ func @DoNotConvertConv2DWhenFilterTypeDimIsNotDecided(%arg0 : tensor<?x?x?x96xf3
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -116,6 +116,9 @@ def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim),
|
||||
def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim),
|
||||
(TFL_ArgMinOp $input, $dim)>;
|
||||
|
||||
def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim),
|
||||
(TFL_BroadcastToOp $input, $dim)>;
|
||||
|
||||
def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
|
||||
|
||||
def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;
|
||||
|
@ -716,9 +716,9 @@ struct ConvertTFBroadcastTo : public RewritePattern {
|
||||
|
||||
// Allow lowering when low dimension inputs are given and its type is F32 or
|
||||
// I32.
|
||||
if (!((output_type.hasRank() && output_type.getRank() <= 5) ||
|
||||
if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
|
||||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
|
||||
shape_type.getDimSize(0) <= 5)))
|
||||
shape_type.getDimSize(0) <= 4)))
|
||||
return failure();
|
||||
|
||||
if (!(element_type.isa<BFloat16Type, Float32Type>() ||
|
||||
|
Loading…
Reference in New Issue
Block a user