Add new pattern matching rules for TFL Broadcastable ops, which have high dimenion inputs.
- Low dimensions <= 4, for example, should avoid adding BroadcastTo op as usual, which means we need to lower broadcast-able ops as before without BroadcastTo op and lower the BroadcastTo op to the hardware accelerator supported ops as well. - This explicit BroadcastTo op needs to be inserted only when a higher dimension is needed, which will unlock the new opportunity. - There are the broadcast-able 20 TFLite ops (about 15 % of TFLite op set) as the followings: Comparison: LessEqual, GreaterEqual, NotEqual, Greater, Less, Equal (up to four dim.) Activation: PRelu (up to four dim.) Arithmetic: Add, Mul, Div, Sub, FloorDiv, FloorMod, Pow, Maximum, Minimum, SquaredDifference (up to four or five dim.) Dimension: SelectV2 (up to four dim.), BroadcastTo (supported via lowering) PiperOrigin-RevId: 340575786 Change-Id: I007b19487512560e1042e99321b0a37e3123c0f4
This commit is contained in:
parent
83ab6303b0
commit
247746808e
tensorflow
compiler/mlir/lite
lite/testing/op_tests
@ -148,18 +148,10 @@ bool IsI64Type(Type element_type) {
|
||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
|
||||
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
|
||||
// Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsQI8Type(element_type) ||
|
||||
IsQUI8Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type)) {
|
||||
IsQUI8Type(element_type) || IsI32Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
@ -211,20 +203,13 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
|
||||
}
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
|
||||
// Allows F32 output when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
}
|
||||
|
||||
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
|
||||
// broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
|
||||
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
|
||||
// are broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
|
||||
element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
|
@ -25,13 +25,6 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te
|
||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
||||
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %2: tensor<1xf32>
|
||||
@ -1568,7 +1561,11 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
|
||||
return %0 : tensor<1x1x1x2x3x4xf32>
|
||||
// CHECK-LABEL: select_v2_with_6d_broadcasting
|
||||
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
|
||||
// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]])
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1578,7 +1575,9 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
|
||||
return %0 : tensor<1x1x1x1x8x16xf32>
|
||||
|
||||
// CHECK-LABEL: maximum_with_6d_broadcasting
|
||||
// CHECK: "tf.Maximum"(%arg0, %arg1)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.maximum"(%arg0, [[BCT]])
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -1587,7 +1586,171 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
|
||||
return %0 : tensor<1x1x1x3x4xi32>
|
||||
// CHECK-LABEL: add_with_int32_5d_inputs
|
||||
// CHECK: "tf.Add"(%arg0, %arg1)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64>
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.add [[BCT]], [[BCT_0]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddWithBroadcastToOps
|
||||
func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSubWithBroadcastToOps
|
||||
func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMulWithBroadcastToOps
|
||||
func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testDivWithBroadcastToOps
|
||||
func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFloorDivWithBroadcastToOps
|
||||
func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFloorModWithBroadcastToOps
|
||||
func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testPowWithBroadcastToOps
|
||||
func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMaximumWithBroadcastToOps
|
||||
func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMinimumWithBroadcastToOps
|
||||
func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectV2WithBroadcastToOps
|
||||
func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
|
||||
// CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]])
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLessEqualWithBroadcastToOps
|
||||
func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testGreaterEqualWithBroadcastToOps
|
||||
func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testEqualWithBroadcastToOps
|
||||
func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testNotEqualWithBroadcastToOps
|
||||
func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLessWithBroadcastToOps
|
||||
func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testGreaterWithBroadcastToOps
|
||||
func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
||||
// CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
||||
}
|
||||
|
||||
func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
|
@ -267,7 +267,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
|
||||
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def LegalizeBiasAdd : Pat<
|
||||
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
|
||||
(TFL_AddOp $l, $r, TFL_AF_None)>;
|
||||
(TF_AddV2Op $l, $r)>;
|
||||
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
|
||||
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
|
||||
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),
|
||||
|
@ -636,11 +636,155 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* context = &getContext();
|
||||
auto func = getFunction();
|
||||
// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
|
||||
// to make binary broadcast-able op conversion always successful and does not
|
||||
// require flex delegate.
|
||||
template <typename SourceOp>
|
||||
class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SourceOp src_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* op = static_cast<Operation*>(src_op);
|
||||
auto lhs = op->getOperand(0);
|
||||
auto rhs = op->getOperand(1);
|
||||
|
||||
// Should have static shapes to calculate the broadcasted shape.
|
||||
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!rhs.getType().cast<ShapedType>().hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
|
||||
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (lhs_shape == rhs_shape) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Calculate the broadcasted shape.
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
|
||||
result_shape)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
RankedTensorType result_type = RankedTensorType::get(
|
||||
result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
|
||||
|
||||
// Create a const op, that stores the above broadcasted shape.
|
||||
auto new_shape_attr = mlir::DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
|
||||
result_shape);
|
||||
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
|
||||
|
||||
// Apply BroadcastTo ops to each input.
|
||||
auto broadcast_type = RankedTensorType::get(
|
||||
result_shape, getElementTypeOrSelf(lhs.getType()));
|
||||
|
||||
if (result_type.getShape() != lhs_shape) {
|
||||
lhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
if (result_type.getShape() != rhs_shape) {
|
||||
rhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
|
||||
// Recreate an op with the above Broadcast op results.
|
||||
rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
|
||||
// they should have broadcastable shapes.
|
||||
template <>
|
||||
class ApplyExplicitBroadcasting<TF::SelectV2Op>
|
||||
: public OpRewritePattern<TF::SelectV2Op> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* op = static_cast<Operation*>(src_op);
|
||||
auto cond = op->getOperand(0);
|
||||
auto lhs = op->getOperand(1);
|
||||
auto rhs = op->getOperand(2);
|
||||
|
||||
// Should have static shapes to calculate the broadcasted shape.
|
||||
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!rhs.getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!cond.getType().cast<ShapedType>().hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs.getType().cast<ShapedType>().getShape();
|
||||
auto rhs_shape = rhs.getType().cast<ShapedType>().getShape();
|
||||
auto cond_shape = cond.getType().cast<ShapedType>().getShape();
|
||||
|
||||
if (lhs_shape == rhs_shape && cond_shape == lhs_shape) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Calculate the broadcasted shape.
|
||||
SmallVector<int64_t, 4> broadcasted_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape,
|
||||
broadcasted_shape)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape,
|
||||
result_shape)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Create a const op, that stores the above broadcasted shape.
|
||||
auto shape_type =
|
||||
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
|
||||
auto new_shape_attr =
|
||||
mlir::DenseIntElementsAttr::get(shape_type, result_shape);
|
||||
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
|
||||
|
||||
// Apply BroadcastTo ops to each input.
|
||||
auto cond_result_type =
|
||||
RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
|
||||
auto result_type = RankedTensorType::get(
|
||||
result_shape, getElementTypeOrSelf(lhs.getType()));
|
||||
|
||||
if (result_shape != cond_shape) {
|
||||
cond = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
|
||||
cond, new_shape)
|
||||
.output();
|
||||
}
|
||||
if (result_shape != lhs_shape) {
|
||||
lhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
if (result_shape != rhs_shape) {
|
||||
rhs = rewriter
|
||||
.create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
|
||||
new_shape)
|
||||
.output();
|
||||
}
|
||||
|
||||
// Recreate an op with the above Broadcast op results.
|
||||
rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
|
||||
rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
|
||||
// Add TF->TF lowering patterns.
|
||||
TF::PopulateLoweringTFPatterns(context, &patterns);
|
||||
|
||||
@ -656,7 +800,25 @@ void LegalizeTF::runOnFunction() {
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
LegalizeUnidirectionalSequenceRnn>(context);
|
||||
FrozenRewritePatternList frozenPatterns(std::move(patterns));
|
||||
}
|
||||
|
||||
void applyPatterns(FuncOp func, ConversionTarget& target,
|
||||
FrozenRewritePatternList& frozenPatterns) {
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
// Currently unit-test doesn't do multiple tries, so we need this.
|
||||
const int max_iterations = 15;
|
||||
for (int i = 0; i < max_iterations; ++i) {
|
||||
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
auto* context = &getContext();
|
||||
auto func = getFunction();
|
||||
|
||||
ConversionTarget target(*context);
|
||||
// It is legal to have TF ops in the graph still which can be
|
||||
@ -690,16 +852,42 @@ void LegalizeTF::runOnFunction() {
|
||||
return success(current_thread_id == llvm::get_threadid());
|
||||
});
|
||||
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
// Currently unit-test doesn't do multiple tries, so we need this.
|
||||
const int max_iterations = 15;
|
||||
for (int i = 0; i < max_iterations; ++i) {
|
||||
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
OwningRewritePatternList stage1Patterns;
|
||||
|
||||
addPatterns(context, stage1Patterns);
|
||||
|
||||
FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns));
|
||||
applyPatterns(func, target, stage1FrozenPatterns);
|
||||
|
||||
// Explict BroadcastTo addition for left-over broadcast-able ops.
|
||||
// The following pattern matchings should be done after the other legalization
|
||||
// rules in order not to add unnecessary BroadcastTo ops.
|
||||
OwningRewritePatternList stage2Patterns;
|
||||
|
||||
addPatterns(context, stage2Patterns);
|
||||
|
||||
stage2Patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::NotEqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::GreaterOp>,
|
||||
ApplyExplicitBroadcasting<TF::LessOp>,
|
||||
ApplyExplicitBroadcasting<TF::EqualOp>,
|
||||
ApplyExplicitBroadcasting<TF::AddOp>,
|
||||
ApplyExplicitBroadcasting<TF::AddV2Op>,
|
||||
ApplyExplicitBroadcasting<TF::MulOp>,
|
||||
ApplyExplicitBroadcasting<TF::DivOp>,
|
||||
ApplyExplicitBroadcasting<TF::RealDivOp>,
|
||||
ApplyExplicitBroadcasting<TF::SubOp>,
|
||||
ApplyExplicitBroadcasting<TF::FloorDivOp>,
|
||||
ApplyExplicitBroadcasting<TF::FloorModOp>,
|
||||
ApplyExplicitBroadcasting<TF::PowOp>,
|
||||
ApplyExplicitBroadcasting<TF::MaximumOp>,
|
||||
ApplyExplicitBroadcasting<TF::MinimumOp>,
|
||||
ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
|
||||
ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
|
||||
|
||||
FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns));
|
||||
applyPatterns(func, target, stage2FrozenPatterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -178,6 +178,19 @@ def make_binary_op_tests(options,
|
||||
},
|
||||
]
|
||||
|
||||
# High dimension broadcasting support in MLIR converter.
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = test_parameters + [
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
"input_shape_1": [[8, 7, 6, 5, 4, 3, 2, 1]],
|
||||
"input_shape_2": [[4, 3, 2, 1]],
|
||||
"activation": [False],
|
||||
"fully_quantize": [False],
|
||||
"dynamic_range_quantize": [False],
|
||||
},
|
||||
]
|
||||
|
||||
# test_parameters include fully_quantize option only when
|
||||
# allow_fully_quantize is True.
|
||||
if not allow_fully_quantize:
|
||||
|
@ -40,6 +40,16 @@ def make_where_tests(options):
|
||||
},
|
||||
]
|
||||
|
||||
# High dimension broadcasting support in MLIR converter.
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = test_parameters + [
|
||||
{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
|
||||
"use_where_v2": [True],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the where op testing graph."""
|
||||
input_value1 = tf.compat.v1.placeholder(
|
||||
|
Loading…
Reference in New Issue
Block a user