diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 7f1590ff75d..fd4ae10595b 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -16,6 +16,12 @@ cc_library( deps = ["//tensorflow/core:test_main"], ) +filegroup( + name = "quantize_header", + srcs = ["quantize.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "tfcompile_lib", srcs = [ @@ -27,6 +33,7 @@ cc_library( "codegen.h", "compile.h", "flags.h", + "quantize.h", ], defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]), visibility = ["//tensorflow/python:__pkg__"], @@ -37,7 +44,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/mlir/lite/quantization/xla:quantize", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:mlir_tf2xla", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index f83cd45f9f3..a2cba5cdf9e 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm-c/Target.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/flags.h" -#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" +#include "tensorflow/compiler/aot/quantize.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -46,6 +46,14 @@ limitations under the License. namespace tensorflow { namespace tfcompile { +static llvm::ManagedStatic quantize_xla; + +bool RegisterQuantizeFn(const QuantizeXlaFn& fn) { + if (*quantize_xla) return false; + *quantize_xla = fn; + return true; +} + namespace { // Compiles the XLA computation into executable code. @@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, } else { return errors::Unknown("Unknown mlir_components ", flags.mlir_components); } - if (flags.experimental_quantize) { - TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation)); + + if (flags.experimental_quantize && *quantize_xla) { + TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation)); } + if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h b/tensorflow/compiler/aot/quantize.h similarity index 55% rename from tensorflow/compiler/mlir/lite/quantization/xla/quantize.h rename to tensorflow/compiler/aot/quantize.h index 2ec5dbb02ce..add05bd0422 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h +++ b/tensorflow/compiler/aot/quantize.h @@ -13,21 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ +#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ + +#include +#include +#include #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -namespace mlir { -namespace xla_hlo { +namespace tensorflow { +namespace tfcompile { -// Quantizes the model in the computation. -tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, - xla::XlaComputation* computation); +using QuantizeXlaFn = std::function; -} // namespace xla_hlo -} // namespace mlir +// Set the static quantization function to the `fn` if it hasn't been set. +// Return false if the static function has been set. +bool RegisterQuantizeFn(const QuantizeXlaFn& fn); -#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_ +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD index 74e81d7f291..791ec0aa93e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD @@ -14,7 +14,7 @@ package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], packages = [ - "//tensorflow/compiler/aot/...", + "//learning/brain/experimental/mlir/quantization/...", "//tensorflow/compiler/mlir/...", "//tensorflow/compiler/mlir/lite/...", ], @@ -68,35 +68,6 @@ cc_library( ], ) -cc_library( - name = "quantize", - srcs = [ - "quantize.cc", - ], - hdrs = [ - "quantize.h", - ], - deps = [ - ":hlo_xla_quantization_passes", - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo", - "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", - "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core/platform:status", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Transforms", - ], -) - gentbl( name = "cpu_kernel_fusion_inc_gen", tbl_outs = [ diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc deleted file mode 100644 index 223a55d23d5..00000000000 --- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h" - -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h" -#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/tf2xla/tf2xla.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" - -namespace mlir { -namespace xla_hlo { - -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; -} - -// Quantizes the model in the computation. -tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config, - xla::XlaComputation* computation) { - TF_ASSIGN_OR_RETURN(std::unique_ptr snapshot, - computation->Snapshot()); - - RegisterDialects(); - MLIRContext context; - OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context)); - auto status = xla::ConvertHloToMlirHlo( - module.get(), snapshot->mutable_hlo()->mutable_hlo_module()); - if (!status.ok()) { - LOG(ERROR) << "Hlo module import failed: " << status; - return status; - } - - PassManager pm(&context); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createInlinerPass()); - pm.addPass(createSymbolDCEPass()); - pm.addNestedPass(createCSEPass()); - pm.addNestedPass(CreateCpuKernelFusionPass()); - - mlir::StatusScopedDiagnosticHandler diag_handler(&context); - LogicalResult result = pm.run(module.get()); - (void)result; - - module->walk([&](quant::QuantizeRegionOp op) { op.dump(); }); - - return tensorflow::Status::OK(); -} - -} // namespace xla_hlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD index 4b6b4212567..b6c156e0ded 100644 --- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD @@ -8,9 +8,7 @@ glob_lit_tests( ":test_utilities", ], driver = "@llvm-project//mlir:run_lit.sh", - tags_override = { - "fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss. - }, + exclude = ["fadd_quant.mlir"], test_file_exts = ["mlir"], )