From cf83c252421dcdf31ef8bfe5d464b61ecf7a294b Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar@google.com>
Date: Tue, 20 Aug 2019 10:58:46 -0700
Subject: [PATCH] Add InitMlir class to initialize using TF's InitMain and
 InitLLVM

Helper class that initializes both LLVM and TF. Pass strings before the separator (--) to TF's InitMain (none where there is no separator).

This could be further enhanced to better support help flag.

PiperOrigin-RevId: 264420162
---
 tensorflow/compiler/mlir/BUILD                | 11 +++++
 tensorflow/compiler/mlir/init_mlir.cc         | 45 +++++++++++++++++++
 tensorflow/compiler/mlir/init_mlir.h          | 40 +++++++++++++++++
 tensorflow/compiler/mlir/lite/BUILD           |  2 +
 .../compiler/mlir/lite/tf_tfl_translate.cc    |  8 +---
 tensorflow/compiler/mlir/xla/BUILD            |  1 -
 6 files changed, 100 insertions(+), 7 deletions(-)
 create mode 100644 tensorflow/compiler/mlir/init_mlir.cc
 create mode 100644 tensorflow/compiler/mlir/init_mlir.h

diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 0c35466b392..e875ed254f6 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -38,6 +38,7 @@ cc_library(
         "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
         "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
         "//tensorflow/core:lib",
+        "//tensorflow/core/platform:logging",
         "@llvm//:support",
         "@local_config_mlir//:AffineDialectRegistration",
         "@local_config_mlir//:MlirOptLib",
@@ -50,6 +51,16 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "init_mlir",
+    srcs = ["init_mlir.cc"],
+    hdrs = ["init_mlir.h"],
+    deps = [
+        "//tensorflow/core:lib",
+        "@llvm//:support",
+    ],
+)
+
 tf_cc_binary(
     name = "tf-opt",
     deps = [
diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc
new file mode 100644
index 00000000000..54f8a57d8a6
--- /dev/null
+++ b/tensorflow/compiler/mlir/init_mlir.cc
@@ -0,0 +1,45 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/init_mlir.h"
+
+#include "tensorflow/core/platform/init_main.h"
+
+namespace tensorflow {
+
+InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) {
+  constexpr char kSeparator[] = "--";
+
+  // Find index of separator between two sets of flags.
+  int pass_remainder = 1;
+  bool split = false;
+  for (int i = 0; i < *argc; ++i) {
+    if (llvm::StringRef((*argv)[i]) == kSeparator) {
+      pass_remainder = i;
+      *argc -= (i + 1);
+      split = true;
+      break;
+    }
+  }
+
+  tensorflow::port::InitMain((*argv)[0], &pass_remainder, argv);
+  if (split) {
+    *argc += pass_remainder;
+    (*argv)[1] = (*argv)[0];
+    ++*argv;
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/init_mlir.h b/tensorflow/compiler/mlir/init_mlir.h
new file mode 100644
index 00000000000..91020c1758b
--- /dev/null
+++ b/tensorflow/compiler/mlir/init_mlir.h
@@ -0,0 +1,40 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
+#define TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+
+namespace tensorflow {
+
+// Initializer to perform both InitLLVM and TF"s InitMain initialization.
+// InitMain also performs flag parsing and '--' is used to separate flags passed
+// to it: Flags before the first '--' are parsed by InitMain and argc and argv
+// progressed to the flags post. If there is no separator, then no flags are
+// parsed by InitMain and argc/argv left unadjusted.
+// TODO(jpienaar): The way help flag is handled could be improved.
+class InitMlir {
+ public:
+  InitMlir(int *argc, char ***argv);
+
+ private:
+  llvm::InitLLVM init_llvm_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 433a85f4b08..5216d237d83 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -468,6 +468,7 @@ tf_cc_binary(
         ":tf_tfl_passes",
         ":tf_tfl_translate_cl_options",
         ":tf_to_tfl_flatbuffer",
+        "//tensorflow/compiler/mlir:init_mlir",
         "//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
         "//tensorflow/core:lib",
         "//tensorflow/lite:framework",
@@ -485,6 +486,7 @@ tf_cc_binary(
     deps = [
         ":flatbuffer_translate_lib",
         "//tensorflow/core:lib",
+        "//tensorflow/core/platform:logging",
         "//tensorflow/core/platform/default/build_config:base",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/delegates/flex:delegate",
diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
index be1496b6edd..445535d52f9 100644
--- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
+++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
@@ -23,6 +23,7 @@ limitations under the License.
 #include "mlir/IR/MLIRContext.h"  // TF:local_config_mlir
 #include "mlir/IR/Module.h"  // TF:local_config_mlir
 #include "mlir/Support/FileUtilities.h"  // TF:local_config_mlir
+#include "tensorflow/compiler/mlir/init_mlir.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
 #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
@@ -100,7 +101,7 @@ static int PrintFunctionResultMapping(const std::string &result,
 
 int main(int argc, char **argv) {
   // TODO(jpienaar): Revise the command line option parsing here.
-  llvm::InitLLVM y(argc, argv);
+  tensorflow::InitMlir y(&argc, &argv);
 
   // TODO(antiagainst): We are pulling in multiple transformations as follows.
   // Each transformation has its own set of command-line options; options of one
@@ -111,14 +112,9 @@ int main(int argc, char **argv) {
   // We need to disable duplicated ones to provide a cleaner command-line option
   // interface. That also means we need to relay the value set in one option to
   // all its aliases.
-
   llvm::cl::ParseCommandLineOptions(
       argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
 
-  // TODO(ashwinm): Enable command line parsing for both sides.
-  int fake_argc = 1;
-  tensorflow::port::InitMain(argv[0], &fake_argc, &argv);
-
   MLIRContext context;
   llvm::SourceMgr source_mgr;
   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index a2f04cce9ce..546d9811729 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -304,7 +304,6 @@ tf_native_cc_binary(
         "operator_writer_gen.cc",
     ],
     deps = [
-        "@llvm//:config",
         "@llvm//:support",
         "@llvm//:tablegen",
         "@local_config_mlir//:TableGen",