diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 06c23df0ae4..1a46acdcbe0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -141,11 +141,11 @@ func @QuantizeConcatResToAll(tensor<2xf32>, tensor<2xf32>) -> tensor<2x2x!quant. %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %1 : tensor<2x2x!quant.uniform> -// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> +// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<2x!quant.uniform>) -> tensor<2xf32> -// CHECK: %2 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> +// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<2x!quant.uniform>) -> tensor<2xf32> -// CHECK: %4 = "tfl.concatenation"(%1, %3) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32> +// CHECK: %4 = "tfl.concatenation"(%3, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32> // CHECK: %5 = "tfl.quantize"(%4) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> // CHECK: return %5 : tensor<2x2x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 752c03ee6a9..ea88c10b6bf 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -138,9 +138,9 @@ func @QuantizeConcat(tensor<2xf32>, tensor<2xf32>) -> tensor<2x2x!quant.uniform< %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %1 : tensor<2x2x!quant.uniform> -// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} -// CHECK: %1 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform>} -// CHECK: %2 = "tfl.concatenation"(%0, %1) {axis = 0 : i32, fused_activation_function = "NONE"} +// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform>} +// CHECK: %1 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform>} +// CHECK: %2 = "tfl.concatenation"(%1, %0) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %2 : tensor<2x2x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc b/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc index 165a157aecd..7ceb0f5c86e 100644 --- a/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc @@ -267,7 +267,9 @@ class QuantizationDriver { } cached.first->second = InitializeState(op, index, in, /*as_result=*/false); if (is_argument) { - arg_states_[llvm::cast(in)] = cached.first->second; + auto *arg = llvm::cast(in); + arg_states_[arg] = cached.first->second; + args_.push_back(arg); } } @@ -299,11 +301,15 @@ class QuantizationDriver { // the values from `operand_states_` and `result_state_`. std::unordered_map rescale_states_; - // Maps of indexes to the propagation state vector from the ops results and - // op operands. Both maps are unmodified after initialization. + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. llvm::DenseMap operand_states_; llvm::DenseMap result_states_; llvm::DenseMap arg_states_; + + // This vector is to preserve the arguments order, so the newly inserted + // quantized ops for the arguments are deterministically ordered. + llvm::SmallVector args_; }; #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" @@ -656,10 +662,7 @@ bool QuantizationDriver::PropagateParams() { } void QuantizationDriver::Finalize() { - std::map sorted_states(arg_states_.begin(), - arg_states_.end()); - for (auto it : sorted_states) { - BlockArgument *arg = it.first; + for (auto *arg : args_) { auto &state = GetArgQuantState(arg); auto &requantize = GetArgRequantizeState(arg); if (state.IsEmpty() ||