Fix the same element type trait of tfl.greater_equal op

PiperOrigin-RevId: 348483743
Change-Id: I34f65e5b9fd45c18642c3ebf3882b35929d43493
This commit is contained in:
Feng Liu 2020-12-21 09:59:10 -08:00 committed by TensorFlower Gardener
parent 3fb987e825
commit f6972931f2
2 changed files with 23 additions and 13 deletions

View File

@ -369,11 +369,8 @@ class TFL_TCopVTEtAreSameAt<int i, int j, int num=8> : Or<[
// TFL op common constraints.
//===----------------------------------------------------------------------===//
// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
// Binary ops lhs & rhs should have the same value type, and is capable to
// compare quantization types as well.
def BinaryOpSameElementTypeConstraint :
PredOpTrait<"operands have same element type",
class OperandsSameElementTypeConstraintBase<string op> :
PredOpTrait<op # " operands have same element type",
Or<[
TCopVTEtIsSameAs<0, 1>,
// Two operands' values are both quantized and their type have the same
@ -386,6 +383,18 @@ def BinaryOpSameElementTypeConstraint :
"quant::QuantizedType::castToStorageType("
"getElementTypeOrSelf($_op.getOperand(1)))">]>]>>;
// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
// Binary ops lhs & rhs should have the same value type, and is capable to
// compare quantization types as well.
def BinaryOpSameElementTypeConstraint :
OperandsSameElementTypeConstraintBase<"binary op">;
// This is a constraint for most of the comparison ops, e.g., equal, not_equal,
// greater, greater_equal, less, etc. Comparison ops lhs & rhs should have the
// same value type, and is capable to compare quantization types as well.
def ComparisonOpSameElementTypeConstraint :
OperandsSameElementTypeConstraintBase<"comparison op">;
//===----------------------------------------------------------------------===//
// TFL common builders.
//===----------------------------------------------------------------------===//
@ -1100,7 +1109,7 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [
// Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait.
def TFL_LessEqualOp : TFL_Op<"less_equal", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
ComparisonOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect]> {
let summary = "Less_equal operator";
@ -1164,6 +1173,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
ResultsBroadcastableShape,
ComparisonOpSameElementTypeConstraint,
NoSideEffect]> {
let summary = "Greater_equal operator";
@ -1355,7 +1365,7 @@ larger than 0.
def TFL_NotEqualOp : TFL_Op<"not_equal", [
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
BinaryOpSameElementTypeConstraint,
ComparisonOpSameElementTypeConstraint,
ResultsBroadcastableShape,
Commutative,
NoSideEffect,
@ -1462,7 +1472,7 @@ def TFL_EqualOp: TFL_Op<"equal", [
NoSideEffect,
ResultsBroadcastableShape,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
ComparisonOpSameElementTypeConstraint]> {
let summary = "Equal operator";
let description = [{
@ -1669,7 +1679,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [
def TFL_GreaterOp : TFL_Op<"greater", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
ComparisonOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect]> {
let summary = "Greater operator";
@ -1768,7 +1778,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
def TFL_LessOp : TFL_Op<"less", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
ComparisonOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect]> {
let summary = "Less operator";

View File

@ -357,7 +357,7 @@ func @testMulNonQuantizedOperandsandQuantizedResult(tensor<? x f32>, tensor<? x
func @testMulInvalidOperands(tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x i32>):
// expected-error @+1 {{failed to verify that operands have same element type}}
// expected-error @+1 {{failed to verify that binary op operands have same element type}}
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
@ -366,7 +366,7 @@ func @testMulInvalidOperands(tensor<? x f32>, tensor<? x i32>) -> tensor<? x i32
func @testMulInvalidQuantizedOperands(tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>> {
^bb0(%arg0: tensor<* x !quant.any<i16:f32>>, %arg1: tensor<* x !quant.any<i8:f32>>):
// expected-error @+1 {{failed to verify that operands have same element type}}
// expected-error @+1 {{failed to verify that binary op operands have same element type}}
%0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<* x !quant.any<i16:f32>>, tensor<* x !quant.any<i8:f32>>) -> tensor<* x !quant.any<i16:f32>>
return %0#0 : tensor<* x !quant.any<i16:f32>>
}
@ -412,7 +412,7 @@ func @testFloorDivF32(tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
// -----
func @testFloorDivF32(%arg0: tensor<2 x f32>, %arg1: tensor<2 x i32>) -> tensor<2 x f32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
// expected-error @+1 {{failed to verify that binary op operands have same element type}}
%0 = "tfl.floor_div"(%arg0, %arg1) : (tensor<2 x f32>, tensor<2 x i32>) -> tensor<2 x f32>
return %0#0 : tensor<2 x f32>
}