Export the MLIR classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 292076160
Change-Id: I62bf3aac988c3ce4e18aa01ee49d8aa9ffde383d
This commit is contained in:
Amit Patankar 2020-01-28 21:53:48 -08:00 committed by TensorFlower Gardener
parent 941950947d
commit 9d41ea557a
13 changed files with 413 additions and 271 deletions
tensorflow
compiler/mlir
python
tensorflow/tests/tf_saved_model
python
tools/def_file_filter

View File

@ -3,9 +3,29 @@ package(
licenses = ["notice"], # Apache 2.0
)
exports_files(
["mlir.i"],
visibility = [
"//tensorflow/python:__subpackages__",
cc_library(
name = "mlir",
srcs = ["mlir.cc"],
hdrs = ["mlir.h"],
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:import_utils",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
],
)
filegroup(
name = "pywrap_mlir_hdrs",
srcs = [
"mlir.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)

View File

@ -0,0 +1,157 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "mlir/Parser.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
namespace tensorflow {
std::string ImportGraphDef(const std::string &proto,
const std::string &pass_pipeline,
TF_Status *status) {
GraphDef graphdef;
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
GraphDebugInfo debug_info;
GraphImportConfig specs;
mlir::MLIRContext context;
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";
}
// Run the pass_pipeline on the module if not empty.
if (!pass_pipeline.empty()) {
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
if (failed(pm.run(*module.ValueOrDie()))) {
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
return "// error";
}
}
return MlirModuleToString(*module.ConsumeValueOrDie());
}
std::string ExperimentalConvertSavedModelToMlir(
const std::string &saved_model_path, const std::string &exported_names_str,
bool show_debug_info, TF_Status *status) {
// Load the saved model into a SavedModelV2Bundle.
tensorflow::SavedModelV2Bundle bundle;
auto load_status =
tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
if (!load_status.ok()) {
Set_TF_Status_from_Status(status, load_status);
return "// error";
}
// Convert the SavedModelV2Bundle to an MLIR module.
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
auto module_or = ConvertSavedModelToMlir(
&bundle, &context, absl::Span<std::string>(exported_names));
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
}
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool show_debug_info, TF_Status *status) {
// Load the saved model into a SavedModelBundle.
std::unordered_set<string> tag_set =
absl::StrSplit(tags, ',', absl::SkipEmpty());
tensorflow::SavedModelBundle bundle;
auto load_status =
tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
if (!load_status.ok()) {
Set_TF_Status_from_Status(status, load_status);
return "// error";
}
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
}
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
const std::string &pass_pipeline,
bool show_debug_info,
TF_Status *status) {
mlir::MLIRContext context;
mlir::OwningModuleRef module;
{
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
module = mlir::parseSourceString(mlir_txt, &context);
if (!module) {
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
return "// error";
}
}
// Run the pass_pipeline on the module.
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
if (failed(pm.run(*module))) {
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
return "// error";
}
return MlirModuleToString(*module, show_debug_info);
}
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Functions for getting information about kernels registered in the binary.
// Migrated from previous SWIG file (mlir.i) authored by aminim@.
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
#include <string>
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
// returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
std::string ImportGraphDef(const std::string &proto,
const std::string &pass_pipeline, TF_Status *status);
// Load a SavedModel and return a textual MLIR string corresponding to it.
//
// Args:
// saved_model_path: File path from which to load the SavedModel.
// exported_names_str: Comma-separated list of names to export.
// Empty means "export all".
//
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelToMlir(
const std::string &saved_model_path, const std::string &exported_names_str,
bool show_debug_info, TF_Status *status);
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
//
// Args:
// saved_model_path: File path from which to load the SavedModel.
// tags: Tags to identify MetaGraphDef that need to be loaded.
//
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool show_debug_info, TF_Status *status);
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
const std::string &pass_pipeline,
bool show_debug_info,
TF_Status *status);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_

View File

@ -1,252 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/platform/base.i"
%{
#include "mlir/Parser.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
namespace tensorflow {
namespace swig {
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
// returning it as a string.
// This is an early experimental API, ideally we should return a wrapper object
// around a Python binding to the MLIR module.
string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Status* status) {
GraphDef graphdef;
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
GraphDebugInfo debug_info;
GraphImportConfig specs;
mlir::MLIRContext context;
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";
}
// Run the pass_pipeline on the module if not empty.
if (!pass_pipeline.empty()) {
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
if (failed(pm.run(*module.ValueOrDie()))) {
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
return "// error";
}
}
return MlirModuleToString(*module.ConsumeValueOrDie());
}
// Load a SavedModel and return a textual MLIR string corresponding to it.
//
// Args:
// saved_model_path: File path from which to load the SavedModel.
// exported_names_str: Comma-separated list of names to export.
// Empty means "export all".
//
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
string ExperimentalConvertSavedModelToMlir(
const string &saved_model_path,
const string &exported_names_str,
bool show_debug_info,
TF_Status* status) {
// Load the saved model into a SavedModelV2Bundle.
tensorflow::SavedModelV2Bundle bundle;
auto load_status = tensorflow::SavedModelV2Bundle::Load(
saved_model_path, &bundle);
if (!load_status.ok()) {
Set_TF_Status_from_Status(status, load_status);
return "// error";
}
// Convert the SavedModelV2Bundle to an MLIR module.
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
auto module_or = ConvertSavedModelToMlir(&bundle, &context,
absl::Span<std::string>(exported_names));
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
}
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
//
// Args:
// saved_model_path: File path from which to load the SavedModel.
// tags: Tags to identify MetaGraphDef that need to be loaded.
//
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
string ExperimentalConvertSavedModelV1ToMlir(
const string &saved_model_path,
const string &tags,
bool show_debug_info,
TF_Status* status) {
// Load the saved model into a SavedModelBundle.
std::unordered_set<string> tag_set
= absl::StrSplit(tags, ',', absl::SkipEmpty());
tensorflow::SavedModelBundle bundle;
auto load_status = tensorflow::LoadSavedModel(
{}, {},
saved_model_path, tag_set, &bundle);
if (!load_status.ok()) {
Set_TF_Status_from_Status(status, load_status);
return "// error";
}
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
}
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
string ExperimentalRunPassPipeline(
const string &mlir_txt,
const string &pass_pipeline,
bool show_debug_info,
TF_Status* status) {
mlir::MLIRContext context;
mlir::OwningModuleRef module;
{
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
module = mlir::parseSourceString(mlir_txt, &context);
if (!module) {
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
return "// error";
}
}
// Run the pass_pipeline on the module.
mlir::PassManager pm(&context);
std::string error;
llvm::raw_string_ostream error_stream(error);
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
("Invalid pass_pipeline: " + error_stream.str()).c_str());
return "// error";
}
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
if (failed(pm.run(*module))) {
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
return "// error";
}
return MlirModuleToString(*module, show_debug_info);
}
} // namespace swig
} // namespace tensorflow
%}
%ignoreall
%unignore tensorflow;
%unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef;
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir;
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
// Wrap this function
namespace tensorflow {
namespace swig {
static string ImportGraphDef(const string &graphdef,
const string &pass_pipeline,
TF_Status* status);
static string ExperimentalConvertSavedModelToMlir(
const string &saved_model_path,
const string &exported_names,
bool show_debug_info,
TF_Status* status);
static string ExperimentalConvertSavedModelV1ToMlir(
const string &saved_model_path,
const string &tags,
bool show_debug_info,
TF_Status* status);
static string ExperimentalRunPassPipeline(
const string &mlir_txt,
const string &pass_pipeline,
bool show_debug_info,
TF_Status* status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
def import_graphdef(graphdef, pass_pipeline):
return ImportGraphDef(str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8')).decode('utf-8');
def experimental_convert_saved_model_to_mlir(saved_model_path,
exported_names,
show_debug_info):
return ExperimentalConvertSavedModelToMlir(
str(saved_model_path).encode('utf-8'),
str(exported_names).encode('utf-8'),
show_debug_info
).decode('utf-8');
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
tags, show_debug_info):
return ExperimentalConvertSavedModelV1ToMlir(
str(saved_model_path).encode('utf-8'),
str(tags).encode('utf-8'),
show_debug_info
).decode('utf-8');
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
return ExperimentalRunPassPipeline(
mlir_txt.encode('utf-8'),
pass_pipeline.encode('utf-8'),
show_debug_info
).decode('utf-8');
%}
%unignoreall

View File

@ -29,7 +29,7 @@ from absl import flags
from absl import logging
import tensorflow.compat.v2 as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import
# Use /tmp to make debugging the tests easier (see README.md)
flags.DEFINE_string('save_model_path', '',
@ -84,13 +84,13 @@ def do_test(create_module_fn, exported_names=None, show_debug_info=False):
tf.saved_model.save(
create_module_fn(), save_model_path, options=save_options)
logging.info('Saved model to: %s', save_model_path)
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
mlir = pywrap_mlir.experimental_convert_saved_model_to_mlir(
save_model_path, ','.join(exported_names), show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
mlir, 'canonicalize', show_debug_info)
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
show_debug_info)
print(mlir)
app.run(app_main)

View File

@ -28,7 +28,7 @@ from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import
# Use /tmp to make debugging the tests easier (see README.md)
flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
@ -80,14 +80,15 @@ def do_test(signature_def_map, show_debug_info=False):
builder.save()
logging.info('Saved model to: %s', save_model_path)
mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir(
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
mlir, 'tf-standard-pipeline', show_debug_info)
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
'tf-standard-pipeline',
show_debug_info)
print(mlir)
app.run(app_main)

View File

@ -1130,6 +1130,7 @@ py_library(
":platform",
":pywrap_tensorflow",
":pywrap_tfe",
":pywrap_mlir",
":random_seed",
":sparse_tensor",
":tensor_spec",
@ -5544,7 +5545,6 @@ tf_py_wrap_cc(
"grappler/tf_optimizer.i",
"lib/core/strings.i",
"platform/base.i",
"//tensorflow/compiler/mlir/python:mlir.i",
],
# add win_def_file for pywrap_tensorflow
win_def_file = select({
@ -5572,7 +5572,6 @@ tf_py_wrap_cc(
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/compiler/mlir:passes",
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
@ -5593,6 +5592,7 @@ tf_py_wrap_cc(
"//tensorflow/lite/toco/python:toco_python_api",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//tensorflow/core/util/tensor_bundle",
"//tensorflow/compiler/mlir/python:mlir",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps()) + if_ngraph([
"@ngraph_tf//:ngraph_tf",
@ -5631,6 +5631,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
"//tensorflow/core/common_runtime/eager:context", # tfe
"//tensorflow/core/profiler/lib:profiler_session", # tfe
"//tensorflow/c:tf_status_helper", # tfe
"//tensorflow/compiler/mlir/python:mlir", # mlir
]
# Filter the DEF file to reduce the number of symbols to 64K or less.
@ -7680,6 +7681,34 @@ py_library(
],
)
py_library(
name = "pywrap_mlir",
srcs = ["pywrap_mlir.py"],
visibility = ["//visibility:public"],
deps = [
":_pywrap_mlir",
":pywrap_tensorflow",
],
)
tf_python_pybind_extension(
name = "_pywrap_mlir",
srcs = ["mlir_wrapper.cc"],
hdrs = [
"lib/core/safe_ptr.h",
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
"//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
],
module_name = "_pywrap_mlir",
deps = [
":pybind11_lib",
":pybind11_status",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
py_library(
name = "pywrap_tfe",
srcs = ["pywrap_tfe.py"],

View File

@ -10,7 +10,7 @@ py_library(
srcs = ["mlir.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:pywrap_mlir",
"//tensorflow/python:util",
],
)

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as import_graphdef
from tensorflow.python import pywrap_mlir
from tensorflow.python.util.tf_export import tf_export
@ -38,4 +38,4 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
Raises a RuntimeError on error.
"""
return import_graphdef.import_graphdef(graph_def, pass_pipeline)
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/compiler/mlir/python/mlir.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
PYBIND11_MODULE(_pywrap_mlir, m) {
m.def("ImportGraphDef",
[](const std::string &graphdef, const std::string &pass_pipeline) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output =
tensorflow::ImportGraphDef(graphdef, pass_pipeline, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalConvertSavedModelToMlir",
[](const std::string &saved_model_path,
const std::string &exported_names, bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output = tensorflow::ExperimentalConvertSavedModelToMlir(
saved_model_path, exported_names, show_debug_info, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalConvertSavedModelV1ToMlir",
[](const std::string &saved_model_path, const std::string &tags,
bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output =
tensorflow::ExperimentalConvertSavedModelV1ToMlir(
saved_model_path, tags, show_debug_info, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("ExperimentalRunPassPipeline",
[](const std::string &mlir_txt, const std::string &pass_pipeline,
bool show_debug_info) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
std::string output = tensorflow::ExperimentalRunPassPipeline(
mlir_txt, pass_pipeline, show_debug_info, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
};

View File

@ -0,0 +1,49 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python module for MLIR functions exported by pybind11."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
from tensorflow.python import pywrap_tensorflow
from tensorflow.python._pywrap_mlir import *
def import_graphdef(graphdef, pass_pipeline):
return ImportGraphDef(
str(graphdef).encode('utf-8'),
pass_pipeline.encode('utf-8'))
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
show_debug_info):
return ExperimentalConvertSavedModelToMlir(
str(saved_model_path).encode('utf-8'),
str(exported_names).encode('utf-8'), show_debug_info)
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
show_debug_info):
return ExperimentalConvertSavedModelV1ToMlir(
str(saved_model_path).encode('utf-8'),
str(tags).encode('utf-8'), show_debug_info)
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
return ExperimentalRunPassPipeline(
mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
show_debug_info)

View File

@ -24,8 +24,6 @@ limitations under the License.
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/compiler/mlir/python/mlir.i"
// TODO(slebedev): This is a temporary workaround for projects implicitly
// relying on TensorFlow exposing tensorflow::Status.
%unignoreall

View File

@ -189,3 +189,9 @@ tensorflow::Set_TF_Status_from_Status
[context] # tfe
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
[mlir] # mlir
tensorflow::ExperimentalRunPassPipeline
tensorflow::ExperimentalConvertSavedModelV1ToMlir
tensorflow::ExperimentalConvertSavedModelToMlir
tensorflow::ImportGraphDef