Rollback of BroadcastTo op additions (part 1)

Rolling back until discussion about builtin ops schema issue is discussed.

PiperOrigin-RevId: 322867083
Change-Id: I85bc33675a00ea5ff7253d9d1eb53b047f9f4658
This commit is contained in:
Jaesung Chung 2020-07-23 14:30:11 -07:00 committed by TensorFlower Gardener
parent 5a58103a7d
commit 5882f49288
6 changed files with 43 additions and 380 deletions

View File

@ -147,10 +147,18 @@ bool IsI64Type(Type element_type) {
bool VerifyAddOpShapeConstraints(AddOp op) { bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType()); auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes, // Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes. // which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) || if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type) || IsI32Type(element_type)) { IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape( return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1}, /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4); /*max_bcast_rank=*/4);
@ -202,13 +210,20 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
} }
return VerifyOperandsHaveSameShapesOrBroadcastableShape( return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1}, /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4); /*max_bcast_rank=*/5);
} }
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which // Allows F32 output when the operands have valid shapes, which are
// are broadcastable shapes up to four dimension or have same shapes. // broadcastable shapes up to five dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type) || if (element_type.isF32()) {
element_type.isF32()) { return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape( return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1}, /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4); /*max_bcast_rank=*/4);

View File

@ -25,6 +25,13 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te
return %0 : tensor<1x2x3x4x5x6x7x8xi32> return %0 : tensor<1x2x3x4x5x6x7x8xi32>
} }
// CHECK-LABEL: testAddTooHighBroadcastableDims
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32> %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32> return %2: tensor<1xf32>
@ -1523,11 +1530,7 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> %0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32>
return %0 : tensor<1x1x1x2x3x4xf32> return %0 : tensor<1x1x1x2x3x4xf32>
// CHECK-LABEL: select_v2_with_6d_broadcasting // CHECK-LABEL: select_v2_with_6d_broadcasting
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64> // CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2)
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]])
} }
// ----- // -----
@ -1537,9 +1540,7 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso
return %0 : tensor<1x1x1x1x8x16xf32> return %0 : tensor<1x1x1x1x8x16xf32>
// CHECK-LABEL: maximum_with_6d_broadcasting // CHECK-LABEL: maximum_with_6d_broadcasting
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64> // CHECK: "tf.Maximum"(%arg0, %arg1)
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.maximum"(%arg0, [[BCT]])
} }
// ----- // -----
@ -1548,169 +1549,5 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x
%0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> %0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32>
return %0 : tensor<1x1x1x3x4xi32> return %0 : tensor<1x1x1x3x4xi32>
// CHECK-LABEL: add_with_int32_5d_inputs // CHECK-LABEL: add_with_int32_5d_inputs
// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64> // CHECK: "tf.Add"(%arg0, %arg1)
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.add [[BCT]], [[BCT_0]]
}
// CHECK-LABEL: testAddWithBroadcastToOps
func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testSubWithBroadcastToOps
func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMulWithBroadcastToOps
func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testDivWithBroadcastToOps
func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testFloorDivWithBroadcastToOps
func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testFloorModWithBroadcastToOps
func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testPowWithBroadcastToOps
func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32>
%0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMaximumWithBroadcastToOps
func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testMinimumWithBroadcastToOps
func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32>
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testSelectV2WithBroadcastToOps
func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]])
// CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]])
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
// CHECK-LABEL: testLessEqualWithBroadcastToOps
func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testGreaterEqualWithBroadcastToOps
func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testEqualWithBroadcastToOps
func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testNotEqualWithBroadcastToOps
func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testLessWithBroadcastToOps
func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
}
// CHECK-LABEL: testGreaterWithBroadcastToOps
func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> {
// CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>
// CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]])
// CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]])
// CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1>
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1>
return %0 : tensor<1x2x3x4x5x6xi1>
} }

View File

@ -256,7 +256,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>; (TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeBiasAdd : Pat< def LegalizeBiasAdd : Pat<
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format), (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
(TF_AddV2Op $l, $r)>; (TFL_AddOp $l, $r, TFL_AF_None)>;
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs), def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>; (TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs), def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),

View File

@ -631,156 +631,6 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
} }
}; };
// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to
// to make binary broadcast-able op conversion always successful and does not
// require flex delegate.
template <typename SourceOp>
class ApplyExplicitBroadcasting : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp src_op,
PatternRewriter& rewriter) const override {
Operation* op = static_cast<Operation*>(src_op);
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
// Should have static shapes to calculate the broadcasted shape.
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
!rhs.getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
// Calculate the broadcasted shape.
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(
lhs.getType().cast<ShapedType>().getShape(),
rhs.getType().cast<ShapedType>().getShape(), result_shape)) {
return failure();
}
RankedTensorType result_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(op->getResult(0).getType()));
// Create a const op, that stores the above broadcasted shape.
auto new_shape_attr = mlir::DenseIntElementsAttr::get(
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)),
result_shape);
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
// Apply BroadcastTo ops to each input.
auto broadcast_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(lhs.getType()));
if (result_type.getShape() != lhs.getType().cast<ShapedType>().getShape()) {
lhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, lhs,
new_shape)
.output();
}
if (result_type.getShape() != rhs.getType().cast<ShapedType>().getShape()) {
rhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), broadcast_type, rhs,
new_shape)
.output();
}
// Recreate an op with the above Broadcast op results.
rewriter.replaceOpWithNewOp<SourceOp>(op, result_type, lhs, rhs);
return success();
}
};
// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and
// they should have broadcastable shapes.
template <>
class ApplyExplicitBroadcasting<TF::SelectV2Op>
: public OpRewritePattern<TF::SelectV2Op> {
public:
using OpRewritePattern<TF::SelectV2Op>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SelectV2Op src_op,
PatternRewriter& rewriter) const override {
Operation* op = static_cast<Operation*>(src_op);
auto cond = op->getOperand(0);
auto lhs = op->getOperand(1);
auto rhs = op->getOperand(2);
// Should have static shapes to calculate the broadcasted shape.
if (!lhs.getType().cast<ShapedType>().hasStaticShape() ||
!rhs.getType().cast<ShapedType>().hasStaticShape() ||
!cond.getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
// Calculate the broadcasted shape.
SmallVector<int64_t, 4> broadcasted_shape;
if (!OpTrait::util::getBroadcastedShape(
lhs.getType().cast<ShapedType>().getShape(),
rhs.getType().cast<ShapedType>().getShape(), broadcasted_shape)) {
return failure();
}
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(
broadcasted_shape, cond.getType().cast<ShapedType>().getShape(),
result_shape)) {
return failure();
}
// Create a const op, that stores the above broadcasted shape.
auto shape_type =
RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64));
auto new_shape_attr =
mlir::DenseIntElementsAttr::get(shape_type, result_shape);
auto new_shape = rewriter.create<TF::ConstOp>(op->getLoc(), new_shape_attr);
// Apply BroadcastTo ops to each input.
auto cond_result_type =
RankedTensorType::get(result_shape, rewriter.getIntegerType(1));
auto result_type = RankedTensorType::get(
result_shape, getElementTypeOrSelf(lhs.getType()));
if (result_shape != cond.getType().cast<ShapedType>().getShape()) {
cond = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), cond_result_type,
cond, new_shape)
.output();
}
if (result_shape != lhs.getType().cast<ShapedType>().getShape()) {
lhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), result_type, lhs,
new_shape)
.output();
}
if (result_shape != rhs.getType().cast<ShapedType>().getShape()) {
rhs = rewriter
.create<TF::BroadcastToOp>(op->getLoc(), result_type, rhs,
new_shape)
.output();
}
// Recreate an op with the above Broadcast op results.
rewriter.replaceOpWithNewOp<TF::SelectV2Op>(op, result_type, cond, lhs,
rhs);
return success();
}
};
void applyPatterns(FuncOp func, ConversionTarget& target,
const OwningRewritePatternList& patterns) {
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge.
// Currently unit-test doesn't do multiple tries, so we need this.
const int max_iterations = 15;
for (int i = 0; i < max_iterations; ++i) {
if (failed(applyPartialConversion(func, target, patterns))) {
return;
}
}
}
void LegalizeTF::runOnFunction() { void LegalizeTF::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto* context = &getContext(); auto* context = &getContext();
@ -831,32 +681,16 @@ void LegalizeTF::runOnFunction() {
return success(current_thread_id == llvm::get_threadid()); return success(current_thread_id == llvm::get_threadid());
}); });
applyPatterns(func, target, patterns); // Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Explict BroadcastTo addition for left-over broadcast-able ops. // Look if there is a function that tries until it converge.
// The following pattern matchings should be done after the other legalization // Currently unit-test doesn't do multiple tries, so we need this.
// rules in order not to add unnecessary BroadcastTo ops. const int max_iterations = 15;
patterns.insert<ApplyExplicitBroadcasting<TF::LessEqualOp>, for (int i = 0; i < max_iterations; ++i) {
ApplyExplicitBroadcasting<TF::GreaterEqualOp>, if (failed(applyPartialConversion(func, target, patterns))) {
ApplyExplicitBroadcasting<TF::NotEqualOp>, return;
ApplyExplicitBroadcasting<TF::GreaterOp>, }
ApplyExplicitBroadcasting<TF::LessOp>, }
ApplyExplicitBroadcasting<TF::EqualOp>,
ApplyExplicitBroadcasting<TF::AddOp>,
ApplyExplicitBroadcasting<TF::AddV2Op>,
ApplyExplicitBroadcasting<TF::MulOp>,
ApplyExplicitBroadcasting<TF::DivOp>,
ApplyExplicitBroadcasting<TF::RealDivOp>,
ApplyExplicitBroadcasting<TF::SubOp>,
ApplyExplicitBroadcasting<TF::FloorDivOp>,
ApplyExplicitBroadcasting<TF::FloorModOp>,
ApplyExplicitBroadcasting<TF::PowOp>,
ApplyExplicitBroadcasting<TF::MaximumOp>,
ApplyExplicitBroadcasting<TF::MinimumOp>,
ApplyExplicitBroadcasting<TF::SquaredDifferenceOp>,
ApplyExplicitBroadcasting<TF::SelectV2Op>>(context);
applyPatterns(func, target, patterns);
} }
} // namespace } // namespace

View File

@ -178,19 +178,6 @@ def make_binary_op_tests(options,
}, },
] ]
# High dimension broadcasting support in MLIR converter.
if options.use_experimental_converter:
test_parameters = test_parameters + [
{
"dtype": [tf.float32],
"input_shape_1": [[8, 7, 6, 5, 4, 3, 2, 1]],
"input_shape_2": [[4, 3, 2, 1]],
"activation": [False],
"fully_quantize": [False],
"dynamic_range_quantize": [False],
},
]
# test_parameters include fully_quantize option only when # test_parameters include fully_quantize option only when
# allow_fully_quantize is True. # allow_fully_quantize is True.
if not allow_fully_quantize: if not allow_fully_quantize:

View File

@ -35,16 +35,6 @@ def make_where_tests(options):
}, },
] ]
# High dimension broadcasting support in MLIR converter.
if options.use_experimental_converter:
test_parameters = test_parameters + [
{
"input_dtype": [tf.float32, tf.int32],
"input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),],
"use_where_v2": [True],
},
]
def build_graph(parameters): def build_graph(parameters):
"""Build the where op testing graph.""" """Build the where op testing graph."""
input_value1 = tf.compat.v1.placeholder( input_value1 = tf.compat.v1.placeholder(