Add InitMlir class to initialize using TF's InitMain and InitLLVM
Helper class that initializes both LLVM and TF. Pass strings before the separator (--) to TF's InitMain (none where there is no separator). This could be further enhanced to better support help flag. PiperOrigin-RevId: 264420162
This commit is contained in:
parent
9c6b1d6898
commit
cf83c25242
tensorflow/compiler/mlir
@ -38,6 +38,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:AffineDialectRegistration",
|
||||
"@local_config_mlir//:MlirOptLib",
|
||||
@ -50,6 +51,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "init_mlir",
|
||||
srcs = ["init_mlir.cc"],
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-opt",
|
||||
deps = [
|
||||
|
45
tensorflow/compiler/mlir/init_mlir.cc
Normal file
45
tensorflow/compiler/mlir/init_mlir.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) {
|
||||
constexpr char kSeparator[] = "--";
|
||||
|
||||
// Find index of separator between two sets of flags.
|
||||
int pass_remainder = 1;
|
||||
bool split = false;
|
||||
for (int i = 0; i < *argc; ++i) {
|
||||
if (llvm::StringRef((*argv)[i]) == kSeparator) {
|
||||
pass_remainder = i;
|
||||
*argc -= (i + 1);
|
||||
split = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::port::InitMain((*argv)[0], &pass_remainder, argv);
|
||||
if (split) {
|
||||
*argc += pass_remainder;
|
||||
(*argv)[1] = (*argv)[0];
|
||||
++*argv;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
40
tensorflow/compiler/mlir/init_mlir.h
Normal file
40
tensorflow/compiler/mlir/init_mlir.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Initializer to perform both InitLLVM and TF"s InitMain initialization.
|
||||
// InitMain also performs flag parsing and '--' is used to separate flags passed
|
||||
// to it: Flags before the first '--' are parsed by InitMain and argc and argv
|
||||
// progressed to the flags post. If there is no separator, then no flags are
|
||||
// parsed by InitMain and argc/argv left unadjusted.
|
||||
// TODO(jpienaar): The way help flag is handled could be improved.
|
||||
class InitMlir {
|
||||
public:
|
||||
InitMlir(int *argc, char ***argv);
|
||||
|
||||
private:
|
||||
llvm::InitLLVM init_llvm_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_
|
@ -468,6 +468,7 @@ tf_cc_binary(
|
||||
":tf_tfl_passes",
|
||||
":tf_tfl_translate_cl_options",
|
||||
":tf_to_tfl_flatbuffer",
|
||||
"//tensorflow/compiler/mlir:init_mlir",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite:framework",
|
||||
@ -485,6 +486,7 @@ tf_cc_binary(
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform/default/build_config:base",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/init_mlir.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h"
|
||||
@ -100,7 +101,7 @@ static int PrintFunctionResultMapping(const std::string &result,
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// TODO(jpienaar): Revise the command line option parsing here.
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
tensorflow::InitMlir y(&argc, &argv);
|
||||
|
||||
// TODO(antiagainst): We are pulling in multiple transformations as follows.
|
||||
// Each transformation has its own set of command-line options; options of one
|
||||
@ -111,14 +112,9 @@ int main(int argc, char **argv) {
|
||||
// We need to disable duplicated ones to provide a cleaner command-line option
|
||||
// interface. That also means we need to relay the value set in one option to
|
||||
// all its aliases.
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(
|
||||
argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n");
|
||||
|
||||
// TODO(ashwinm): Enable command line parsing for both sides.
|
||||
int fake_argc = 1;
|
||||
tensorflow::port::InitMain(argv[0], &fake_argc, &argv);
|
||||
|
||||
MLIRContext context;
|
||||
llvm::SourceMgr source_mgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||
|
@ -304,7 +304,6 @@ tf_native_cc_binary(
|
||||
"operator_writer_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:config",
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
|
Loading…
Reference in New Issue
Block a user