diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 89541d7a76a..de0774ded1d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -299,14 +299,27 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[ "getElementTypeOrSelf($_op.getResult(" # i # "))) == " "quant::QuantizedType::castToStorageType(" "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>; + //===----------------------------------------------------------------------===// // TFL op common constraints. //===----------------------------------------------------------------------===// // This is a constraint for most of the binary ops, e.g., add, mul, div, etc. -// Binary ops lhs & rhs should have the same value type. +// Binary ops lhs & rhs should have the same value type, and is capable to +// compare quantiziation types as well. def BinaryOpSameElementTypeConstraint : - PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<0, 1>>; + PredOpTrait<"operands have same element type", + Or<[ + TCopVTEtIsSameAs<0, 1>, + // Two operands' values are both quantized and their type have the same + // underlying storage type. + And<[ + SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(0))", + quant_QuantizedType.predicate>, + CPred<"quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(0))) == " + "quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(1)))">]>]>>; //===----------------------------------------------------------------------===// // TFL common builders. @@ -1937,6 +1950,8 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, TFL_GpuTargetOp]> { let summary = "Multiplication operator"; diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index 0d612cec961..58d5afb5864 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + exclude = ["load-quantization-recipe.mlir"], tags_override = { "legalize-tf.mlir": ["no_rocm"], "optimize.mlir": ["no_rocm"], diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 4359d8e0f4a..38f736ee378 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -277,6 +277,34 @@ func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> { return %0#0 : tensor<? x i32> } +// CHECK-LABEL: testMulNonQuantizedOperandsandQuantizedResult +func @testMulNonQuantizedOperandsandQuantizedResult(tensor<? x f32>, tensor<? x f32>) -> tensor<? x !quant.any<i16:f32>> { +^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>): + // CHECK: "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} + %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<? x f32>, tensor<? x f32>) -> tensor<? x !quant.any<i16:f32>> + return %0#0 : tensor<? x !quant.any<i16:f32>> +} + +// ----- + +func @testMulInvalidOperands(tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32> { +^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x i32>): + // expected-error @+1 {{failed to verify that operands have same element type}} + %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32> + return %0#0 : tensor<? x i32> +} + +// ----- + +func @testMulInvalidQuantizedOperands(tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>> { +^bb0(%arg0: tensor<* x !quant.any<i16:f32>>, %arg1: tensor<* x !quant.any<i8:f32>>): + // expected-error @+1 {{failed to verify that operands have same element type}} + %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>> + return %0#0 : tensor<* x !quant.any<i16:f32>> +} + +// ----- + // CHECK-LABEL: testDiv func @testDiv(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> { ^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):