Import hlo module and run canonicalization passes before quantization

PiperOrigin-RevId: 301305074
Change-Id: I807753ce6964649f74ccb3f5cd7c61b99ff18638
This commit is contained in:
Feng Liu 2020-03-16 22:10:54 -07:00 committed by TensorFlower Gardener
parent 666f21add8
commit 489126360d
6 changed files with 312 additions and 1 deletions

View File

@ -50,6 +50,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -59,5 +60,8 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/platform:status", "//tensorflow/core/platform:status",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
], ],
) )

View File

@ -14,6 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" #include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -23,6 +31,30 @@ namespace xla_hlo {
// Quantizes the model in the computation. // Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation) { xla::XlaComputation* computation) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
computation->Snapshot());
MLIRContext context;
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
auto status = xla::ConvertHloToMlirHlo(
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return status;
}
PassManager pm(&context);
pm.addPass(createCanonicalizerPass());
pm.addPass(createInlinerPass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(createCSEPass());
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
LogicalResult result = pm.run(module.get());
(void)result;
module->dump();
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }

View File

@ -3,8 +3,14 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"]) package(licenses = ["notice"])
glob_lit_tests( glob_lit_tests(
data = [":test_utilities"], data = [
":graph_config_files",
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh", driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
test_file_exts = ["mlir"], test_file_exts = ["mlir"],
) )
@ -13,7 +19,17 @@ filegroup(
name = "test_utilities", name = "test_utilities",
testonly = True, testonly = True,
data = [ data = [
"//tensorflow/compiler/aot:tfcompile",
"//tensorflow/compiler/mlir:tf-opt", "//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck", "@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
], ],
) )
# Bundle together all the graph files that are used by the tests.
filegroup(
name = "graph_config_files",
srcs = glob(
["**/*.pbtxt"],
),
)

View File

@ -0,0 +1,15 @@
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
# TODO(fengliuai): update this file with the progress of the implementation
// CHECK: func @main
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
// CHECK: return %4 : tuple<tensor<2x4xf32>>
// CHECK: }

View File

@ -0,0 +1,26 @@
feed {
id { node_name: "input0" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
feed {
id { node_name: "input1" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
fetch {
id { node_name: "Add/FakeQuantWithMinMaxVars" }
shape {
dim { size: 2 }
dim { size: 4 }
}
}
conversion_options {
custom_fake_quant_op_calls: true
}

View File

@ -0,0 +1,218 @@
node: {
name: "Add/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "Add"
input: "Add/FakeQuantWithMinMaxVars/min"
input: "Add/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "Add/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "Add"
input: "input0/FakeQuantWithMinMaxVars"
input: "input1/FakeQuantWithMinMaxVars"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input0"
input: "input0/FakeQuantWithMinMaxVars/min"
input: "input0/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input0/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input0"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars"
op: "FakeQuantWithMinMaxVars"
input: "input1"
input: "input1/FakeQuantWithMinMaxVars/min"
input: "input1/FakeQuantWithMinMaxVars/max"
attr: {
key: "num_bits"
value: {
i: 8
}
}
attr: {
key: "narrow_range"
value: {
b: false
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/min"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 0.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node: {
name: "input1/FakeQuantWithMinMaxVars/max"
op: "Const"
attr: {
key: "value"
value: {
tensor: {
dtype: DT_FLOAT
tensor_shape: {
}
float_val: 127.0
}
}
}
attr: {
key: "dtype"
value: {
type: DT_FLOAT
}
}
}
node {
name: "input1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 27
}