Skip the input stats if there are emulation nodes for the inputs from training

We should unconditionally respect the emulation nodes from training.

PiperOrigin-RevId: 293834268
Change-Id: Ie6b9948cf373071147a24b16ed845933a4e53044
This commit is contained in:
Feng Liu 2020-02-07 09:55:20 -08:00 committed by TensorFlower Gardener
parent f04ae00e19
commit c879a09690
2 changed files with 18 additions and 1 deletions
tensorflow/compiler/mlir/lite

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only" | FileCheck %s
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
// CHECK-LABEL: quantize_float_placeholder_only
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
@ -11,6 +11,15 @@ func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>
// CHECK-NEXT: %[[dq]], %arg1, %[[dq_0]]
}
// CHECK-LABEL: not_reset_input
func @not_reset_input(%arg0: tensor<f32>) -> (tensor<!quant.uniform<i16:f32, 1.0>>) {
%0 = "tfl.quantize"(%arg0) {qtype = tensor<!quant.uniform<i16:f32, 1.0>>} : (tensor<f32>) -> tensor<!quant.uniform<i16:f32, 1.0>>
return %0: tensor<!quant.uniform<i16:f32, 1.0>>
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<!quant.uniform<i16:f32, 1.000000e+00>>}
// CHECK-NEXT: return %[[q]]
}
// CHECK-LABEL: DequantizeAndQuantize
func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
%cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
@ -144,6 +145,13 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
int i) {
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
if (shaped.getElementType().isa<FloatType>()) {
// If there are existing quantize ops, they are from training and we
// should respect them.
if (arg.hasOneUse() &&
llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) {
return;
}
auto min_max = GetMinMaxValuesForArgument(func_name, i);
TypeAttr params = quant::GetQuantizedTypeAttr(
builder, input_type, builder.getF64FloatAttr(min_max.first),