Relax the op traits of "tfl.squared_difference" and "tfl.rsqrt" to allow quantized types.

Quantized types are supported in the TFLite kernels.

PiperOrigin-RevId: 345609162
Change-Id: Id529ee8e480a7312329658558587ea2403d35de6
This commit is contained in:
Jing Pu 2020-12-03 22:34:31 -08:00 committed by TensorFlower Gardener
parent b27a179f63
commit fedf3f45fb
2 changed files with 29 additions and 14 deletions

View File

@ -349,7 +349,11 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
"quant::QuantizedType::castToStorageType("
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
def TFL_SameFirstOperandAndFirstResultElementType :
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>;
// This is a quantization-aware version of TCopVTEtAreSameAt
class TFL_TCopVTEtAreSameAt<int i, int j> : Or<[
TCopVTEtAreSameAt<[i, j]>,
TFL_TFOperandTypesWithSameBits<i, j, 8>,
@ -717,8 +721,7 @@ def TFL_CeilOp: TFL_Op<"ceil", [
def TFL_ConcatenationOp : TFL_Op<"concatenation",
[
NoSideEffect,
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_SameFirstOperandAndFirstResultElementType,
SameOperandsAndResultsScale
]> {
let summary = "Concatenation operator";
@ -2277,8 +2280,7 @@ def TFL_NegOp: TFL_Op<"neg", [
}
def TFL_PackOp : TFL_Op<"pack", [
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_SameFirstOperandAndFirstResultElementType,
NoSideEffect,
SameOperandsAndResultsScale]> {
let summary = "Packs a list of tensors along a dimension into one tensor";
@ -2630,7 +2632,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
}
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
SameOperandsAndResultType,
TFL_SameFirstOperandAndFirstResultElementType,
SameOperandsAndResultShape,
NoQuantizableResult]> {
let summary = "Reciprocal of square root operator";
@ -2639,9 +2641,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
Computes element-wise reverse square root of input
}];
let arguments = (ins TFL_FpTensor:$x);
let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
let results = (outs TFL_FpTensor:$y);
let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
let hasFolder = 1;
}
@ -2919,11 +2921,10 @@ def TFL_SubOp : TFL_Op<"sub", [
let hasOptions = 1;
}
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
SameOperandsAndResultElementType,
BinaryOpSameElementTypeConstraint,
TFL_SameFirstOperandAndFirstResultElementType,
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult]> {
@ -2934,10 +2935,10 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
}];
let arguments = (
ins TFL_TensorOf<[F32, I32]>:$lhs,
TFL_TensorOf<[F32, I32]>:$rhs);
ins TFL_TensorOf<[F32, I32, QI8]>:$lhs,
TFL_TensorOf<[F32, I32, QI8]>:$rhs);
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, QI8]>:$output);
let builders = [TFL_BroadcastableBinaryBuilder];

View File

@ -134,6 +134,12 @@ func @testRsqrt(tensor<? x f32>) -> tensor<? x f32> {
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testRsqrtQuant
func @testRsqrtQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.048358432948589325:-128>>) -> tensor<1x80x1x!quant.uniform<i8:f32, 0.0066055487841367722:-128>> {
%0 = "tfl.rsqrt"(%arg0) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.048358432948589325:-128>>) -> tensor<1x80x1x!quant.uniform<i8:f32, 0.0066055487841367722:-128>>
return %0 : tensor<1x80x1x!quant.uniform<i8:f32, 0.0066055487841367722:-128>>
}
// CHECK-LABEL: testSin
func @testSin(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
@ -174,6 +180,14 @@ func @testSquareWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
// -----
// CHECK-LABEL: testSquaredDifferenceQuant
func @testSquaredDifferenceQuant(%arg0: tensor<1x80x128x!quant.uniform<i8:f32, 0.089839041233062744:10>>, %arg1: tensor<1x80x128x!quant.uniform<i8:f32, 0.0019308560295030475:-6>>) -> tensor<1x80x128x!quant.uniform<i8:f32, 0.60070550441741943:-128>> {
%0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x80x128x!quant.uniform<i8:f32, 0.089839041233062744:10>>, tensor<1x80x128x!quant.uniform<i8:f32, 0.0019308560295030475:-6>>) -> tensor<1x80x128x!quant.uniform<i8:f32, 0.60070550441741943:-128>>
return %0 : tensor<1x80x128x!quant.uniform<i8:f32, 0.60070550441741943:-128>>
}
// -----
// CHECK-LABEL: testSqrt
func @testSqrt(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):