Remove experimental xla hlo quantization from open source (I)
PiperOrigin-RevId: 305818950 Change-Id: Ie09f6a81bafb100ca3dd3601fd4f9a8a732677ca
This commit is contained in:
parent
f2100b9b51
commit
03c15fee32
@ -16,6 +16,12 @@ cc_library(
|
||||
deps = ["//tensorflow/core:test_main"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "quantize_header",
|
||||
srcs = ["quantize.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfcompile_lib",
|
||||
srcs = [
|
||||
@ -27,6 +33,7 @@ cc_library(
|
||||
"codegen.h",
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
"quantize.h",
|
||||
],
|
||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
@ -37,7 +44,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
#include "tensorflow/compiler/aot/quantize.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -46,6 +46,14 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
|
||||
|
||||
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
|
||||
if (*quantize_xla) return false;
|
||||
*quantize_xla = fn;
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Compiles the XLA computation into executable code.
|
||||
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
} else {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
if (flags.experimental_quantize) {
|
||||
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
||||
|
||||
if (flags.experimental_quantize && *quantize_xla) {
|
||||
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
|
||||
}
|
||||
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
computation.Snapshot());
|
||||
|
@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation);
|
||||
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
|
||||
xla::XlaComputation* computation)>;
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
// Set the static quantization function to the `fn` if it hasn't been set.
|
||||
// Return false if the static function has been set.
|
||||
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
@ -14,7 +14,7 @@ package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/aot/...",
|
||||
"//learning/brain/experimental/mlir/quantization/...",
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
@ -68,35 +68,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantize",
|
||||
srcs = [
|
||||
"quantize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantize.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo_xla_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "cpu_kernel_fusion_inc_gen",
|
||||
tbl_outs = [
|
||||
|
@ -1,78 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::quant::QuantizationDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
// Quantizes the model in the computation.
|
||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
||||
xla::XlaComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
||||
computation->Snapshot());
|
||||
|
||||
RegisterDialects();
|
||||
MLIRContext context;
|
||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
||||
auto status = xla::ConvertHloToMlirHlo(
|
||||
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
PassManager pm(&context);
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createInlinerPass());
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
pm.addNestedPass<FuncOp>(CreateCpuKernelFusionPass());
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
|
||||
LogicalResult result = pm.run(module.get());
|
||||
(void)result;
|
||||
|
||||
module->walk([&](quant::QuantizeRegionOp op) { op.dump(); });
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -8,9 +8,7 @@ glob_lit_tests(
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags_override = {
|
||||
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
||||
},
|
||||
exclude = ["fadd_quant.mlir"],
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user