Relax the op traits of "tfl.squared_difference" and "tfl.rsqrt" to allow quantized types.
Quantized types are supported in the TFLite kernels. PiperOrigin-RevId: 345609162 Change-Id: Id529ee8e480a7312329658558587ea2403d35de6
This commit is contained in:
parent
b27a179f63
commit
fedf3f45fb
@ -349,7 +349,11 @@ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
"quant::QuantizedType::castToStorageType("
|
||||
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
def TFL_SameFirstOperandAndFirstResultElementType :
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>;
|
||||
|
||||
// This is a quantization-aware version of TCopVTEtAreSameAt
|
||||
class TFL_TCopVTEtAreSameAt<int i, int j> : Or<[
|
||||
TCopVTEtAreSameAt<[i, j]>,
|
||||
TFL_TFOperandTypesWithSameBits<i, j, 8>,
|
||||
@ -717,8 +721,7 @@ def TFL_CeilOp: TFL_Op<"ceil", [
|
||||
def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
[
|
||||
NoSideEffect,
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_SameFirstOperandAndFirstResultElementType,
|
||||
SameOperandsAndResultsScale
|
||||
]> {
|
||||
let summary = "Concatenation operator";
|
||||
@ -2277,8 +2280,7 @@ def TFL_NegOp: TFL_Op<"neg", [
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_SameFirstOperandAndFirstResultElementType,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Packs a list of tensors along a dimension into one tensor";
|
||||
@ -2630,7 +2632,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
|
||||
}
|
||||
|
||||
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
TFL_SameFirstOperandAndFirstResultElementType,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Reciprocal of square root operator";
|
||||
@ -2639,9 +2641,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||
Computes element-wise reverse square root of input
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_FpTensor:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
|
||||
|
||||
let results = (outs TFL_FpTensor:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
@ -2919,11 +2921,10 @@ def TFL_SubOp : TFL_Op<"sub", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Expand the kernel implementation to support all types besides
|
||||
// I32 and F32.
|
||||
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
|
||||
SameOperandsAndResultElementType,
|
||||
BinaryOpSameElementTypeConstraint,
|
||||
TFL_SameFirstOperandAndFirstResultElementType,
|
||||
ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
NoQuantizableResult]> {
|
||||
@ -2934,10 +2935,10 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32]>:$rhs);
|
||||
ins TFL_TensorOf<[F32, I32, QI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8]>:$rhs);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8]>:$output);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
|
||||
|
||||
@ -134,6 +134,12 @@ func @testRsqrt(tensor<? x f32>) -> tensor<? x f32> {
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRsqrtQuant
|
||||
func @testRsqrtQuant(%arg0: tensor<1x80x1x!quant.uniform<i8:f32, 0.048358432948589325:-128>>) -> tensor<1x80x1x!quant.uniform<i8:f32, 0.0066055487841367722:-128>> {
|
||||
%0 = "tfl.rsqrt"(%arg0) : (tensor<1x80x1x!quant.uniform<i8:f32, 0.048358432948589325:-128>>) -> tensor<1x80x1x!quant.uniform<i8:f32, 0.0066055487841367722:-128>>
|
||||
return %0 : tensor<1x80x1x!quant.uniform<i8:f32, 0.0066055487841367722:-128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSin
|
||||
func @testSin(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
@ -174,6 +180,14 @@ func @testSquareWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSquaredDifferenceQuant
|
||||
func @testSquaredDifferenceQuant(%arg0: tensor<1x80x128x!quant.uniform<i8:f32, 0.089839041233062744:10>>, %arg1: tensor<1x80x128x!quant.uniform<i8:f32, 0.0019308560295030475:-6>>) -> tensor<1x80x128x!quant.uniform<i8:f32, 0.60070550441741943:-128>> {
|
||||
%0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x80x128x!quant.uniform<i8:f32, 0.089839041233062744:10>>, tensor<1x80x128x!quant.uniform<i8:f32, 0.0019308560295030475:-6>>) -> tensor<1x80x128x!quant.uniform<i8:f32, 0.60070550441741943:-128>>
|
||||
return %0 : tensor<1x80x128x!quant.uniform<i8:f32, 0.60070550441741943:-128>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testSqrt
|
||||
func @testSqrt(tensor<? x f32>) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user