Export model_analyzer functions from C++ to Python with pybind11 instead of swig.
PiperOrigin-RevId: 281514947 Change-Id: If73d1225e1d48bf3461458f32b4b34f6a6ce99a1
This commit is contained in:
parent
3d82b95d45
commit
2aa9a78d87
tensorflow/python
@ -337,6 +337,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_model_analyzer",
|
||||
srcs = ["grappler/model_analyzer_wrapper.cc"],
|
||||
hdrs = ["grappler/model_analyzer.h"],
|
||||
module_name = "_pywrap_model_analyzer",
|
||||
deps = [
|
||||
":model_analyzer_lib",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item_builder",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "numpy_lib",
|
||||
srcs = ["lib/core/numpy.cc"],
|
||||
@ -5300,7 +5315,6 @@ tf_py_wrap_cc(
|
||||
"grappler/cluster.i",
|
||||
"grappler/cost_analyzer.i",
|
||||
"grappler/item.i",
|
||||
"grappler/model_analyzer.i",
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/bfloat16.i",
|
||||
"lib/core/strings.i",
|
||||
@ -5403,6 +5417,7 @@ genrule(
|
||||
srcs = [
|
||||
":cpp_python_util", # util
|
||||
":py_func_lib", # py_func
|
||||
":model_analyzer_lib", # model_analyzer
|
||||
"//tensorflow/core:util_port", # util_port
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
@ -7199,7 +7214,7 @@ py_library(
|
||||
"grappler/model_analyzer.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":pywrap_tensorflow_internal"],
|
||||
deps = [":_pywrap_model_analyzer"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
|
@ -1,64 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||
char* c_string;
|
||||
Py_ssize_t py_size;
|
||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
||||
SWIG_fail;
|
||||
}
|
||||
|
||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
||||
SWIG_fail;
|
||||
}
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/python/grappler/model_analyzer.h"
|
||||
%}
|
||||
|
||||
%{
|
||||
string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
|
||||
bool assume_valid_feeds, bool debug) {
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.apply_optimizations = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
|
||||
if (!item) {
|
||||
return "Error: failed to preprocess metagraph: check your log file for errors";
|
||||
}
|
||||
|
||||
string suffix;
|
||||
tensorflow::grappler::ModelAnalyzer analyzer(*item);
|
||||
|
||||
std::stringstream os;
|
||||
analyzer.GenerateReport(debug, assume_valid_feeds, os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
|
||||
bool assume_valid_feeds, bool debug);
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python import _pywrap_model_analyzer as tf_wrap
|
||||
|
||||
|
||||
def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
||||
@ -32,7 +32,5 @@ def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
||||
Returns:
|
||||
A string containing the report.
|
||||
"""
|
||||
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
|
||||
assume_valid_feeds, debug)
|
||||
|
||||
return ret_from_swig
|
||||
return tf_wrap.GenerateModelReport(
|
||||
metagraph.SerializeToString(), assume_valid_feeds, debug)
|
||||
|
54
tensorflow/python/grappler/model_analyzer_wrapper.cc
Normal file
54
tensorflow/python/grappler/model_analyzer_wrapper.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/python/grappler/model_analyzer.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_model_analyzer, m) {
|
||||
m.def("GenerateModelReport",
|
||||
[](const py::bytes& serialized_metagraph, bool assume_valid_feeds,
|
||||
bool debug) -> py::bytes {
|
||||
tensorflow::MetaGraphDef metagraph;
|
||||
if (!metagraph.ParseFromString(serialized_metagraph)) {
|
||||
return "The MetaGraphDef could not be parsed as a valid protocol "
|
||||
"buffer";
|
||||
}
|
||||
|
||||
tensorflow::grappler::ItemConfig cfg;
|
||||
cfg.apply_optimizations = false;
|
||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
|
||||
"metagraph", metagraph, cfg);
|
||||
if (item == nullptr) {
|
||||
return "Error: failed to preprocess metagraph: check your log file "
|
||||
"for errors";
|
||||
}
|
||||
|
||||
tensorflow::grappler::ModelAnalyzer analyzer(*item);
|
||||
|
||||
std::ostringstream os;
|
||||
tensorflow::MaybeRaiseFromStatus(
|
||||
analyzer.GenerateReport(debug, assume_valid_feeds, os));
|
||||
return py::bytes(os.str());
|
||||
});
|
||||
}
|
@ -32,6 +32,5 @@ limitations under the License.
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
%include "tensorflow/python/grappler/model_analyzer.i"
|
||||
|
||||
%include "tensorflow/compiler/mlir/python/mlir.i"
|
||||
|
Loading…
Reference in New Issue
Block a user