Skip the input stats if there are emulation nodes for the inputs from training
We should unconditionally respect the emulation nodes from training. PiperOrigin-RevId: 293834268 Change-Id: Ie6b9948cf373071147a24b16ed845933a4e53044
This commit is contained in:
parent
f04ae00e19
commit
c879a09690
tensorflow/compiler/mlir/lite
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only" | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: quantize_float_placeholder_only
|
||||
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
|
||||
@ -11,6 +11,15 @@ func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>
|
||||
// CHECK-NEXT: %[[dq]], %arg1, %[[dq_0]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: not_reset_input
|
||||
func @not_reset_input(%arg0: tensor<f32>) -> (tensor<!quant.uniform<i16:f32, 1.0>>) {
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<!quant.uniform<i16:f32, 1.0>>} : (tensor<f32>) -> tensor<!quant.uniform<i16:f32, 1.0>>
|
||||
return %0: tensor<!quant.uniform<i16:f32, 1.0>>
|
||||
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<!quant.uniform<i16:f32, 1.000000e+00>>}
|
||||
// CHECK-NEXT: return %[[q]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DequantizeAndQuantize
|
||||
func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
%cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>>
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
@ -144,6 +145,13 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
int i) {
|
||||
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
|
||||
if (shaped.getElementType().isa<FloatType>()) {
|
||||
// If there are existing quantize ops, they are from training and we
|
||||
// should respect them.
|
||||
if (arg.hasOneUse() &&
|
||||
llvm::isa<quant::QuantizeCastOp>(*arg.user_begin())) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto min_max = GetMinMaxValuesForArgument(func_name, i);
|
||||
TypeAttr params = quant::GetQuantizedTypeAttr(
|
||||
builder, input_type, builder.getF64FloatAttr(min_max.first),
|
||||
|
Loading…
Reference in New Issue
Block a user