Add new pattern matching rules for TFL Broadcastable ops, which have high dimenion inputs.

- Low dimensions <= 4, for example, should avoid adding BroadcastTo op as usual, which means we need to lower broadcast-able ops as before without BroadcastTo op and lower the BroadcastTo op to the hardware accelerator supported ops as well.

- This explicit BroadcastTo op needs to be inserted only when a higher dimension is needed, which will unlock the new opportunity.

- There are the broadcast-able 20 TFLite ops (about 15 % of TFLite op set) as the followings:

  Comparison: LessEqual, GreaterEqual, NotEqual, Greater, Less, Equal (up to four dim.)
  Activation: PRelu (up to four dim.)
  Arithmetic: Add, Mul, Div, Sub, FloorDiv, FloorMod, Pow, Maximum, Minimum,
                    SquaredDifference (up to four or five dim.)
  Dimension: SelectV2  (up to four dim.), BroadcastTo (supported via lowering)

PiperOrigin-RevId: 340575786
Change-Id: I007b19487512560e1042e99321b0a37e3123c0f4
This commit is contained in:
Jaesung Chung 2020-11-03 20:40:50 -08:00 committed by TensorFlower Gardener
parent 83ab6303b0
commit 247746808e
6 changed files with 407 additions and 48 deletions
tensorflow
compiler/mlir/lite
lite/testing/op_tests

View File

@ -148,18 +148,10 @@ bool IsI64Type(Type element_type) {
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
IsQUI8Type(element_type) || IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
@ -211,20 +203,13 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
/*max_bcast_rank=*/4);
}
// Allows F32 output when the operands have valid shapes, which are
// broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
// are broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);

View File

@ -25,13 +25,6 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
}
// CHECK-LABEL: testAddTooHighBroadcastableDims
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32>
@ -1568,7 +1561,11 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
return %0 : tensor<1x1x1x2x3x4xf32>
// CHECK-LABEL: select_v2_with_6d_broadcasting
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]])
}
// -----
@ -1578,7 +1575,9 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
return %0 : tensor<1x1x1x1x8x16xf32>
// CHECK-LABEL: maximum_with_6d_broadcasting
// CHECK: "tf.Maximum"(%arg0, %arg1)
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.maximum"(%arg0, [[BCT]])
}
// -----
@ -1587,7 +1586,171 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
return %0 : tensor<1x1x1x3x4xi32>
// CHECK-LABEL: add_with_int32_5d_inputs
// CHECK: "tf.Add"(%arg0, %arg1)
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.add [[BCT]], [[BCT_0]]
}
// CHECK-LABEL: testAddWithBroadcastToOps
func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testSubWithBroadcastToOps
func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMulWithBroadcastToOps
func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testDivWithBroadcastToOps
func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testFloorDivWithBroadcastToOps
func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testFloorModWithBroadcastToOps
func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testPowWithBroadcastToOps
func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMaximumWithBroadcastToOps
func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMinimumWithBroadcastToOps
func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testSelectV2WithBroadcastToOps
func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
// CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]])
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testLessEqualWithBroadcastToOps
func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testGreaterEqualWithBroadcastToOps
func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testEqualWithBroadcastToOps
func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testNotEqualWithBroadcastToOps
func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testLessWithBroadcastToOps
func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testGreaterWithBroadcastToOps
func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {

View File

@ -267,7 +267,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeBiasAdd : Pat<
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
(TFL_AddOp $l, $r, TFL_AF_None)>;
(TF_AddV2Op $l, $r)>;
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),

View File

@ -636,11 +636,155 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
}
};
void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns;
auto* context = &getContext();
auto func = getFunction();
// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
// to make binary broadcast-able op conversion always successful and does not
// require flex delegate.
template <typename SourceOp>
class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp src_op,
PatternRewriter& rewriter) const override {
Operation* op = static_cast<Operation*>(src_op);
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
// Should have static shapes to calculate the broadcasted shape.
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
!rhs.getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
if (lhs_shape == rhs_shape) {
return failure();
}
// Calculate the broadcasted shape.
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
result_shape)) {
return failure();
}
RankedTensorType result_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
// Create a const op, that stores the above broadcasted shape.
auto new_shape_attr = mlir::DenseIntElementsAttr::get(
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
result_shape);
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
// Apply BroadcastTo ops to each input.
auto broadcast_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(lhs.getType()));
if (result_type.getShape() != lhs_shape) {
lhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
new_shape)
.output();
}
if (result_type.getShape() != rhs_shape) {
rhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
new_shape)
.output();
}
// Recreate an op with the above Broadcast op results.
rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
return success();
}
};
// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
// they should have broadcastable shapes.
template <>
class ApplyExplicitBroadcasting<TF::SelectV2Op>
: public OpRewritePattern<TF::SelectV2Op> {
public:
using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
PatternRewriter& rewriter) const override {
Operation* op = static_cast<Operation*>(src_op);
auto cond = op->getOperand(0);
auto lhs = op->getOperand(1);
auto rhs = op->getOperand(2);
// Should have static shapes to calculate the broadcasted shape.
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
!rhs.getType().cast<ShapedType>().hasStaticShape() ||
!cond.getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
auto cond_shape = cond.getType().cast<ShapedType>().getShape();
if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
return failure();
}
// Calculate the broadcasted shape.
SmallVector<int64_t, 4> broadcasted_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
broadcasted_shape)) {
return failure();
}
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
result_shape)) {
return failure();
}
// Create a const op, that stores the above broadcasted shape.
auto shape_type =
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
auto new_shape_attr =
mlir::DenseIntElementsAttr::get(shape_type, result_shape);
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
// Apply BroadcastTo ops to each input.
auto cond_result_type =
RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
auto result_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(lhs.getType()));
if (result_shape != cond_shape) {
cond = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
cond, new_shape)
.output();
}
if (result_shape != lhs_shape) {
lhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
new_shape)
.output();
}
if (result_shape != rhs_shape) {
rhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
new_shape)
.output();
}
// Recreate an op with the above Broadcast op results.
rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
rhs);
return success();
}
};
void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
// Add TF->TF lowering patterns.
TF::PopulateLoweringTFPatterns(context, &patterns);
@ -656,7 +800,25 @@ void LegalizeTF::runOnFunction() {
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,
LegalizeUnidirectionalSequenceRnn>(context);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
}
void applyPatterns(FuncOp func, ConversionTarget& target,
FrozenRewritePatternList& frozenPatterns) {
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
// Currently unit-test doesn't do multiple tries, so we need this.
const int max_iterations = 15;
for (int i = 0; i < max_iterations; ++i) {
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
return;
}
}
}
void LegalizeTF::runOnFunction() {
auto* context = &getContext();
auto func = getFunction();
ConversionTarget target(*context);
// It is legal to have TF ops in the graph still which can be
@ -690,16 +852,42 @@ void LegalizeTF::runOnFunction() {
return success(current_thread_id == llvm::get_threadid());
});
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
// Currently unit-test doesn't do multiple tries, so we need this.
const int max_iterations = 15;
for (int i = 0; i < max_iterations; ++i) {
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
return;
}
}
OwningRewritePatternList stage1Patterns;
addPatterns(context, stage1Patterns);
FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns));
applyPatterns(func, target, stage1FrozenPatterns);
// Explict BroadcastTo addition for left-over broadcast-able ops.
// The following pattern matchings should be done after the other legalization
// rules in order not to add unnecessary BroadcastTo ops.
OwningRewritePatternList stage2Patterns;
addPatterns(context, stage2Patterns);
stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
ApplyExplicitBroadcasting<TF::NotEqualOp>,
ApplyExplicitBroadcasting<TF::GreaterOp>,
ApplyExplicitBroadcasting<TF::LessOp>,
ApplyExplicitBroadcasting<TF::EqualOp>,
ApplyExplicitBroadcasting<TF::AddOp>,
ApplyExplicitBroadcasting<TF::AddV2Op>,
ApplyExplicitBroadcasting<TF::MulOp>,
ApplyExplicitBroadcasting<TF::DivOp>,
ApplyExplicitBroadcasting<TF::RealDivOp>,
ApplyExplicitBroadcasting<TF::SubOp>,
ApplyExplicitBroadcasting<TF::FloorDivOp>,
ApplyExplicitBroadcasting<TF::FloorModOp>,
ApplyExplicitBroadcasting<TF::PowOp>,
ApplyExplicitBroadcasting<TF::MaximumOp>,
ApplyExplicitBroadcasting<TF::MinimumOp>,
ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns));
applyPatterns(func, target, stage2FrozenPatterns);
}
} // namespace

View File

@ -178,6 +178,19 @@ def make_binary_op_tests(options,
},
]
# High dimension broadcasting support in MLIR converter.
if options.use_experimental_converter:
test_parameters = test_parameters + [
{
"dtype": [tf.float32],
"input_shape_1": [[8, 7, 6, 5, 4, 3, 2, 1]],
"input_shape_2": [[4, 3, 2, 1]],
"activation": [False],
"fully_quantize": [False],
"dynamic_range_quantize": [False],
},
]
# test_parameters include fully_quantize option only when
# allow_fully_quantize is True.
if not allow_fully_quantize:

View File

@ -40,6 +40,16 @@ def make_where_tests(options):
},
]
# High dimension broadcasting support in MLIR converter.
if options.use_experimental_converter:
test_parameters = test_parameters + [
{
"input_dtype": [tf.float32, tf.int32],
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
"use_where_v2": [True],
},
]
def build_graph(parameters):
"""Build the where op testing graph."""
input_value1 = tf.compat.v1.placeholder(