Import hlo module and run canonicalization passes before quantization
PiperOrigin-RevId: 301305074 Change-Id: I807753ce6964649f74ccb3f5cd7c61b99ff18638
This commit is contained in:
parent
666f21add8
commit
489126360d
@ -50,6 +50,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
@ -59,5 +60,8 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -14,6 +14,14 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
@ -23,6 +31,30 @@ namespace xla_hlo {
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
||||
computation->Snapshot());
|
||||
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto status = xla::ConvertHloToMlirHlo(
|
||||
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createInlinerPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
|
||||
LogicalResult result = pm.run(module.get());
|
||||
(void)result;
|
||||
|
||||
module->dump();
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
|
@ -3,8 +3,14 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
data = [
|
||||
":graph_config_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||
},
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
@ -13,7 +19,17 @@ filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/aot:tfcompile",
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all the graph files that are used by the tests.
|
||||
filegroup(
|
||||
name = "graph_config_files",
|
||||
srcs = glob(
|
||||
["**/*.pbtxt"],
|
||||
),
|
||||
)
|
||||
|
@ -0,0 +1,15 @@
|
||||
# RUN: not tfcompile --graph=%s.pbtxt --config=%s.config.pbtxt --quantize --cpp_class="::test::fadd_quant" 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
|
||||
# TODO(fengliuai): update this file with the progress of the implementation
|
||||
// CHECK: func @main
|
||||
// CHECK: %cst = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<1.270000e+02> : tensor<f32>
|
||||
// CHECK: %cst_1 = constant dense<8> : tensor<i32>
|
||||
// CHECK: %cst_2 = constant dense<false> : tensor<i1>
|
||||
// CHECK: %0 = "xla_hlo.custom_call"(%arg0, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.9"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %1 = "xla_hlo.custom_call"(%arg1, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.14"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %2 = xla_hlo.add %0, %1 {name = "add.15"} : tensor<2x4xf32>
|
||||
// CHECK: %3 = "xla_hlo.custom_call"(%2, %cst, %cst_0, %cst_1, %cst_2) {backend_config = "", call_target_name = "fake_quant_with_min_max_vars", has_side_effect = false, name = "custom-call.20"} : (tensor<2x4xf32>, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i1>) -> tensor<2x4xf32>
|
||||
// CHECK: %4 = "xla_hlo.tuple"(%3) {name = "tuple.22"} : (tensor<2x4xf32>) -> tuple<tensor<2x4xf32>>
|
||||
// CHECK: return %4 : tuple<tensor<2x4xf32>>
|
||||
// CHECK: }
|
@ -0,0 +1,26 @@
|
||||
feed {
|
||||
id { node_name: "input0" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "input1" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
fetch {
|
||||
id { node_name: "Add/FakeQuantWithMinMaxVars" }
|
||||
shape {
|
||||
dim { size: 2 }
|
||||
dim { size: 4 }
|
||||
}
|
||||
}
|
||||
|
||||
conversion_options {
|
||||
custom_fake_quant_op_calls: true
|
||||
}
|
@ -0,0 +1,218 @@
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "Add"
|
||||
input: "Add/FakeQuantWithMinMaxVars/min"
|
||||
input: "Add/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "Add/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Add"
|
||||
op: "Add"
|
||||
input: "input0/FakeQuantWithMinMaxVars"
|
||||
input: "input1/FakeQuantWithMinMaxVars"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input0"
|
||||
input: "input0/FakeQuantWithMinMaxVars/min"
|
||||
input: "input0/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input0/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input0"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars"
|
||||
op: "FakeQuantWithMinMaxVars"
|
||||
input: "input1"
|
||||
input: "input1/FakeQuantWithMinMaxVars/min"
|
||||
input: "input1/FakeQuantWithMinMaxVars/max"
|
||||
attr: {
|
||||
key: "num_bits"
|
||||
value: {
|
||||
i: 8
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "narrow_range"
|
||||
value: {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/min"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "input1/FakeQuantWithMinMaxVars/max"
|
||||
op: "Const"
|
||||
attr: {
|
||||
key: "value"
|
||||
value: {
|
||||
tensor: {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape: {
|
||||
}
|
||||
float_val: 127.0
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "input1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
Loading…
Reference in New Issue
Block a user