Rollback of BroadcastTo op additions (part 1)
Rolling back until discussion about builtin ops schema issue is discussed. PiperOrigin-RevId: 322867083 Change-Id: I85bc33675a00ea5ff7253d9d1eb53b047f9f4658
This commit is contained in:
parent
5a58103a7d
commit
5882f49288
@ -147,10 +147,18 @@ bool IsI64Type(Type element_type) {
|
|||||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||||
|
|
||||||
// Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes,
|
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
|
||||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||||
if (element_type.isF32() || IsQI8Type(element_type) ||
|
if (element_type.isF32() || IsQI8Type(element_type) ||
|
||||||
IsQUI8Type(element_type) || IsI32Type(element_type)) {
|
IsQUI8Type(element_type)) {
|
||||||
|
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||||
|
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||||
|
/*max_bcast_rank=*/5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allows I32 output when the operands have valid shapes, which are
|
||||||
|
// broadcastable shapes up to four dimension or have same shapes.
|
||||||
|
if (IsI32Type(element_type)) {
|
||||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||||
/*max_bcast_rank=*/4);
|
/*max_bcast_rank=*/4);
|
||||||
@ -202,13 +210,20 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
|
|||||||
}
|
}
|
||||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||||
/*max_bcast_rank=*/4);
|
/*max_bcast_rank=*/5);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
|
// Allows F32 output when the operands have valid shapes, which are
|
||||||
// are broadcastable shapes up to four dimension or have same shapes.
|
// broadcastable shapes up to five dimension or have same shapes.
|
||||||
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
|
if (element_type.isF32()) {
|
||||||
element_type.isF32()) {
|
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||||
|
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||||
|
/*max_bcast_rank=*/5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
|
||||||
|
// broadcastable shapes up to four dimension or have same shapes.
|
||||||
|
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
|
||||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||||
/*max_bcast_rank=*/4);
|
/*max_bcast_rank=*/4);
|
||||||
|
@ -25,6 +25,13 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te
|
|||||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
||||||
|
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||||
|
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
||||||
|
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||||
|
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||||
|
}
|
||||||
|
|
||||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
return %2: tensor<1xf32>
|
return %2: tensor<1xf32>
|
||||||
@ -1523,11 +1530,7 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens
|
|||||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
|
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
|
||||||
return %0 : tensor<1x1x1x2x3x4xf32>
|
return %0 : tensor<1x1x1x2x3x4xf32>
|
||||||
// CHECK-LABEL: select_v2_with_6d_broadcasting
|
// CHECK-LABEL: select_v2_with_6d_broadcasting
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64>
|
// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
|
||||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
|
|
||||||
// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
@ -1537,9 +1540,7 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
|
|||||||
return %0 : tensor<1x1x1x1x8x16xf32>
|
return %0 : tensor<1x1x1x1x8x16xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: maximum_with_6d_broadcasting
|
// CHECK-LABEL: maximum_with_6d_broadcasting
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64>
|
// CHECK: "tf.Maximum"(%arg0, %arg1)
|
||||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.maximum"(%arg0, [[BCT]])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
@ -1548,169 +1549,5 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
|
|||||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
|
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
|
||||||
return %0 : tensor<1x1x1x3x4xi32>
|
return %0 : tensor<1x1x1x3x4xi32>
|
||||||
// CHECK-LABEL: add_with_int32_5d_inputs
|
// CHECK-LABEL: add_with_int32_5d_inputs
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64>
|
// CHECK: "tf.Add"(%arg0, %arg1)
|
||||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.add [[BCT]], [[BCT_0]]
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testAddWithBroadcastToOps
|
|
||||||
func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testSubWithBroadcastToOps
|
|
||||||
func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testMulWithBroadcastToOps
|
|
||||||
func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testDivWithBroadcastToOps
|
|
||||||
func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testFloorDivWithBroadcastToOps
|
|
||||||
func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testFloorModWithBroadcastToOps
|
|
||||||
func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testPowWithBroadcastToOps
|
|
||||||
func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testMaximumWithBroadcastToOps
|
|
||||||
func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testMinimumWithBroadcastToOps
|
|
||||||
func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testSelectV2WithBroadcastToOps
|
|
||||||
func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
|
|
||||||
// CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]])
|
|
||||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testLessEqualWithBroadcastToOps
|
|
||||||
func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testGreaterEqualWithBroadcastToOps
|
|
||||||
func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testEqualWithBroadcastToOps
|
|
||||||
func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testNotEqualWithBroadcastToOps
|
|
||||||
func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testLessWithBroadcastToOps
|
|
||||||
func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
%0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testGreaterWithBroadcastToOps
|
|
||||||
func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
|
|
||||||
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
|
|
||||||
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
|
|
||||||
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
|
|
||||||
// CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi1>
|
|
||||||
}
|
}
|
||||||
|
@ -256,7 +256,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
|
|||||||
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
|
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
|
||||||
def LegalizeBiasAdd : Pat<
|
def LegalizeBiasAdd : Pat<
|
||||||
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
|
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
|
||||||
(TF_AddV2Op $l, $r)>;
|
(TFL_AddOp $l, $r, TFL_AF_None)>;
|
||||||
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
|
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
|
||||||
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
|
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
|
||||||
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),
|
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),
|
||||||
|
@ -631,156 +631,6 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
|
|
||||||
// to make binary broadcast-able op conversion always successful and does not
|
|
||||||
// require flex delegate.
|
|
||||||
template <typename SourceOp>
|
|
||||||
class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(SourceOp src_op,
|
|
||||||
PatternRewriter& rewriter) const override {
|
|
||||||
Operation* op = static_cast<Operation*>(src_op);
|
|
||||||
auto lhs = op->getOperand(0);
|
|
||||||
auto rhs = op->getOperand(1);
|
|
||||||
|
|
||||||
// Should have static shapes to calculate the broadcasted shape.
|
|
||||||
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
|
|
||||||
!rhs.getType().cast<ShapedType>().hasStaticShape()) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the broadcasted shape.
|
|
||||||
SmallVector<int64_t, 4> result_shape;
|
|
||||||
if (!OpTrait::util::getBroadcastedShape(
|
|
||||||
lhs.getType().cast<ShapedType>().getShape(),
|
|
||||||
rhs.getType().cast<ShapedType>().getShape(), result_shape)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
RankedTensorType result_type = RankedTensorType::get(
|
|
||||||
result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
|
|
||||||
|
|
||||||
// Create a const op, that stores the above broadcasted shape.
|
|
||||||
auto new_shape_attr = mlir::DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
|
|
||||||
result_shape);
|
|
||||||
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
|
|
||||||
|
|
||||||
// Apply BroadcastTo ops to each input.
|
|
||||||
auto broadcast_type = RankedTensorType::get(
|
|
||||||
result_shape, getElementTypeOrSelf(lhs.getType()));
|
|
||||||
|
|
||||||
if (result_type.getShape() != lhs.getType().cast<ShapedType>().getShape()) {
|
|
||||||
lhs = rewriter
|
|
||||||
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
|
|
||||||
new_shape)
|
|
||||||
.output();
|
|
||||||
}
|
|
||||||
if (result_type.getShape() != rhs.getType().cast<ShapedType>().getShape()) {
|
|
||||||
rhs = rewriter
|
|
||||||
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
|
|
||||||
new_shape)
|
|
||||||
.output();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recreate an op with the above Broadcast op results.
|
|
||||||
rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
|
|
||||||
// they should have broadcastable shapes.
|
|
||||||
template <>
|
|
||||||
class ApplyExplicitBroadcasting<TF::SelectV2Op>
|
|
||||||
: public OpRewritePattern<TF::SelectV2Op> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
|
|
||||||
PatternRewriter& rewriter) const override {
|
|
||||||
Operation* op = static_cast<Operation*>(src_op);
|
|
||||||
auto cond = op->getOperand(0);
|
|
||||||
auto lhs = op->getOperand(1);
|
|
||||||
auto rhs = op->getOperand(2);
|
|
||||||
|
|
||||||
// Should have static shapes to calculate the broadcasted shape.
|
|
||||||
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
|
|
||||||
!rhs.getType().cast<ShapedType>().hasStaticShape() ||
|
|
||||||
!cond.getType().cast<ShapedType>().hasStaticShape()) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the broadcasted shape.
|
|
||||||
SmallVector<int64_t, 4> broadcasted_shape;
|
|
||||||
if (!OpTrait::util::getBroadcastedShape(
|
|
||||||
lhs.getType().cast<ShapedType>().getShape(),
|
|
||||||
rhs.getType().cast<ShapedType>().getShape(), broadcasted_shape)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> result_shape;
|
|
||||||
if (!OpTrait::util::getBroadcastedShape(
|
|
||||||
broadcasted_shape, cond.getType().cast<ShapedType>().getShape(),
|
|
||||||
result_shape)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a const op, that stores the above broadcasted shape.
|
|
||||||
auto shape_type =
|
|
||||||
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
|
|
||||||
auto new_shape_attr =
|
|
||||||
mlir::DenseIntElementsAttr::get(shape_type, result_shape);
|
|
||||||
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
|
|
||||||
|
|
||||||
// Apply BroadcastTo ops to each input.
|
|
||||||
auto cond_result_type =
|
|
||||||
RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
|
|
||||||
auto result_type = RankedTensorType::get(
|
|
||||||
result_shape, getElementTypeOrSelf(lhs.getType()));
|
|
||||||
|
|
||||||
if (result_shape != cond.getType().cast<ShapedType>().getShape()) {
|
|
||||||
cond = rewriter
|
|
||||||
.create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
|
|
||||||
cond, new_shape)
|
|
||||||
.output();
|
|
||||||
}
|
|
||||||
if (result_shape != lhs.getType().cast<ShapedType>().getShape()) {
|
|
||||||
lhs = rewriter
|
|
||||||
.create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
|
|
||||||
new_shape)
|
|
||||||
.output();
|
|
||||||
}
|
|
||||||
if (result_shape != rhs.getType().cast<ShapedType>().getShape()) {
|
|
||||||
rhs = rewriter
|
|
||||||
.create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
|
|
||||||
new_shape)
|
|
||||||
.output();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recreate an op with the above Broadcast op results.
|
|
||||||
rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
|
|
||||||
rhs);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void applyPatterns(FuncOp func, ConversionTarget& target,
|
|
||||||
const OwningRewritePatternList& patterns) {
|
|
||||||
// Keep trying to convert.
|
|
||||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
|
||||||
// Look if there is a function that tries until it converge.
|
|
||||||
// Currently unit-test doesn't do multiple tries, so we need this.
|
|
||||||
const int max_iterations = 15;
|
|
||||||
for (int i = 0; i < max_iterations; ++i) {
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void LegalizeTF::runOnFunction() {
|
void LegalizeTF::runOnFunction() {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto* context = &getContext();
|
auto* context = &getContext();
|
||||||
@ -831,32 +681,16 @@ void LegalizeTF::runOnFunction() {
|
|||||||
return success(current_thread_id == llvm::get_threadid());
|
return success(current_thread_id == llvm::get_threadid());
|
||||||
});
|
});
|
||||||
|
|
||||||
applyPatterns(func, target, patterns);
|
// Keep trying to convert.
|
||||||
|
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||||
// Explict BroadcastTo addition for left-over broadcast-able ops.
|
// Look if there is a function that tries until it converge.
|
||||||
// The following pattern matchings should be done after the other legalization
|
// Currently unit-test doesn't do multiple tries, so we need this.
|
||||||
// rules in order not to add unnecessary BroadcastTo ops.
|
const int max_iterations = 15;
|
||||||
patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>,
|
for (int i = 0; i < max_iterations; ++i) {
|
||||||
ApplyExplicitBroadcasting<TF::GreaterEqualOp>,
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
ApplyExplicitBroadcasting<TF::NotEqualOp>,
|
return;
|
||||||
ApplyExplicitBroadcasting<TF::GreaterOp>,
|
}
|
||||||
ApplyExplicitBroadcasting<TF::LessOp>,
|
}
|
||||||
ApplyExplicitBroadcasting<TF::EqualOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::AddOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::AddV2Op>,
|
|
||||||
ApplyExplicitBroadcasting<TF::MulOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::DivOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::RealDivOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::SubOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::FloorDivOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::FloorModOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::PowOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::MaximumOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::MinimumOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
|
|
||||||
ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
|
|
||||||
|
|
||||||
applyPatterns(func, target, patterns);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -178,19 +178,6 @@ def make_binary_op_tests(options,
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# High dimension broadcasting support in MLIR converter.
|
|
||||||
if options.use_experimental_converter:
|
|
||||||
test_parameters = test_parameters + [
|
|
||||||
{
|
|
||||||
"dtype": [tf.float32],
|
|
||||||
"input_shape_1": [[8, 7, 6, 5, 4, 3, 2, 1]],
|
|
||||||
"input_shape_2": [[4, 3, 2, 1]],
|
|
||||||
"activation": [False],
|
|
||||||
"fully_quantize": [False],
|
|
||||||
"dynamic_range_quantize": [False],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# test_parameters include fully_quantize option only when
|
# test_parameters include fully_quantize option only when
|
||||||
# allow_fully_quantize is True.
|
# allow_fully_quantize is True.
|
||||||
if not allow_fully_quantize:
|
if not allow_fully_quantize:
|
||||||
|
@ -35,16 +35,6 @@ def make_where_tests(options):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# High dimension broadcasting support in MLIR converter.
|
|
||||||
if options.use_experimental_converter:
|
|
||||||
test_parameters = test_parameters + [
|
|
||||||
{
|
|
||||||
"input_dtype": [tf.float32, tf.int32],
|
|
||||||
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
|
|
||||||
"use_where_v2": [True],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
"""Build the where op testing graph."""
|
"""Build the where op testing graph."""
|
||||||
input_value1 = tf.compat.v1.placeholder(
|
input_value1 = tf.compat.v1.placeholder(
|
||||||
|
Loading…
Reference in New Issue
Block a user