Export model_analyzer functions from C++ to Python with pybind11 instead of swig.
PiperOrigin-RevId: 281514947 Change-Id: If73d1225e1d48bf3461458f32b4b34f6a6ce99a1
This commit is contained in:
parent
3d82b95d45
commit
2aa9a78d87
tensorflow/python
@ -337,6 +337,21 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_model_analyzer",
|
||||||
|
srcs = ["grappler/model_analyzer_wrapper.cc"],
|
||||||
|
hdrs = ["grappler/model_analyzer.h"],
|
||||||
|
module_name = "_pywrap_model_analyzer",
|
||||||
|
deps = [
|
||||||
|
":model_analyzer_lib",
|
||||||
|
":pybind11_status",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler:grappler_item_builder",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "numpy_lib",
|
name = "numpy_lib",
|
||||||
srcs = ["lib/core/numpy.cc"],
|
srcs = ["lib/core/numpy.cc"],
|
||||||
@ -5300,7 +5315,6 @@ tf_py_wrap_cc(
|
|||||||
"grappler/cluster.i",
|
"grappler/cluster.i",
|
||||||
"grappler/cost_analyzer.i",
|
"grappler/cost_analyzer.i",
|
||||||
"grappler/item.i",
|
"grappler/item.i",
|
||||||
"grappler/model_analyzer.i",
|
|
||||||
"grappler/tf_optimizer.i",
|
"grappler/tf_optimizer.i",
|
||||||
"lib/core/bfloat16.i",
|
"lib/core/bfloat16.i",
|
||||||
"lib/core/strings.i",
|
"lib/core/strings.i",
|
||||||
@ -5403,6 +5417,7 @@ genrule(
|
|||||||
srcs = [
|
srcs = [
|
||||||
":cpp_python_util", # util
|
":cpp_python_util", # util
|
||||||
":py_func_lib", # py_func
|
":py_func_lib", # py_func
|
||||||
|
":model_analyzer_lib", # model_analyzer
|
||||||
"//tensorflow/core:util_port", # util_port
|
"//tensorflow/core:util_port", # util_port
|
||||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||||
@ -7199,7 +7214,7 @@ py_library(
|
|||||||
"grappler/model_analyzer.py",
|
"grappler/model_analyzer.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [":pywrap_tensorflow_internal"],
|
deps = [":_pywrap_model_analyzer"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
|
@ -1,64 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
|
||||||
char* c_string;
|
|
||||||
Py_ssize_t py_size;
|
|
||||||
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
|
|
||||||
// Python has raised an error (likely TypeError or UnicodeEncodeError).
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!temp.ParseFromString(string(c_string, py_size))) {
|
|
||||||
PyErr_SetString(
|
|
||||||
PyExc_TypeError,
|
|
||||||
"The MetaGraphDef could not be parsed as a valid protocol buffer");
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
$1 = &temp;
|
|
||||||
}
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
|
||||||
#include "tensorflow/python/grappler/model_analyzer.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
%{
|
|
||||||
string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
|
|
||||||
bool assume_valid_feeds, bool debug) {
|
|
||||||
tensorflow::grappler::ItemConfig cfg;
|
|
||||||
cfg.apply_optimizations = false;
|
|
||||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
|
||||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
|
|
||||||
if (!item) {
|
|
||||||
return "Error: failed to preprocess metagraph: check your log file for errors";
|
|
||||||
}
|
|
||||||
|
|
||||||
string suffix;
|
|
||||||
tensorflow::grappler::ModelAnalyzer analyzer(*item);
|
|
||||||
|
|
||||||
std::stringstream os;
|
|
||||||
analyzer.GenerateReport(debug, assume_valid_feeds, os);
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
%}
|
|
||||||
|
|
||||||
string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
|
|
||||||
bool assume_valid_feeds, bool debug);
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
from tensorflow.python import _pywrap_model_analyzer as tf_wrap
|
||||||
|
|
||||||
|
|
||||||
def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
||||||
@ -32,7 +32,5 @@ def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
|||||||
Returns:
|
Returns:
|
||||||
A string containing the report.
|
A string containing the report.
|
||||||
"""
|
"""
|
||||||
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
|
return tf_wrap.GenerateModelReport(
|
||||||
assume_valid_feeds, debug)
|
metagraph.SerializeToString(), assume_valid_feeds, debug)
|
||||||
|
|
||||||
return ret_from_swig
|
|
||||||
|
54
tensorflow/python/grappler/model_analyzer_wrapper.cc
Normal file
54
tensorflow/python/grappler/model_analyzer_wrapper.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item_builder.h"
|
||||||
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
#include "tensorflow/python/grappler/model_analyzer.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_model_analyzer, m) {
|
||||||
|
m.def("GenerateModelReport",
|
||||||
|
[](const py::bytes& serialized_metagraph, bool assume_valid_feeds,
|
||||||
|
bool debug) -> py::bytes {
|
||||||
|
tensorflow::MetaGraphDef metagraph;
|
||||||
|
if (!metagraph.ParseFromString(serialized_metagraph)) {
|
||||||
|
return "The MetaGraphDef could not be parsed as a valid protocol "
|
||||||
|
"buffer";
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::grappler::ItemConfig cfg;
|
||||||
|
cfg.apply_optimizations = false;
|
||||||
|
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||||
|
tensorflow::grappler::GrapplerItemFromMetaGraphDef(
|
||||||
|
"metagraph", metagraph, cfg);
|
||||||
|
if (item == nullptr) {
|
||||||
|
return "Error: failed to preprocess metagraph: check your log file "
|
||||||
|
"for errors";
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::grappler::ModelAnalyzer analyzer(*item);
|
||||||
|
|
||||||
|
std::ostringstream os;
|
||||||
|
tensorflow::MaybeRaiseFromStatus(
|
||||||
|
analyzer.GenerateReport(debug, assume_valid_feeds, os));
|
||||||
|
return py::bytes(os.str());
|
||||||
|
});
|
||||||
|
}
|
@ -32,6 +32,5 @@ limitations under the License.
|
|||||||
%include "tensorflow/python/grappler/item.i"
|
%include "tensorflow/python/grappler/item.i"
|
||||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||||
%include "tensorflow/python/grappler/model_analyzer.i"
|
|
||||||
|
|
||||||
%include "tensorflow/compiler/mlir/python/mlir.i"
|
%include "tensorflow/compiler/mlir/python/mlir.i"
|
||||||
|
Loading…
Reference in New Issue
Block a user