Export the MLIR classes and functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 292076160 Change-Id: I62bf3aac988c3ce4e18aa01ee49d8aa9ffde383d
This commit is contained in:
parent
941950947d
commit
9d41ea557a
tensorflow
compiler/mlir
python
tools/def_file_filter
@ -3,9 +3,29 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["mlir.i"],
|
||||
visibility = [
|
||||
"//tensorflow/python:__subpackages__",
|
||||
cc_library(
|
||||
name = "mlir",
|
||||
srcs = ["mlir.cc"],
|
||||
hdrs = ["mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_mlir_hdrs",
|
||||
srcs = [
|
||||
"mlir.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
157
tensorflow/compiler/mlir/python/mlir.cc
Normal file
157
tensorflow/compiler/mlir/python/mlir.cc
Normal file
@ -0,0 +1,157 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Parser.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
std::string ImportGraphDef(const std::string &proto,
|
||||
const std::string &pass_pipeline,
|
||||
TF_Status *status) {
|
||||
GraphDef graphdef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
GraphImportConfig specs;
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Run the pass_pipeline on the module if not empty.
|
||||
if (!pass_pipeline.empty()) {
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
|
||||
if (failed(pm.run(*module.ValueOrDie()))) {
|
||||
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
return MlirModuleToString(*module.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelToMlir(
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
// Load the saved model into a SavedModelV2Bundle.
|
||||
|
||||
tensorflow::SavedModelV2Bundle bundle;
|
||||
auto load_status =
|
||||
tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
|
||||
if (!load_status.ok()) {
|
||||
Set_TF_Status_from_Status(status, load_status);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Convert the SavedModelV2Bundle to an MLIR module.
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelToMlir(
|
||||
&bundle, &context, absl::Span<std::string>(exported_names));
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
// Load the saved model into a SavedModelBundle.
|
||||
|
||||
std::unordered_set<string> tag_set =
|
||||
absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
auto load_status =
|
||||
tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
|
||||
if (!load_status.ok()) {
|
||||
Set_TF_Status_from_Status(status, load_status);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info,
|
||||
TF_Status *status) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
{
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
module = mlir::parseSourceString(mlir_txt, &context);
|
||||
if (!module) {
|
||||
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
|
||||
// Run the pass_pipeline on the module.
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
if (failed(pm.run(*module))) {
|
||||
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
return MlirModuleToString(*module, show_debug_info);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
67
tensorflow/compiler/mlir/python/mlir.h
Normal file
67
tensorflow/compiler/mlir/python/mlir.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Functions for getting information about kernels registered in the binary.
|
||||
// Migrated from previous SWIG file (mlir.i) authored by aminim@.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
||||
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
|
||||
// returning it as a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
std::string ImportGraphDef(const std::string &proto,
|
||||
const std::string &pass_pipeline, TF_Status *status);
|
||||
|
||||
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
// Args:
|
||||
// saved_model_path: File path from which to load the SavedModel.
|
||||
// exported_names_str: Comma-separated list of names to export.
|
||||
// Empty means "export all".
|
||||
//
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelToMlir(
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
// Args:
|
||||
// saved_model_path: File path from which to load the SavedModel.
|
||||
// tags: Tags to identify MetaGraphDef that need to be loaded.
|
||||
//
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||
const std::string &pass_pipeline,
|
||||
bool show_debug_info,
|
||||
TF_Status *status);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_
|
@ -1,252 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
// Simple wrapper to support tf.mlir.experimental.convert_graph_def.
|
||||
// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
|
||||
// returning it as a string.
|
||||
// This is an early experimental API, ideally we should return a wrapper object
|
||||
// around a Python binding to the MLIR module.
|
||||
string ImportGraphDef(const string &proto, const string &pass_pipeline, TF_Status* status) {
|
||||
GraphDef graphdef;
|
||||
auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
GraphImportConfig specs;
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Run the pass_pipeline on the module if not empty.
|
||||
if (!pass_pipeline.empty()) {
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(&context);
|
||||
if (failed(pm.run(*module.ValueOrDie()))) {
|
||||
Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
return MlirModuleToString(*module.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
// Load a SavedModel and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
// Args:
|
||||
// saved_model_path: File path from which to load the SavedModel.
|
||||
// exported_names_str: Comma-separated list of names to export.
|
||||
// Empty means "export all".
|
||||
//
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
string ExperimentalConvertSavedModelToMlir(
|
||||
const string &saved_model_path,
|
||||
const string &exported_names_str,
|
||||
bool show_debug_info,
|
||||
TF_Status* status) {
|
||||
// Load the saved model into a SavedModelV2Bundle.
|
||||
|
||||
tensorflow::SavedModelV2Bundle bundle;
|
||||
auto load_status = tensorflow::SavedModelV2Bundle::Load(
|
||||
saved_model_path, &bundle);
|
||||
if (!load_status.ok()) {
|
||||
Set_TF_Status_from_Status(status, load_status);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Convert the SavedModelV2Bundle to an MLIR module.
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelToMlir(&bundle, &context,
|
||||
absl::Span<std::string>(exported_names));
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
// Args:
|
||||
// saved_model_path: File path from which to load the SavedModel.
|
||||
// tags: Tags to identify MetaGraphDef that need to be loaded.
|
||||
//
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const string &saved_model_path,
|
||||
const string &tags,
|
||||
bool show_debug_info,
|
||||
TF_Status* status) {
|
||||
// Load the saved model into a SavedModelBundle.
|
||||
|
||||
std::unordered_set<string> tag_set
|
||||
= absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||
|
||||
tensorflow::SavedModelBundle bundle;
|
||||
auto load_status = tensorflow::LoadSavedModel(
|
||||
{}, {},
|
||||
saved_model_path, tag_set, &bundle);
|
||||
if (!load_status.ok()) {
|
||||
Set_TF_Status_from_Status(status, load_status);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||
}
|
||||
|
||||
|
||||
string ExperimentalRunPassPipeline(
|
||||
const string &mlir_txt,
|
||||
const string &pass_pipeline,
|
||||
bool show_debug_info,
|
||||
TF_Status* status) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
{
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
module = mlir::parseSourceString(mlir_txt, &context);
|
||||
if (!module) {
|
||||
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
}
|
||||
|
||||
// Run the pass_pipeline on the module.
|
||||
mlir::PassManager pm(&context);
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
("Invalid pass_pipeline: " + error_stream.str()).c_str());
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
|
||||
if (failed(pm.run(*module))) {
|
||||
Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
|
||||
return "// error";
|
||||
}
|
||||
return MlirModuleToString(*module, show_debug_info);
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::swig;
|
||||
%unignore tensorflow::swig::ImportGraphDef;
|
||||
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
|
||||
%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir;
|
||||
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
|
||||
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
static string ImportGraphDef(const string &graphdef,
|
||||
const string &pass_pipeline,
|
||||
TF_Status* status);
|
||||
static string ExperimentalConvertSavedModelToMlir(
|
||||
const string &saved_model_path,
|
||||
const string &exported_names,
|
||||
bool show_debug_info,
|
||||
TF_Status* status);
|
||||
static string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const string &saved_model_path,
|
||||
const string &tags,
|
||||
bool show_debug_info,
|
||||
TF_Status* status);
|
||||
static string ExperimentalRunPassPipeline(
|
||||
const string &mlir_txt,
|
||||
const string &pass_pipeline,
|
||||
bool show_debug_info,
|
||||
TF_Status* status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
%insert("python") %{
|
||||
def import_graphdef(graphdef, pass_pipeline):
|
||||
return ImportGraphDef(str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8')).decode('utf-8');
|
||||
|
||||
def experimental_convert_saved_model_to_mlir(saved_model_path,
|
||||
exported_names,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelToMlir(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(exported_names).encode('utf-8'),
|
||||
show_debug_info
|
||||
).decode('utf-8');
|
||||
|
||||
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
|
||||
tags, show_debug_info):
|
||||
return ExperimentalConvertSavedModelV1ToMlir(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(tags).encode('utf-8'),
|
||||
show_debug_info
|
||||
).decode('utf-8');
|
||||
|
||||
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
||||
return ExperimentalRunPassPipeline(
|
||||
mlir_txt.encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'),
|
||||
show_debug_info
|
||||
).decode('utf-8');
|
||||
%}
|
||||
|
||||
%unignoreall
|
@ -29,7 +29,7 @@ from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
# Use /tmp to make debugging the tests easier (see README.md)
|
||||
flags.DEFINE_string('save_model_path', '',
|
||||
@ -84,13 +84,13 @@ def do_test(create_module_fn, exported_names=None, show_debug_info=False):
|
||||
tf.saved_model.save(
|
||||
create_module_fn(), save_model_path, options=save_options)
|
||||
logging.info('Saved model to: %s', save_model_path)
|
||||
mlir = pywrap_tensorflow.experimental_convert_saved_model_to_mlir(
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_to_mlir(
|
||||
save_model_path, ','.join(exported_names), show_debug_info)
|
||||
# We don't strictly need this, but it serves as a handy sanity check
|
||||
# for that API, which is otherwise a bit annoying to test.
|
||||
# The canonicalization shouldn't affect these tests in any way.
|
||||
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
|
||||
mlir, 'canonicalize', show_debug_info)
|
||||
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize',
|
||||
show_debug_info)
|
||||
print(mlir)
|
||||
|
||||
app.run(app_main)
|
||||
|
@ -28,7 +28,7 @@ from absl import flags
|
||||
from absl import logging
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
# Use /tmp to make debugging the tests easier (see README.md)
|
||||
flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
|
||||
@ -80,14 +80,15 @@ def do_test(signature_def_map, show_debug_info=False):
|
||||
builder.save()
|
||||
|
||||
logging.info('Saved model to: %s', save_model_path)
|
||||
mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir(
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
show_debug_info)
|
||||
# We don't strictly need this, but it serves as a handy sanity check
|
||||
# for that API, which is otherwise a bit annoying to test.
|
||||
# The canonicalization shouldn't affect these tests in any way.
|
||||
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
|
||||
mlir, 'tf-standard-pipeline', show_debug_info)
|
||||
mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir,
|
||||
'tf-standard-pipeline',
|
||||
show_debug_info)
|
||||
print(mlir)
|
||||
|
||||
app.run(app_main)
|
||||
|
@ -1130,6 +1130,7 @@ py_library(
|
||||
":platform",
|
||||
":pywrap_tensorflow",
|
||||
":pywrap_tfe",
|
||||
":pywrap_mlir",
|
||||
":random_seed",
|
||||
":sparse_tensor",
|
||||
":tensor_spec",
|
||||
@ -5544,7 +5545,6 @@ tf_py_wrap_cc(
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/strings.i",
|
||||
"platform/base.i",
|
||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||
],
|
||||
# add win_def_file for pywrap_tensorflow
|
||||
win_def_file = select({
|
||||
@ -5572,7 +5572,6 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/compiler/mlir:passes",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
@ -5593,6 +5592,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/lite/toco/python:toco_python_api",
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
"//tensorflow/compiler/mlir/python:mlir",
|
||||
] + (tf_additional_lib_deps() +
|
||||
tf_additional_plugin_deps()) + if_ngraph([
|
||||
"@ngraph_tf//:ngraph_tf",
|
||||
@ -5631,6 +5631,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
||||
"//tensorflow/core/common_runtime/eager:context", # tfe
|
||||
"//tensorflow/core/profiler/lib:profiler_session", # tfe
|
||||
"//tensorflow/c:tf_status_helper", # tfe
|
||||
"//tensorflow/compiler/mlir/python:mlir", # mlir
|
||||
]
|
||||
|
||||
# Filter the DEF file to reduce the number of symbols to 64K or less.
|
||||
@ -7680,6 +7681,34 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pywrap_mlir",
|
||||
srcs = ["pywrap_mlir.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":_pywrap_mlir",
|
||||
":pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_mlir",
|
||||
srcs = ["mlir_wrapper.cc"],
|
||||
hdrs = [
|
||||
"lib/core/safe_ptr.h",
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/compiler/mlir/python:pywrap_mlir_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_mlir",
|
||||
deps = [
|
||||
":pybind11_lib",
|
||||
":pybind11_status",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pywrap_tfe",
|
||||
srcs = ["pywrap_tfe.py"],
|
||||
|
@ -10,7 +10,7 @@ py_library(
|
||||
srcs = ["mlir.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:pywrap_mlir",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as import_graphdef
|
||||
from tensorflow.python import pywrap_mlir
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -38,4 +38,4 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
|
||||
Raises a RuntimeError on error.
|
||||
|
||||
"""
|
||||
return import_graphdef.import_graphdef(graph_def, pass_pipeline)
|
||||
return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
|
||||
|
67
tensorflow/python/mlir_wrapper.cc
Normal file
67
tensorflow/python/mlir_wrapper.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/compiler/mlir/python/mlir.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
PYBIND11_MODULE(_pywrap_mlir, m) {
|
||||
m.def("ImportGraphDef",
|
||||
[](const std::string &graphdef, const std::string &pass_pipeline) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output =
|
||||
tensorflow::ImportGraphDef(graphdef, pass_pipeline, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelToMlir",
|
||||
[](const std::string &saved_model_path,
|
||||
const std::string &exported_names, bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output = tensorflow::ExperimentalConvertSavedModelToMlir(
|
||||
saved_model_path, exported_names, show_debug_info, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalConvertSavedModelV1ToMlir",
|
||||
[](const std::string &saved_model_path, const std::string &tags,
|
||||
bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output =
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlir(
|
||||
saved_model_path, tags, show_debug_info, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
|
||||
m.def("ExperimentalRunPassPipeline",
|
||||
[](const std::string &mlir_txt, const std::string &pass_pipeline,
|
||||
bool show_debug_info) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
std::string output = tensorflow::ExperimentalRunPassPipeline(
|
||||
mlir_txt, pass_pipeline, show_debug_info, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return output;
|
||||
});
|
||||
};
|
49
tensorflow/python/pywrap_mlir.py
Normal file
49
tensorflow/python/pywrap_mlir.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python module for MLIR functions exported by pybind11."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python._pywrap_mlir import *
|
||||
|
||||
|
||||
def import_graphdef(graphdef, pass_pipeline):
|
||||
return ImportGraphDef(
|
||||
str(graphdef).encode('utf-8'),
|
||||
pass_pipeline.encode('utf-8'))
|
||||
|
||||
|
||||
def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelToMlir(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(exported_names).encode('utf-8'), show_debug_info)
|
||||
|
||||
|
||||
def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
|
||||
show_debug_info):
|
||||
return ExperimentalConvertSavedModelV1ToMlir(
|
||||
str(saved_model_path).encode('utf-8'),
|
||||
str(tags).encode('utf-8'), show_debug_info)
|
||||
|
||||
|
||||
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
||||
return ExperimentalRunPassPipeline(
|
||||
mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'),
|
||||
show_debug_info)
|
@ -24,8 +24,6 @@ limitations under the License.
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
|
||||
%include "tensorflow/compiler/mlir/python/mlir.i"
|
||||
|
||||
// TODO(slebedev): This is a temporary workaround for projects implicitly
|
||||
// relying on TensorFlow exposing tensorflow::Status.
|
||||
%unignoreall
|
||||
|
@ -189,3 +189,9 @@ tensorflow::Set_TF_Status_from_Status
|
||||
|
||||
[context] # tfe
|
||||
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||
|
||||
[mlir] # mlir
|
||||
tensorflow::ExperimentalRunPassPipeline
|
||||
tensorflow::ExperimentalConvertSavedModelV1ToMlir
|
||||
tensorflow::ExperimentalConvertSavedModelToMlir
|
||||
tensorflow::ImportGraphDef
|
||||
|
Loading…
Reference in New Issue
Block a user