diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index d593c0ec836..63ac51eef32 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -148,18 +148,10 @@ bool IsI64Type(Type element_type) { bool VerifyAddOpShapeConstraints(AddOp op) { auto element_type = getElementTypeOrSelf(op.output().getType()); - // Allows F32, QI8, and QUI8 outputs when the operands have valid shapes, + // Allows F32, QI8, QUI8 and I32 outputs when the operands have valid shapes, // which are broadcastable shapes up to five dimension or have same shapes. if (element_type.isF32() || IsQI8Type(element_type) || - IsQUI8Type(element_type)) { - return VerifyOperandsHaveSameShapesOrBroadcastableShape( - /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, - /*max_bcast_rank=*/5); - } - - // Allows I32 output when the operands have valid shapes, which are - // broadcastable shapes up to four dimension or have same shapes. - if (IsI32Type(element_type)) { + IsQUI8Type(element_type) || IsI32Type(element_type)) { return VerifyOperandsHaveSameShapesOrBroadcastableShape( /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, /*max_bcast_rank=*/4); @@ -211,20 +203,13 @@ bool VerifyMulOpShapeConstraints(MulOp op) { } return VerifyOperandsHaveSameShapesOrBroadcastableShape( /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, - /*max_bcast_rank=*/5); + /*max_bcast_rank=*/4); } - // Allows F32 output when the operands have valid shapes, which are - // broadcastable shapes up to five dimension or have same shapes. - if (element_type.isF32()) { - return VerifyOperandsHaveSameShapesOrBroadcastableShape( - /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, - /*max_bcast_rank=*/5); - } - - // Allows I32 and QI16 outputs when the operands have valid shapes, which are - // broadcastable shapes up to four dimension or have same shapes. - if (IsI32Type(element_type) || IsQI16Type(element_type)) { + // Allows I32, QI16 and F32 outputs when the operands have valid shapes, which + // are broadcastable shapes up to four dimension or have same shapes. + if (IsI32Type(element_type) || IsQI16Type(element_type) || + element_type.isF32()) { return VerifyOperandsHaveSameShapesOrBroadcastableShape( /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, /*max_bcast_rank=*/4); diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index b712278fb17..5e36f4af802 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -25,13 +25,6 @@ func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: te return %0 : tensor<1x2x3x4x5x6x7x8xi32> } -// CHECK-LABEL: testAddTooHighBroadcastableDims -func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { - // expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> - return %0 : tensor<1x2x3x4x5x6xi32> -} - func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32> return %2: tensor<1xf32> @@ -1568,7 +1561,11 @@ func @select_v2_with_6d_broadcasting(%arg0: tensor<1x1x1x1x3x1xi1>, %arg1 : tens %0 = "tf.SelectV2"(%arg0, %arg1, %arg2): (tensor<1x1x1x1x3x1xi1>, tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x2x1x1xf32>) -> tensor<1x1x1x2x3x4xf32> return %0 : tensor<1x1x1x2x3x4xf32> // CHECK-LABEL: select_v2_with_6d_broadcasting -// CHECK: "tf.SelectV2"(%arg0, %arg1, %arg2) +// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 2, 3, 4]> : tensor<6xi64> +// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) +// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) +// CHECK: [[BCT_1:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]]) +// CHECK: "tfl.select"([[BCT]], [[BCT_0]], [[BCT_1]]) } // ----- @@ -1578,7 +1575,9 @@ func @maximum_with_6d_broadcasting(%arg0: tensor<1x1x1x1x8x16xf32>, %arg1: tenso return %0 : tensor<1x1x1x1x8x16xf32> // CHECK-LABEL: maximum_with_6d_broadcasting -// CHECK: "tf.Maximum"(%arg0, %arg1) +// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 1, 8, 16]> : tensor<6xi64> +// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) +// CHECK: "tfl.maximum"(%arg0, [[BCT]]) } // ----- @@ -1587,7 +1586,171 @@ func @add_with_int32_5d_inputs(%arg0: tensor<1x1x1x3x1xi32>, %arg1 : tensor<1x1x %0 = "tf.Add"(%arg0, %arg1): (tensor<1x1x1x3x1xi32>, tensor<1x1x1x1x4xi32>) -> tensor<1x1x1x3x4xi32> return %0 : tensor<1x1x1x3x4xi32> // CHECK-LABEL: add_with_int32_5d_inputs -// CHECK: "tf.Add"(%arg0, %arg1) +// CHECK: [[CST:%.*]] = constant dense<[1, 1, 1, 3, 4]> : tensor<5xi64> +// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) +// CHECK: [[BCT_0:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) +// CHECK: tfl.add [[BCT]], [[BCT_0]] +} + +// CHECK-LABEL: testAddWithBroadcastToOps +func @testAddWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: tfl.add [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32> + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testSubWithBroadcastToOps +func @testSubWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: tfl.sub [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32> + %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testMulWithBroadcastToOps +func @testMulWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: tfl.mul [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32> + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testDivWithBroadcastToOps +func @testDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: tfl.div [[BCAST]], [[BCAST_1]] {fused_activation_function = "NONE"} : tensor<1x2x3x4x5x6xi32> + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testFloorDivWithBroadcastToOps +func @testFloorDivWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: tfl.floor_div [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32> + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testFloorModWithBroadcastToOps +func @testFloorModWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.floor_mod"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32> + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testPowWithBroadcastToOps +func @testPowWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: tfl.pow [[BCAST]], [[BCAST_1]] : tensor<1x2x3x4x5x6xi32> + %0 = "tf.Pow"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testMaximumWithBroadcastToOps +func @testMaximumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.maximum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32> + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testMinimumWithBroadcastToOps +func @testMinimumWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.minimum"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testSelectV2WithBroadcastToOps +func @testSelectV2WithBroadcastToOps(%arg0: tensor<1x2x1x4x1x6xi1>, %arg1: tensor<1x2x3x4x1x1xi32>, %arg2: tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: [[BCAST_2:%.*]] = "tfl.broadcast_to"(%arg2, [[CST]]) + // CHECK: "tfl.select"([[BCAST]], [[BCAST_1]], [[BCAST_2]]) + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1x2x1x4x1x6xi1>, tensor<1x2x3x4x1x1xi32>, tensor<1x2x1x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + +// CHECK-LABEL: testLessEqualWithBroadcastToOps +func @testLessEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.less_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1> + %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> + return %0 : tensor<1x2x3x4x5x6xi1> +} + +// CHECK-LABEL: testGreaterEqualWithBroadcastToOps +func @testGreaterEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.greater_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1> + %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> + return %0 : tensor<1x2x3x4x5x6xi1> +} + +// CHECK-LABEL: testEqualWithBroadcastToOps +func @testEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1> + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> + return %0 : tensor<1x2x3x4x5x6xi1> +} + +// CHECK-LABEL: testNotEqualWithBroadcastToOps +func @testNotEqualWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.not_equal"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1> + %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> + return %0 : tensor<1x2x3x4x5x6xi1> +} + +// CHECK-LABEL: testLessWithBroadcastToOps +func @testLessWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.less"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1> + %0 = "tf.Less"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> + return %0 : tensor<1x2x3x4x5x6xi1> +} + +// CHECK-LABEL: testGreaterWithBroadcastToOps +func @testGreaterWithBroadcastToOps(%arg0: tensor<1x2x1x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> { + // CHECK: [[CST:%.*]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64> + // CHECK: [[BCAST:%.*]] = "tfl.broadcast_to"(%arg0, [[CST]]) + // CHECK: [[BCAST_1:%.*]] = "tfl.broadcast_to"(%arg1, [[CST]]) + // CHECK: "tfl.greater"([[BCAST]], [[BCAST_1]]) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x6xi32>) -> tensor<1x2x3x4x5x6xi1> + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1x2x1x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi1> + return %0 : tensor<1x2x3x4x5x6xi1> } func @tranpose_int32_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 6a15e22326e..728bb0852f7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -267,7 +267,7 @@ def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; def LegalizeBiasAdd : Pat< (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format), - (TFL_AddOp $l, $r, TFL_AF_None)>; + (TF_AddV2Op $l, $r)>; def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs), diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index c5398d290a9..5f96bd1198a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -636,11 +636,155 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { } }; -void LegalizeTF::runOnFunction() { - OwningRewritePatternList patterns; - auto* context = &getContext(); - auto func = getFunction(); +// Put two TFL BroadcastTo ops in front of the given TF binary broadcast op to +// to make binary broadcast-able op conversion always successful and does not +// require flex delegate. +template +class ApplyExplicitBroadcasting : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SourceOp src_op, + PatternRewriter& rewriter) const override { + Operation* op = static_cast(src_op); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + + // Should have static shapes to calculate the broadcasted shape. + if (!lhs.getType().cast().hasStaticShape() || + !rhs.getType().cast().hasStaticShape()) { + return failure(); + } + + auto lhs_shape = lhs.getType().cast().getShape(); + auto rhs_shape = rhs.getType().cast().getShape(); + + if (lhs_shape == rhs_shape) { + return failure(); + } + + // Calculate the broadcasted shape. + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, + result_shape)) { + return failure(); + } + + RankedTensorType result_type = RankedTensorType::get( + result_shape, getElementTypeOrSelf(op->getResult(0).getType())); + + // Create a const op, that stores the above broadcasted shape. + auto new_shape_attr = mlir::DenseIntElementsAttr::get( + RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)), + result_shape); + auto new_shape = rewriter.create(op->getLoc(), new_shape_attr); + + // Apply BroadcastTo ops to each input. + auto broadcast_type = RankedTensorType::get( + result_shape, getElementTypeOrSelf(lhs.getType())); + + if (result_type.getShape() != lhs_shape) { + lhs = rewriter + .create(op->getLoc(), broadcast_type, lhs, + new_shape) + .output(); + } + if (result_type.getShape() != rhs_shape) { + rhs = rewriter + .create(op->getLoc(), broadcast_type, rhs, + new_shape) + .output(); + } + + // Recreate an op with the above Broadcast op results. + rewriter.replaceOpWithNewOp(op, result_type, lhs, rhs); + return success(); + } +}; + +// This specialization is for TF SelectV2 op. SelectV2 op have three inputs and +// they should have broadcastable shapes. +template <> +class ApplyExplicitBroadcasting + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SelectV2Op src_op, + PatternRewriter& rewriter) const override { + Operation* op = static_cast(src_op); + auto cond = op->getOperand(0); + auto lhs = op->getOperand(1); + auto rhs = op->getOperand(2); + + // Should have static shapes to calculate the broadcasted shape. + if (!lhs.getType().cast().hasStaticShape() || + !rhs.getType().cast().hasStaticShape() || + !cond.getType().cast().hasStaticShape()) { + return failure(); + } + + auto lhs_shape = lhs.getType().cast().getShape(); + auto rhs_shape = rhs.getType().cast().getShape(); + auto cond_shape = cond.getType().cast().getShape(); + + if (lhs_shape == rhs_shape && cond_shape == lhs_shape) { + return failure(); + } + + // Calculate the broadcasted shape. + SmallVector broadcasted_shape; + if (!OpTrait::util::getBroadcastedShape(lhs_shape, rhs_shape, + broadcasted_shape)) { + return failure(); + } + + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(broadcasted_shape, cond_shape, + result_shape)) { + return failure(); + } + + // Create a const op, that stores the above broadcasted shape. + auto shape_type = + RankedTensorType::get(result_shape.size(), rewriter.getIntegerType(64)); + auto new_shape_attr = + mlir::DenseIntElementsAttr::get(shape_type, result_shape); + auto new_shape = rewriter.create(op->getLoc(), new_shape_attr); + + // Apply BroadcastTo ops to each input. + auto cond_result_type = + RankedTensorType::get(result_shape, rewriter.getIntegerType(1)); + auto result_type = RankedTensorType::get( + result_shape, getElementTypeOrSelf(lhs.getType())); + + if (result_shape != cond_shape) { + cond = rewriter + .create(op->getLoc(), cond_result_type, + cond, new_shape) + .output(); + } + if (result_shape != lhs_shape) { + lhs = rewriter + .create(op->getLoc(), result_type, lhs, + new_shape) + .output(); + } + if (result_shape != rhs_shape) { + rhs = rewriter + .create(op->getLoc(), result_type, rhs, + new_shape) + .output(); + } + + // Recreate an op with the above Broadcast op results. + rewriter.replaceOpWithNewOp(op, result_type, cond, lhs, + rhs); + return success(); + } +}; + +void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) { // Add TF->TF lowering patterns. TF::PopulateLoweringTFPatterns(context, &patterns); @@ -656,7 +800,25 @@ void LegalizeTF::runOnFunction() { // Ophint python converter converted tf node pattern. patterns.insert(context); - FrozenRewritePatternList frozenPatterns(std::move(patterns)); +} + +void applyPatterns(FuncOp func, ConversionTarget& target, + FrozenRewritePatternList& frozenPatterns) { + // Keep trying to convert. + // TODO(karimnosseir): This is similar to what apply greedy patterns does. + // Look if there is a function that tries until it converge. + // Currently unit-test doesn't do multiple tries, so we need this. + const int max_iterations = 15; + for (int i = 0; i < max_iterations; ++i) { + if (failed(applyPartialConversion(func, target, frozenPatterns))) { + return; + } + } +} + +void LegalizeTF::runOnFunction() { + auto* context = &getContext(); + auto func = getFunction(); ConversionTarget target(*context); // It is legal to have TF ops in the graph still which can be @@ -690,16 +852,42 @@ void LegalizeTF::runOnFunction() { return success(current_thread_id == llvm::get_threadid()); }); - // Keep trying to convert. - // TODO(karimnosseir): This is similar to what apply greedy patterns does. - // Look if there is a function that tries until it converge. - // Currently unit-test doesn't do multiple tries, so we need this. - const int max_iterations = 15; - for (int i = 0; i < max_iterations; ++i) { - if (failed(applyPartialConversion(func, target, frozenPatterns))) { - return; - } - } + OwningRewritePatternList stage1Patterns; + + addPatterns(context, stage1Patterns); + + FrozenRewritePatternList stage1FrozenPatterns(std::move(stage1Patterns)); + applyPatterns(func, target, stage1FrozenPatterns); + + // Explict BroadcastTo addition for left-over broadcast-able ops. + // The following pattern matchings should be done after the other legalization + // rules in order not to add unnecessary BroadcastTo ops. + OwningRewritePatternList stage2Patterns; + + addPatterns(context, stage2Patterns); + + stage2Patterns.insert, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting, + ApplyExplicitBroadcasting>(context); + + FrozenRewritePatternList stage2FrozenPatterns(std::move(stage2Patterns)); + applyPatterns(func, target, stage2FrozenPatterns); } } // namespace diff --git a/tensorflow/lite/testing/op_tests/binary_op.py b/tensorflow/lite/testing/op_tests/binary_op.py index 17ed2f3522d..936563cc63d 100644 --- a/tensorflow/lite/testing/op_tests/binary_op.py +++ b/tensorflow/lite/testing/op_tests/binary_op.py @@ -178,6 +178,19 @@ def make_binary_op_tests(options, }, ] + # High dimension broadcasting support in MLIR converter. + if options.use_experimental_converter: + test_parameters = test_parameters + [ + { + "dtype": [tf.float32], + "input_shape_1": [[8, 7, 6, 5, 4, 3, 2, 1]], + "input_shape_2": [[4, 3, 2, 1]], + "activation": [False], + "fully_quantize": [False], + "dynamic_range_quantize": [False], + }, + ] + # test_parameters include fully_quantize option only when # allow_fully_quantize is True. if not allow_fully_quantize: diff --git a/tensorflow/lite/testing/op_tests/where.py b/tensorflow/lite/testing/op_tests/where.py index 90db8d56f25..7c5a899e00b 100644 --- a/tensorflow/lite/testing/op_tests/where.py +++ b/tensorflow/lite/testing/op_tests/where.py @@ -40,6 +40,16 @@ def make_where_tests(options): }, ] + # High dimension broadcasting support in MLIR converter. + if options.use_experimental_converter: + test_parameters = test_parameters + [ + { + "input_dtype": [tf.float32, tf.int32], + "input_shape_set": [([8, 7, 6, 5, 4, 3, 2, 1], [4, 3, 2, 1]),], + "use_where_v2": [True], + }, + ] + def build_graph(parameters): """Build the where op testing graph.""" input_value1 = tf.compat.v1.placeholder(