Export model_analyzer functions from C++ to Python with pybind11 instead of swig.

PiperOrigin-RevId: 281514947
Change-Id: If73d1225e1d48bf3461458f32b4b34f6a6ce99a1
This commit is contained in:
Peter Buchlovsky 2019-11-20 07:22:26 -08:00 committed by TensorFlower Gardener
parent 3d82b95d45
commit 2aa9a78d87
5 changed files with 74 additions and 72 deletions

View File

@ -337,6 +337,21 @@ cc_library(
],
)
tf_python_pybind_extension(
name = "_pywrap_model_analyzer",
srcs = ["grappler/model_analyzer_wrapper.cc"],
hdrs = ["grappler/model_analyzer.h"],
module_name = "_pywrap_model_analyzer",
deps = [
":model_analyzer_lib",
":pybind11_status",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item_builder",
"@pybind11",
],
)
cc_library(
name = "numpy_lib",
srcs = ["lib/core/numpy.cc"],
@ -5300,7 +5315,6 @@ tf_py_wrap_cc(
"grappler/cluster.i",
"grappler/cost_analyzer.i",
"grappler/item.i",
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
"lib/core/bfloat16.i",
"lib/core/strings.i",
@ -5403,6 +5417,7 @@ genrule(
srcs = [
":cpp_python_util", # util
":py_func_lib", # py_func
":model_analyzer_lib", # model_analyzer
"//tensorflow/core:util_port", # util_port
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
@ -7199,7 +7214,7 @@ py_library(
"grappler/model_analyzer.py",
],
srcs_version = "PY2AND3",
deps = [":pywrap_tensorflow_internal"],
deps = [":_pywrap_model_analyzer"],
)
tf_py_test(

View File

@ -1,64 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
// Python has raised an error (likely TypeError or UnicodeEncodeError).
SWIG_fail;
}
if (!temp.ParseFromString(string(c_string, py_size))) {
PyErr_SetString(
PyExc_TypeError,
"The MetaGraphDef could not be parsed as a valid protocol buffer");
SWIG_fail;
}
$1 = &temp;
}
%{
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/python/grappler/model_analyzer.h"
%}
%{
string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
bool assume_valid_feeds, bool debug) {
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
if (!item) {
return "Error: failed to preprocess metagraph: check your log file for errors";
}
string suffix;
tensorflow::grappler::ModelAnalyzer analyzer(*item);
std::stringstream os;
analyzer.GenerateReport(debug, assume_valid_feeds, os);
return os.str();
}
%}
string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
bool assume_valid_feeds, bool debug);

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python import _pywrap_model_analyzer as tf_wrap
def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
@ -32,7 +32,5 @@ def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
Returns:
A string containing the report.
"""
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
assume_valid_feeds, debug)
return ret_from_swig
return tf_wrap.GenerateModelReport(
metagraph.SerializeToString(), assume_valid_feeds, debug)

View File

@ -0,0 +1,54 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <sstream>
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/python/grappler/model_analyzer.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_model_analyzer, m) {
m.def("GenerateModelReport",
[](const py::bytes& serialized_metagraph, bool assume_valid_feeds,
bool debug) -> py::bytes {
tensorflow::MetaGraphDef metagraph;
if (!metagraph.ParseFromString(serialized_metagraph)) {
return "The MetaGraphDef could not be parsed as a valid protocol "
"buffer";
}
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
"metagraph", metagraph, cfg);
if (item == nullptr) {
return "Error: failed to preprocess metagraph: check your log file "
"for errors";
}
tensorflow::grappler::ModelAnalyzer analyzer(*item);
std::ostringstream os;
tensorflow::MaybeRaiseFromStatus(
analyzer.GenerateReport(debug, assume_valid_feeds, os));
return py::bytes(os.str());
});
}

View File

@ -32,6 +32,5 @@ limitations under the License.
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
%include "tensorflow/compiler/mlir/python/mlir.i"