Remove experimental xla hlo quantization from open source (I)
PiperOrigin-RevId: 305818950 Change-Id: Ie09f6a81bafb100ca3dd3601fd4f9a8a732677ca
This commit is contained in:
parent
f2100b9b51
commit
03c15fee32
tensorflow/compiler
@ -16,6 +16,12 @@ cc_library(
|
|||||||
deps = ["//tensorflow/core:test_main"],
|
deps = ["//tensorflow/core:test_main"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "quantize_header",
|
||||||
|
srcs = ["quantize.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tfcompile_lib",
|
name = "tfcompile_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -27,6 +33,7 @@ cc_library(
|
|||||||
"codegen.h",
|
"codegen.h",
|
||||||
"compile.h",
|
"compile.h",
|
||||||
"flags.h",
|
"flags.h",
|
||||||
|
"quantize.h",
|
||||||
],
|
],
|
||||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||||
visibility = ["//tensorflow/python:__pkg__"],
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
@ -37,7 +44,6 @@ cc_library(
|
|||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
|
|
||||||
"//tensorflow/compiler/tf2xla",
|
"//tensorflow/compiler/tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "llvm-c/Target.h"
|
#include "llvm-c/Target.h"
|
||||||
#include "tensorflow/compiler/aot/codegen.h"
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
#include "tensorflow/compiler/aot/flags.h"
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
#include "tensorflow/compiler/aot/quantize.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
@ -46,6 +46,14 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tfcompile {
|
namespace tfcompile {
|
||||||
|
|
||||||
|
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
|
||||||
|
|
||||||
|
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
|
||||||
|
if (*quantize_xla) return false;
|
||||||
|
*quantize_xla = fn;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Compiles the XLA computation into executable code.
|
// Compiles the XLA computation into executable code.
|
||||||
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
|||||||
} else {
|
} else {
|
||||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||||
}
|
}
|
||||||
if (flags.experimental_quantize) {
|
|
||||||
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
|
if (flags.experimental_quantize && *quantize_xla) {
|
||||||
|
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!flags.out_session_module.empty()) {
|
if (!flags.out_session_module.empty()) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||||
computation.Snapshot());
|
computation.Snapshot());
|
||||||
|
@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace tensorflow {
|
||||||
namespace xla_hlo {
|
namespace tfcompile {
|
||||||
|
|
||||||
// Quantizes the model in the computation.
|
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
|
||||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
xla::XlaComputation* computation)>;
|
||||||
xla::XlaComputation* computation);
|
|
||||||
|
|
||||||
} // namespace xla_hlo
|
// Set the static quantization function to the `fn` if it hasn't been set.
|
||||||
} // namespace mlir
|
// Return false if the static function has been set.
|
||||||
|
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
|
} // namespace tfcompile
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
|
@ -14,7 +14,7 @@ package_group(
|
|||||||
name = "friends",
|
name = "friends",
|
||||||
includes = ["//third_party/mlir:subpackages"],
|
includes = ["//third_party/mlir:subpackages"],
|
||||||
packages = [
|
packages = [
|
||||||
"//tensorflow/compiler/aot/...",
|
"//learning/brain/experimental/mlir/quantization/...",
|
||||||
"//tensorflow/compiler/mlir/...",
|
"//tensorflow/compiler/mlir/...",
|
||||||
"//tensorflow/compiler/mlir/lite/...",
|
"//tensorflow/compiler/mlir/lite/...",
|
||||||
],
|
],
|
||||||
@ -68,35 +68,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "quantize",
|
|
||||||
srcs = [
|
|
||||||
"quantize.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"quantize.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":hlo_xla_quantization_passes",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
|
|
||||||
"//tensorflow/compiler/tf2xla",
|
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
|
||||||
"//tensorflow/core/platform:status",
|
|
||||||
"@llvm-project//mlir:Analysis",
|
|
||||||
"@llvm-project//mlir:IR",
|
|
||||||
"@llvm-project//mlir:Pass",
|
|
||||||
"@llvm-project//mlir:QuantOps",
|
|
||||||
"@llvm-project//mlir:StandardOps",
|
|
||||||
"@llvm-project//mlir:Transforms",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
gentbl(
|
gentbl(
|
||||||
name = "cpu_kernel_fusion_inc_gen",
|
name = "cpu_kernel_fusion_inc_gen",
|
||||||
tbl_outs = [
|
tbl_outs = [
|
||||||
|
@ -1,78 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/Function.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
|
||||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
|
||||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h"
|
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace xla_hlo {
|
|
||||||
|
|
||||||
static void RegisterDialects() {
|
|
||||||
static bool init_once = []() {
|
|
||||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
|
||||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
|
||||||
mlir::registerDialect<mlir::quant::QuantizationDialect>();
|
|
||||||
return true;
|
|
||||||
}();
|
|
||||||
(void)init_once;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Quantizes the model in the computation.
|
|
||||||
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
|
|
||||||
xla::XlaComputation* computation) {
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
|
|
||||||
computation->Snapshot());
|
|
||||||
|
|
||||||
RegisterDialects();
|
|
||||||
MLIRContext context;
|
|
||||||
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
|
|
||||||
auto status = xla::ConvertHloToMlirHlo(
|
|
||||||
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
PassManager pm(&context);
|
|
||||||
pm.addPass(createCanonicalizerPass());
|
|
||||||
pm.addPass(createInlinerPass());
|
|
||||||
pm.addPass(createSymbolDCEPass());
|
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
|
||||||
pm.addNestedPass<FuncOp>(CreateCpuKernelFusionPass());
|
|
||||||
|
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
|
|
||||||
LogicalResult result = pm.run(module.get());
|
|
||||||
(void)result;
|
|
||||||
|
|
||||||
module->walk([&](quant::QuantizeRegionOp op) { op.dump(); });
|
|
||||||
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla_hlo
|
|
||||||
} // namespace mlir
|
|
@ -8,9 +8,7 @@ glob_lit_tests(
|
|||||||
":test_utilities",
|
":test_utilities",
|
||||||
],
|
],
|
||||||
driver = "@llvm-project//mlir:run_lit.sh",
|
driver = "@llvm-project//mlir:run_lit.sh",
|
||||||
tags_override = {
|
exclude = ["fadd_quant.mlir"],
|
||||||
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
|
|
||||||
},
|
|
||||||
test_file_exts = ["mlir"],
|
test_file_exts = ["mlir"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user