From fedf3f45fba7d7f0998f7da4177d949ae162838b Mon Sep 17 00:00:00 2001 From: Jing Pu Date: Thu, 3 Dec 2020 22:34:31 -0800 Subject: [PATCH] Relax the op traits of "tfl.squared_difference" and "tfl.rsqrt" to allow quantized types. Quantized types are supported in the TFLite kernels. PiperOrigin-RevId: 345609162 Change-Id: Id529ee8e480a7312329658558587ea2403d35de6 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 29 ++++++++++---------- tensorflow/compiler/mlir/lite/tests/ops.mlir | 14 ++++++++++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 7e653d32519..4ca957c173e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -349,7 +349,11 @@ class TFL_TCresVTEtIsSameAsOp : And<[ "quant::QuantizedType::castToStorageType(" "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>; -// This is a quantization-aware version of TCresVTEtIsSameAsOp +def TFL_SameFirstOperandAndFirstResultElementType : + PredOpTrait<"values and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>; + +// This is a quantization-aware version of TCopVTEtAreSameAt class TFL_TCopVTEtAreSameAt : Or<[ TCopVTEtAreSameAt<[i, j]>, TFL_TFOperandTypesWithSameBits, @@ -717,8 +721,7 @@ def TFL_CeilOp: TFL_Op<"ceil", [ def TFL_ConcatenationOp : TFL_Op<"concatenation", [ NoSideEffect, - PredOpTrait<"values and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_SameFirstOperandAndFirstResultElementType, SameOperandsAndResultsScale ]> { let summary = "Concatenation operator"; @@ -2277,8 +2280,7 @@ def TFL_NegOp: TFL_Op<"neg", [ } def TFL_PackOp : TFL_Op<"pack", [ - PredOpTrait<"values and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_SameFirstOperandAndFirstResultElementType, NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Packs a list of tensors along a dimension into one tensor"; @@ -2630,7 +2632,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension } def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, - SameOperandsAndResultType, + TFL_SameFirstOperandAndFirstResultElementType, SameOperandsAndResultShape, NoQuantizableResult]> { let summary = "Reciprocal of square root operator"; @@ -2639,9 +2641,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, Computes element-wise reverse square root of input }]; - let arguments = (ins TFL_FpTensor:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x); - let results = (outs TFL_FpTensor:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y); let hasFolder = 1; } @@ -2919,11 +2921,10 @@ def TFL_SubOp : TFL_Op<"sub", [ let hasOptions = 1; } -// TODO(jpienaar): Expand the kernel implementation to support all types besides -// I32 and F32. def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, - SameOperandsAndResultElementType, + BinaryOpSameElementTypeConstraint, + TFL_SameFirstOperandAndFirstResultElementType, ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { @@ -2934,10 +2935,10 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I32]>:$lhs, - TFL_TensorOf<[F32, I32]>:$rhs); + ins TFL_TensorOf<[F32, I32, QI8]>:$lhs, + TFL_TensorOf<[F32, I32, QI8]>:$rhs); - let results = (outs TFL_TensorOf<[F32, I32]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, QI8]>:$output); let builders = [TFL_BroadcastableBinaryBuilder]; diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 82ab9c779c1..07193955404 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -134,6 +134,12 @@ func @testRsqrt(tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: testRsqrtQuant +func @testRsqrtQuant(%arg0: tensor<1x80x1x!quant.uniform>) -> tensor<1x80x1x!quant.uniform> { + %0 = "tfl.rsqrt"(%arg0) : (tensor<1x80x1x!quant.uniform>) -> tensor<1x80x1x!quant.uniform> + return %0 : tensor<1x80x1x!quant.uniform> +} + // CHECK-LABEL: testSin func @testSin(tensor) -> tensor { ^bb0(%arg0: tensor): @@ -174,6 +180,14 @@ func @testSquareWithWrongInputType(tensor) -> tensor { // ----- +// CHECK-LABEL: testSquaredDifferenceQuant +func @testSquaredDifferenceQuant(%arg0: tensor<1x80x128x!quant.uniform>, %arg1: tensor<1x80x128x!quant.uniform>) -> tensor<1x80x128x!quant.uniform> { + %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x80x128x!quant.uniform>, tensor<1x80x128x!quant.uniform>) -> tensor<1x80x128x!quant.uniform> + return %0 : tensor<1x80x128x!quant.uniform> +} + +// ----- + // CHECK-LABEL: testSqrt func @testSqrt(tensor) -> tensor { ^bb0(%arg0: tensor):