Remove experimental xla hlo quantization from open source (I)

PiperOrigin-RevId: 305818950
Change-Id: Ie09f6a81bafb100ca3dd3601fd4f9a8a732677ca
This commit is contained in:
Feng Liu 2020-04-09 21:06:08 -07:00 committed by TensorFlower Gardener
parent f2100b9b51
commit 03c15fee32
6 changed files with 40 additions and 125 deletions

View File

@ -16,6 +16,12 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
filegroup(
name = "quantize_header",
srcs = ["quantize.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "tfcompile_lib",
srcs = [
@ -27,6 +33,7 @@ cc_library(
"codegen.h",
"compile.h",
"flags.h",
"quantize.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
@ -37,7 +44,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "tensorflow/compiler/aot/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -46,6 +46,14 @@ limitations under the License.
namespace tensorflow {
namespace tfcompile {
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
if (*quantize_xla) return false;
*quantize_xla = fn;
return true;
}
namespace {
// Compiles the XLA computation into executable code.
@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.experimental_quantize) {
TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
if (flags.experimental_quantize && *quantize_xla) {
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());

View File

@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
#include <functional>
#include <iostream>
#include <ostream>
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace xla_hlo {
namespace tensorflow {
namespace tfcompile {
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation);
using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
xla::XlaComputation* computation)>;
} // namespace xla_hlo
} // namespace mlir
// Set the static quantization function to the `fn` if it hasn't been set.
// Return false if the static function has been set.
bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_

View File

@ -14,7 +14,7 @@ package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/aot/...",
"//learning/brain/experimental/mlir/quantization/...",
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
@ -68,35 +68,6 @@ cc_library(
],
)
cc_library(
name = "quantize",
srcs = [
"quantize.cc",
],
hdrs = [
"quantize.h",
],
deps = [
":hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/core/platform:status",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
gentbl(
name = "cpu_kernel_fusion_inc_gen",
tbl_outs = [

View File

@ -1,78 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h"
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
namespace mlir {
namespace xla_hlo {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::quant::QuantizationDialect>();
return true;
}();
(void)init_once;
}
// Quantizes the model in the computation.
tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
xla::XlaComputation* computation) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
computation->Snapshot());
RegisterDialects();
MLIRContext context;
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
auto status = xla::ConvertHloToMlirHlo(
module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return status;
}
PassManager pm(&context);
pm.addPass(createCanonicalizerPass());
pm.addPass(createInlinerPass());
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(CreateCpuKernelFusionPass());
mlir::StatusScopedDiagnosticHandler diag_handler(&context);
LogicalResult result = pm.run(module.get());
(void)result;
module->walk([&](quant::QuantizeRegionOp op) { op.dump(); });
return tensorflow::Status::OK();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -8,9 +8,7 @@ glob_lit_tests(
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"fadd_quant.mlir": ["no_oss"], # TODO(b/150957738): to be fixed on oss.
},
exclude = ["fadd_quant.mlir"],
test_file_exts = ["mlir"],
)