From 03c15fee32663306201005cda7088965ad0e646b Mon Sep 17 00:00:00 2001
From: Feng Liu <fengliuai@google.com>
Date: Thu, 9 Apr 2020 21:06:08 -0700
Subject: [PATCH] Remove experimental xla hlo quantization from open source (I)

PiperOrigin-RevId: 305818950
Change-Id: Ie09f6a81bafb100ca3dd3601fd4f9a8a732677ca
---
 tensorflow/compiler/aot/BUILD                 |  8 +-
 tensorflow/compiler/aot/compile.cc            | 16 +++-
 .../lite/quantization/xla => aot}/quantize.h  | 28 ++++---
 .../compiler/mlir/lite/quantization/xla/BUILD | 31 +-------
 .../mlir/lite/quantization/xla/quantize.cc    | 78 -------------------
 .../mlir/lite/quantization/xla/tests/BUILD    |  4 +-
 6 files changed, 40 insertions(+), 125 deletions(-)
 rename tensorflow/compiler/{mlir/lite/quantization/xla => aot}/quantize.h (55%)
 delete mode 100644 tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc

diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 7f1590ff75d..fd4ae10595b 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -16,6 +16,12 @@ cc_library(
     deps = ["//tensorflow/core:test_main"],
 )
 
+filegroup(
+    name = "quantize_header",
+    srcs = ["quantize.h"],
+    visibility = ["//visibility:public"],
+)
+
 cc_library(
     name = "tfcompile_lib",
     srcs = [
@@ -27,6 +33,7 @@ cc_library(
         "codegen.h",
         "compile.h",
         "flags.h",
+        "quantize.h",
     ],
     defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
     visibility = ["//tensorflow/python:__pkg__"],
@@ -37,7 +44,6 @@ cc_library(
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
-        "//tensorflow/compiler/mlir/lite/quantization/xla:quantize",
         "//tensorflow/compiler/tf2xla",
         "//tensorflow/compiler/tf2xla:mlir_tf2xla",
         "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index f83cd45f9f3..a2cba5cdf9e 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -24,7 +24,7 @@ limitations under the License.
 #include "llvm-c/Target.h"
 #include "tensorflow/compiler/aot/codegen.h"
 #include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
+#include "tensorflow/compiler/aot/quantize.h"
 #include "tensorflow/compiler/tf2xla/tf2xla.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
@@ -46,6 +46,14 @@ limitations under the License.
 namespace tensorflow {
 namespace tfcompile {
 
+static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
+
+bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
+  if (*quantize_xla) return false;
+  *quantize_xla = fn;
+  return true;
+}
+
 namespace {
 
 // Compiles the XLA computation into executable code.
@@ -116,9 +124,11 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
   } else {
     return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
   }
-  if (flags.experimental_quantize) {
-    TF_RETURN_IF_ERROR(mlir::xla_hlo::XlaQuantize(config, &computation));
+
+  if (flags.experimental_quantize && *quantize_xla) {
+    TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
   }
+
   if (!flags.out_session_module.empty()) {
     TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
                         computation.Snapshot());
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h b/tensorflow/compiler/aot/quantize.h
similarity index 55%
rename from tensorflow/compiler/mlir/lite/quantization/xla/quantize.h
rename to tensorflow/compiler/aot/quantize.h
index 2ec5dbb02ce..add05bd0422 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.h
+++ b/tensorflow/compiler/aot/quantize.h
@@ -13,21 +13,29 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
-#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
+#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
+#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
+
+#include <functional>
+#include <iostream>
+#include <ostream>
 
 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/status.h"
 
-namespace mlir {
-namespace xla_hlo {
+namespace tensorflow {
+namespace tfcompile {
 
-// Quantizes the model in the computation.
-tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
-                               xla::XlaComputation* computation);
+using QuantizeXlaFn = std::function<Status(const tf2xla::Config& config,
+                                           xla::XlaComputation* computation)>;
 
-}  // namespace xla_hlo
-}  // namespace mlir
+// Set the static quantization function to the `fn` if it hasn't been set.
+// Return false if the static function has been set.
+bool RegisterQuantizeFn(const QuantizeXlaFn& fn);
 
-#endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_QUANTIZE_H_
+}  // namespace tfcompile
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
index 74e81d7f291..791ec0aa93e 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/BUILD
@@ -14,7 +14,7 @@ package_group(
     name = "friends",
     includes = ["//third_party/mlir:subpackages"],
     packages = [
-        "//tensorflow/compiler/aot/...",
+        "//learning/brain/experimental/mlir/quantization/...",
         "//tensorflow/compiler/mlir/...",
         "//tensorflow/compiler/mlir/lite/...",
     ],
@@ -68,35 +68,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "quantize",
-    srcs = [
-        "quantize.cc",
-    ],
-    hdrs = [
-        "quantize.h",
-    ],
-    deps = [
-        ":hlo_xla_quantization_passes",
-        "//tensorflow/compiler/mlir/xla:hlo",
-        "//tensorflow/compiler/mlir/xla:hlo_to_mlir_hlo",
-        "//tensorflow/compiler/tf2xla",
-        "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
-        "//tensorflow/compiler/tf2xla:tf2xla_util",
-        "//tensorflow/compiler/tf2xla:xla_compiler",
-        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
-        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
-        "//tensorflow/compiler/xla/client:xla_computation",
-        "//tensorflow/core/platform:status",
-        "@llvm-project//mlir:Analysis",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:QuantOps",
-        "@llvm-project//mlir:StandardOps",
-        "@llvm-project//mlir:Transforms",
-    ],
-)
-
 gentbl(
     name = "cpu_kernel_fusion_inc_gen",
     tbl_outs = [
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc b/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
deleted file mode 100644
index 223a55d23d5..00000000000
--- a/tensorflow/compiler/mlir/lite/quantization/xla/quantize.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/compiler/mlir/lite/quantization/xla/quantize.h"
-
-#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
-#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
-#include "mlir/IR/Builders.h"  // from @llvm-project
-#include "mlir/IR/Function.h"  // from @llvm-project
-#include "mlir/IR/MLIRContext.h"  // from @llvm-project
-#include "mlir/IR/Module.h"  // from @llvm-project
-#include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "mlir/Pass/PassManager.h"  // from @llvm-project
-#include "mlir/Transforms/Passes.h"  // from @llvm-project
-#include "tensorflow/compiler/mlir/lite/quantization/xla/passes.h"
-#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
-#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
-#include "tensorflow/compiler/tf2xla/tf2xla.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-
-namespace mlir {
-namespace xla_hlo {
-
-static void RegisterDialects() {
-  static bool init_once = []() {
-    mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
-    mlir::registerDialect<mlir::StandardOpsDialect>();
-    mlir::registerDialect<mlir::quant::QuantizationDialect>();
-    return true;
-  }();
-  (void)init_once;
-}
-
-// Quantizes the model in the computation.
-tensorflow::Status XlaQuantize(const tensorflow::tf2xla::Config& config,
-                               xla::XlaComputation* computation) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> snapshot,
-                      computation->Snapshot());
-
-  RegisterDialects();
-  MLIRContext context;
-  OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
-  auto status = xla::ConvertHloToMlirHlo(
-      module.get(), snapshot->mutable_hlo()->mutable_hlo_module());
-  if (!status.ok()) {
-    LOG(ERROR) << "Hlo module import failed: " << status;
-    return status;
-  }
-
-  PassManager pm(&context);
-  pm.addPass(createCanonicalizerPass());
-  pm.addPass(createInlinerPass());
-  pm.addPass(createSymbolDCEPass());
-  pm.addNestedPass<FuncOp>(createCSEPass());
-  pm.addNestedPass<FuncOp>(CreateCpuKernelFusionPass());
-
-  mlir::StatusScopedDiagnosticHandler diag_handler(&context);
-  LogicalResult result = pm.run(module.get());
-  (void)result;
-
-  module->walk([&](quant::QuantizeRegionOp op) { op.dump(); });
-
-  return tensorflow::Status::OK();
-}
-
-}  // namespace xla_hlo
-}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
index 4b6b4212567..b6c156e0ded 100644
--- a/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
@@ -8,9 +8,7 @@ glob_lit_tests(
         ":test_utilities",
     ],
     driver = "@llvm-project//mlir:run_lit.sh",
-    tags_override = {
-        "fadd_quant.mlir": ["no_oss"],  # TODO(b/150957738): to be fixed on oss.
-    },
+    exclude = ["fadd_quant.mlir"],
     test_file_exts = ["mlir"],
 )