Add sanity checks to TFLite's MulOp
Also it includes the following changes: - Provides better support on quantization type checking on the BinaryOpSampElementTypeConstraint and TFL_TCresVTEtIsSameAsOp. - Disables load quantization recipe test. PiperOrigin-RevId: 308758999 Change-Id: Ic3564ee9caa15ca2ed3f064b24bb8d2e02de3eb5
This commit is contained in:
parent
7b357dd06b
commit
428cdeda09
@ -299,14 +299,27 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
|||||||
"getElementTypeOrSelf($_op.getResult(" # i # "))) == "
|
"getElementTypeOrSelf($_op.getResult(" # i # "))) == "
|
||||||
"quant::QuantizedType::castToStorageType("
|
"quant::QuantizedType::castToStorageType("
|
||||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL op common constraints.
|
// TFL op common constraints.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
|
// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
|
||||||
// Binary ops lhs & rhs should have the same value type.
|
// Binary ops lhs & rhs should have the same value type, and is capable to
|
||||||
|
// compare quantiziation types as well.
|
||||||
def BinaryOpSameElementTypeConstraint :
|
def BinaryOpSameElementTypeConstraint :
|
||||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<0, 1>>;
|
PredOpTrait<"operands have same element type",
|
||||||
|
Or<[
|
||||||
|
TCopVTEtIsSameAs<0, 1>,
|
||||||
|
// Two operands' values are both quantized and their type have the same
|
||||||
|
// underlying storage type.
|
||||||
|
And<[
|
||||||
|
SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(0))",
|
||||||
|
quant_QuantizedType.predicate>,
|
||||||
|
CPred<"quant::QuantizedType::castToStorageType("
|
||||||
|
"getElementTypeOrSelf($_op.getOperand(0))) == "
|
||||||
|
"quant::QuantizedType::castToStorageType("
|
||||||
|
"getElementTypeOrSelf($_op.getOperand(1)))">]>]>>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TFL common builders.
|
// TFL common builders.
|
||||||
@ -1937,6 +1950,8 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
|
|||||||
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape,
|
def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape,
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
Commutative,
|
Commutative,
|
||||||
|
BinaryOpSameElementTypeConstraint,
|
||||||
|
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>,
|
||||||
TFL_GpuTargetOp]> {
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Multiplication operator";
|
let summary = "Multiplication operator";
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ package(licenses = ["notice"])
|
|||||||
glob_lit_tests(
|
glob_lit_tests(
|
||||||
data = [":test_utilities"],
|
data = [":test_utilities"],
|
||||||
driver = "@llvm-project//mlir:run_lit.sh",
|
driver = "@llvm-project//mlir:run_lit.sh",
|
||||||
|
exclude = ["load-quantization-recipe.mlir"],
|
||||||
tags_override = {
|
tags_override = {
|
||||||
"legalize-tf.mlir": ["no_rocm"],
|
"legalize-tf.mlir": ["no_rocm"],
|
||||||
"optimize.mlir": ["no_rocm"],
|
"optimize.mlir": ["no_rocm"],
|
||||||
|
@ -277,6 +277,34 @@ func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
|||||||
return %0#0 : tensor<? x i32>
|
return %0#0 : tensor<? x i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testMulNonQuantizedOperandsandQuantizedResult
|
||||||
|
func @testMulNonQuantizedOperandsandQuantizedResult(tensor<? x f32>, tensor<? x f32>) -> tensor<? x !quant.any<i16:f32>> {
|
||||||
|
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
|
||||||
|
// CHECK: "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}
|
||||||
|
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<? x f32>, tensor<? x f32>) -> tensor<? x !quant.any<i16:f32>>
|
||||||
|
return %0#0 : tensor<? x !quant.any<i16:f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testMulInvalidOperands(tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||||
|
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x i32>):
|
||||||
|
// expected-error @+1 {{failed to verify that operands have same element type}}
|
||||||
|
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32>
|
||||||
|
return %0#0 : tensor<? x i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @testMulInvalidQuantizedOperands(tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>> {
|
||||||
|
^bb0(%arg0: tensor<* x !quant.any<i16:f32>>, %arg1: tensor<* x !quant.any<i8:f32>>):
|
||||||
|
// expected-error @+1 {{failed to verify that operands have same element type}}
|
||||||
|
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>>
|
||||||
|
return %0#0 : tensor<* x !quant.any<i16:f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: testDiv
|
// CHECK-LABEL: testDiv
|
||||||
func @testDiv(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
func @testDiv(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||||
|
Loading…
Reference in New Issue
Block a user