From cf83c252421dcdf31ef8bfe5d464b61ecf7a294b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar <jpienaar@google.com> Date: Tue, 20 Aug 2019 10:58:46 -0700 Subject: [PATCH] Add InitMlir class to initialize using TF's InitMain and InitLLVM Helper class that initializes both LLVM and TF. Pass strings before the separator (--) to TF's InitMain (none where there is no separator). This could be further enhanced to better support help flag. PiperOrigin-RevId: 264420162 --- tensorflow/compiler/mlir/BUILD | 11 +++++ tensorflow/compiler/mlir/init_mlir.cc | 45 +++++++++++++++++++ tensorflow/compiler/mlir/init_mlir.h | 40 +++++++++++++++++ tensorflow/compiler/mlir/lite/BUILD | 2 + .../compiler/mlir/lite/tf_tfl_translate.cc | 8 +--- tensorflow/compiler/mlir/xla/BUILD | 1 - 6 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 tensorflow/compiler/mlir/init_mlir.cc create mode 100644 tensorflow/compiler/mlir/init_mlir.h diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 0c35466b392..e875ed254f6 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", "@llvm//:support", "@local_config_mlir//:AffineDialectRegistration", "@local_config_mlir//:MlirOptLib", @@ -50,6 +51,16 @@ cc_library( ], ) +cc_library( + name = "init_mlir", + srcs = ["init_mlir.cc"], + hdrs = ["init_mlir.h"], + deps = [ + "//tensorflow/core:lib", + "@llvm//:support", + ], +) + tf_cc_binary( name = "tf-opt", deps = [ diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc new file mode 100644 index 00000000000..54f8a57d8a6 --- /dev/null +++ b/tensorflow/compiler/mlir/init_mlir.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/init_mlir.h" + +#include "tensorflow/core/platform/init_main.h" + +namespace tensorflow { + +InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) { + constexpr char kSeparator[] = "--"; + + // Find index of separator between two sets of flags. + int pass_remainder = 1; + bool split = false; + for (int i = 0; i < *argc; ++i) { + if (llvm::StringRef((*argv)[i]) == kSeparator) { + pass_remainder = i; + *argc -= (i + 1); + split = true; + break; + } + } + + tensorflow::port::InitMain((*argv)[0], &pass_remainder, argv); + if (split) { + *argc += pass_remainder; + (*argv)[1] = (*argv)[0]; + ++*argv; + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/init_mlir.h b/tensorflow/compiler/mlir/init_mlir.h new file mode 100644 index 00000000000..91020c1758b --- /dev/null +++ b/tensorflow/compiler/mlir/init_mlir.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" + +namespace tensorflow { + +// Initializer to perform both InitLLVM and TF"s InitMain initialization. +// InitMain also performs flag parsing and '--' is used to separate flags passed +// to it: Flags before the first '--' are parsed by InitMain and argc and argv +// progressed to the flags post. If there is no separator, then no flags are +// parsed by InitMain and argc/argv left unadjusted. +// TODO(jpienaar): The way help flag is handled could be improved. +class InitMlir { + public: + InitMlir(int *argc, char ***argv); + + private: + llvm::InitLLVM init_llvm_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 433a85f4b08..5216d237d83 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -468,6 +468,7 @@ tf_cc_binary( ":tf_tfl_passes", ":tf_tfl_translate_cl_options", ":tf_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -485,6 +486,7 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_lib", "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform/default/build_config:base", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index be1496b6edd..445535d52f9 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" @@ -100,7 +101,7 @@ static int PrintFunctionResultMapping(const std::string &result, int main(int argc, char **argv) { // TODO(jpienaar): Revise the command line option parsing here. - llvm::InitLLVM y(argc, argv); + tensorflow::InitMlir y(&argc, &argv); // TODO(antiagainst): We are pulling in multiple transformations as follows. // Each transformation has its own set of command-line options; options of one @@ -111,14 +112,9 @@ int main(int argc, char **argv) { // We need to disable duplicated ones to provide a cleaner command-line option // interface. That also means we need to relay the value set in one option to // all its aliases. - llvm::cl::ParseCommandLineOptions( argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n"); - // TODO(ashwinm): Enable command line parsing for both sides. - int fake_argc = 1; - tensorflow::port::InitMain(argv[0], &fake_argc, &argv); - MLIRContext context; llvm::SourceMgr source_mgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index a2f04cce9ce..546d9811729 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -304,7 +304,6 @@ tf_native_cc_binary( "operator_writer_gen.cc", ], deps = [ - "@llvm//:config", "@llvm//:support", "@llvm//:tablegen", "@local_config_mlir//:TableGen",