preserve the argument order when inserting the fake quant ops

Previously, it relies on pointer values to determine the inserting order. This will introduce test flakiness.
This CL makes the order deterministic by using the op visited order.

PiperOrigin-RevId: 256266368
This commit is contained in:
Feng Liu 2019-07-02 16:48:51 -07:00 committed by TensorFlower Gardener
parent 576d32c57d
commit aabcdcbdff
3 changed files with 16 additions and 13 deletions

View File

@ -141,11 +141,11 @@ func @QuantizeConcatResToAll(tensor<2xf32>, tensor<2xf32>) -> tensor<2x2x!quant.
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
// CHECK: %2 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
// CHECK: %4 = "tfl.concatenation"(%1, %3) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = "tfl.concatenation"(%3, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK: %5 = "tfl.quantize"(%4) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: return %5 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}

View File

@ -138,9 +138,9 @@ func @QuantizeConcat(tensor<2xf32>, tensor<2xf32>) -> tensor<2x2x!quant.uniform<
%1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %1 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %1 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %2 = "tfl.concatenation"(%0, %1) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %1 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
// CHECK: %2 = "tfl.concatenation"(%1, %0) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %2 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}

View File

@ -267,7 +267,9 @@ class QuantizationDriver {
}
cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
if (is_argument) {
arg_states_[llvm::cast<BlockArgument>(in)] = cached.first->second;
auto *arg = llvm::cast<BlockArgument>(in);
arg_states_[arg] = cached.first->second;
args_.push_back(arg);
}
}
@ -299,11 +301,15 @@ class QuantizationDriver {
// the values from `operand_states_` and `result_state_`.
std::unordered_map<int, RequantizeState> rescale_states_;
// Maps of indexes to the propagation state vector from the ops results and
// op operands. Both maps are unmodified after initialization.
// Maps of indexes to the propagation state vector from the ops operands,
// results and arguments.
llvm::DenseMap<OpValue, int> operand_states_;
llvm::DenseMap<OpValue, int> result_states_;
llvm::DenseMap<BlockArgument *, int> arg_states_;
// This vector is to preserve the arguments order, so the newly inserted
// quantized ops for the arguments are deterministically ordered.
llvm::SmallVector<BlockArgument *, 4> args_;
};
#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc"
@ -656,10 +662,7 @@ bool QuantizationDriver::PropagateParams() {
}
void QuantizationDriver::Finalize() {
std::map<BlockArgument *, int> sorted_states(arg_states_.begin(),
arg_states_.end());
for (auto it : sorted_states) {
BlockArgument *arg = it.first;
for (auto *arg : args_) {
auto &state = GetArgQuantState(arg);
auto &requantize = GetArgRequantizeState(arg);
if (state.IsEmpty() ||