Add InitMlir class to initialize using TF's InitMain and InitLLVM
Helper class that initializes both LLVM and TF. Pass strings before the separator (--) to TF's InitMain (none where there is no separator). This could be further enhanced to better support help flag. PiperOrigin-RevId: 264420162
This commit is contained in:
parent
9c6b1d6898
commit
cf83c25242
@ -38,6 +38,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
"@llvm//:support",
|
"@llvm//:support",
|
||||||
"@local_config_mlir//:AffineDialectRegistration",
|
"@local_config_mlir//:AffineDialectRegistration",
|
||||||
"@local_config_mlir//:MlirOptLib",
|
"@local_config_mlir//:MlirOptLib",
|
||||||
@ -50,6 +51,16 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "init_mlir",
|
||||||
|
srcs = ["init_mlir.cc"],
|
||||||
|
hdrs = ["init_mlir.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@llvm//:support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_binary(
|
tf_cc_binary(
|
||||||
name = "tf-opt",
|
name = "tf-opt",
|
||||||
deps = [
|
deps = [
|
||||||
|
45
tensorflow/compiler/mlir/init_mlir.cc
Normal file
45
tensorflow/compiler/mlir/init_mlir.cc
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) {
|
||||||
|
constexpr char kSeparator[] = "--";
|
||||||
|
|
||||||
|
// Find index of separator between two sets of flags.
|
||||||
|
int pass_remainder = 1;
|
||||||
|
bool split = false;
|
||||||
|
for (int i = 0; i < *argc; ++i) {
|
||||||
|
if (llvm::StringRef((*argv)[i]) == kSeparator) {
|
||||||
|
pass_remainder = i;
|
||||||
|
*argc -= (i + 1);
|
||||||
|
split = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::port::InitMain((*argv)[0], &pass_remainder, argv);
|
||||||
|
if (split) {
|
||||||
|
*argc += pass_remainder;
|
||||||
|
(*argv)[1] = (*argv)[0];
|
||||||
|
++*argv;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
40
tensorflow/compiler/mlir/init_mlir.h
Normal file
40
tensorflow/compiler/mlir/init_mlir.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
|
||||||
|
|
||||||
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
#include "llvm/Support/InitLLVM.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Initializer to perform both InitLLVM and TF"s InitMain initialization.
|
||||||
|
// InitMain also performs flag parsing and '--' is used to separate flags passed
|
||||||
|
// to it: Flags before the first '--' are parsed by InitMain and argc and argv
|
||||||
|
// progressed to the flags post. If there is no separator, then no flags are
|
||||||
|
// parsed by InitMain and argc/argv left unadjusted.
|
||||||
|
// TODO(jpienaar): The way help flag is handled could be improved.
|
||||||
|
class InitMlir {
|
||||||
|
public:
|
||||||
|
InitMlir(int *argc, char ***argv);
|
||||||
|
|
||||||
|
private:
|
||||||
|
llvm::InitLLVM init_llvm_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
|
@ -468,6 +468,7 @@ tf_cc_binary(
|
|||||||
":tf_tfl_passes",
|
":tf_tfl_passes",
|
||||||
":tf_tfl_translate_cl_options",
|
":tf_tfl_translate_cl_options",
|
||||||
":tf_to_tfl_flatbuffer",
|
":tf_to_tfl_flatbuffer",
|
||||||
|
"//tensorflow/compiler/mlir:init_mlir",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
@ -485,6 +486,7 @@ tf_cc_binary(
|
|||||||
deps = [
|
deps = [
|
||||||
":flatbuffer_translate_lib",
|
":flatbuffer_translate_lib",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/core/platform/default/build_config:base",
|
"//tensorflow/core/platform/default/build_config:base",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/delegates/flex:delegate",
|
"//tensorflow/lite/delegates/flex:delegate",
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||||
|
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||||
@ -100,7 +101,7 @@ static int PrintFunctionResultMapping(const std::string &result,
|
|||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
// TODO(jpienaar): Revise the command line option parsing here.
|
// TODO(jpienaar): Revise the command line option parsing here.
|
||||||
llvm::InitLLVM y(argc, argv);
|
tensorflow::InitMlir y(&argc, &argv);
|
||||||
|
|
||||||
// TODO(antiagainst): We are pulling in multiple transformations as follows.
|
// TODO(antiagainst): We are pulling in multiple transformations as follows.
|
||||||
// Each transformation has its own set of command-line options; options of one
|
// Each transformation has its own set of command-line options; options of one
|
||||||
@ -111,14 +112,9 @@ int main(int argc, char **argv) {
|
|||||||
// We need to disable duplicated ones to provide a cleaner command-line option
|
// We need to disable duplicated ones to provide a cleaner command-line option
|
||||||
// interface. That also means we need to relay the value set in one option to
|
// interface. That also means we need to relay the value set in one option to
|
||||||
// all its aliases.
|
// all its aliases.
|
||||||
|
|
||||||
llvm::cl::ParseCommandLineOptions(
|
llvm::cl::ParseCommandLineOptions(
|
||||||
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
|
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
|
||||||
|
|
||||||
// TODO(ashwinm): Enable command line parsing for both sides.
|
|
||||||
int fake_argc = 1;
|
|
||||||
tensorflow::port::InitMain(argv[0], &fake_argc, &argv);
|
|
||||||
|
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
llvm::SourceMgr source_mgr;
|
llvm::SourceMgr source_mgr;
|
||||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||||
|
@ -304,7 +304,6 @@ tf_native_cc_binary(
|
|||||||
"operator_writer_gen.cc",
|
"operator_writer_gen.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@llvm//:config",
|
|
||||||
"@llvm//:support",
|
"@llvm//:support",
|
||||||
"@llvm//:tablegen",
|
"@llvm//:tablegen",
|
||||||
"@local_config_mlir//:TableGen",
|
"@local_config_mlir//:TableGen",
|
||||||
|
Loading…
Reference in New Issue
Block a user