diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 2861c14b32b..39149583918 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -510,6 +510,12 @@ Operation* BuildVariableOp(const tflite::TensorT& tensor, return op.getOperation(); } auto op = builder.create(loc, value); + if (!tensor.quantization->min.empty()) { + if (auto stats_op = + ConvertMinMaxToStatsOp(tensor, builder, op.getResult())) { + return stats_op; + } + } return op.getOperation(); } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json index ca6c278722d..03d8de6f175 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json @@ -1,10 +1,14 @@ // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s -// CHECK: effective_hidden_scale_intermediate = tensor>> -// CHECK: input_to_cell_intermediate = tensor>> -// CHECK: input_to_forget_intermediate = tensor>> -// CHECK: input_to_input_intermediate = tensor>> -// CHECK: input_to_output_intermediate = tensor>> +// CHECK-DAG: %[[input_18:.*]] = "quant.stats"({{.*}}) {layerStats = dense<[-8.000000e-01, 1.600000e+00]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK-DAG: %[[input_19:.*]] = "quant.stats"({{.*}}) {layerStats = dense<[-2.000000e+00, 4.000000e+00]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + +// CHECK: "tfl.unidirectional_sequence_lstm"({{.*}}, %[[input_18]], %[[input_19]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) +// CHECK-SAME: effective_hidden_scale_intermediate = tensor>> +// CHECK-SAME: input_to_cell_intermediate = tensor>> +// CHECK-SAME: input_to_forget_intermediate = tensor>> +// CHECK-SAME: input_to_input_intermediate = tensor>> +// CHECK-SAME: input_to_output_intermediate = tensor>> { "version": 3, @@ -110,8 +114,8 @@ "name": "input_activation_state18", "is_variable": true, "quantization": { - "min": [-0.9], - "max": [0.9] + "min": [-0.8], + "max": [1.6] } }, { @@ -119,8 +123,8 @@ "name": "input_cell_state19", "is_variable": true, "quantization": { - "min": [-0.8], - "max": [0.8] + "min": [-2.0], + "max": [4.0] } }, {