From 2aa9a78d87235fa839135b26afe9cc2362e7339e Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Wed, 20 Nov 2019 07:22:26 -0800 Subject: [PATCH] Export model_analyzer functions from C++ to Python with pybind11 instead of swig. PiperOrigin-RevId: 281514947 Change-Id: If73d1225e1d48bf3461458f32b4b34f6a6ce99a1 --- tensorflow/python/BUILD | 19 +++++- tensorflow/python/grappler/model_analyzer.i | 64 ------------------- tensorflow/python/grappler/model_analyzer.py | 8 +-- .../python/grappler/model_analyzer_wrapper.cc | 54 ++++++++++++++++ tensorflow/python/tensorflow.i | 1 - 5 files changed, 74 insertions(+), 72 deletions(-) delete mode 100644 tensorflow/python/grappler/model_analyzer.i create mode 100644 tensorflow/python/grappler/model_analyzer_wrapper.cc diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 986a60e0f88..a18a9786179 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -337,6 +337,21 @@ cc_library( ], ) +tf_python_pybind_extension( + name = "_pywrap_model_analyzer", + srcs = ["grappler/model_analyzer_wrapper.cc"], + hdrs = ["grappler/model_analyzer.h"], + module_name = "_pywrap_model_analyzer", + deps = [ + ":model_analyzer_lib", + ":pybind11_status", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item_builder", + "@pybind11", + ], +) + cc_library( name = "numpy_lib", srcs = ["lib/core/numpy.cc"], @@ -5300,7 +5315,6 @@ tf_py_wrap_cc( "grappler/cluster.i", "grappler/cost_analyzer.i", "grappler/item.i", - "grappler/model_analyzer.i", "grappler/tf_optimizer.i", "lib/core/bfloat16.i", "lib/core/strings.i", @@ -5403,6 +5417,7 @@ genrule( srcs = [ ":cpp_python_util", # util ":py_func_lib", # py_func + ":model_analyzer_lib", # model_analyzer "//tensorflow/core:util_port", # util_port "//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer @@ -7199,7 +7214,7 @@ py_library( "grappler/model_analyzer.py", ], srcs_version = "PY2AND3", - deps = [":pywrap_tensorflow_internal"], + deps = [":_pywrap_model_analyzer"], ) tf_py_test( diff --git a/tensorflow/python/grappler/model_analyzer.i b/tensorflow/python/grappler/model_analyzer.i deleted file mode 100644 index 4955780764b..00000000000 --- a/tensorflow/python/grappler/model_analyzer.i +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%include "tensorflow/python/lib/core/strings.i" -%include "tensorflow/python/platform/base.i" - -%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { - char* c_string; - Py_ssize_t py_size; - if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { - // Python has raised an error (likely TypeError or UnicodeEncodeError). - SWIG_fail; - } - - if (!temp.ParseFromString(string(c_string, py_size))) { - PyErr_SetString( - PyExc_TypeError, - "The MetaGraphDef could not be parsed as a valid protocol buffer"); - SWIG_fail; - } - $1 = &temp; -} - -%{ -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/grappler_item_builder.h" -#include "tensorflow/python/grappler/model_analyzer.h" -%} - -%{ -string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, - bool assume_valid_feeds, bool debug) { - tensorflow::grappler::ItemConfig cfg; - cfg.apply_optimizations = false; - std::unique_ptr item = - tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg); - if (!item) { - return "Error: failed to preprocess metagraph: check your log file for errors"; - } - - string suffix; - tensorflow::grappler::ModelAnalyzer analyzer(*item); - - std::stringstream os; - analyzer.GenerateReport(debug, assume_valid_feeds, os); - return os.str(); -} - -%} - -string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, - bool assume_valid_feeds, bool debug); diff --git a/tensorflow/python/grappler/model_analyzer.py b/tensorflow/python/grappler/model_analyzer.py index 417194a11af..57d8486acf3 100644 --- a/tensorflow/python/grappler/model_analyzer.py +++ b/tensorflow/python/grappler/model_analyzer.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import pywrap_tensorflow as tf_wrap +from tensorflow.python import _pywrap_model_analyzer as tf_wrap def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False): @@ -32,7 +32,5 @@ def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False): Returns: A string containing the report. """ - ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(), - assume_valid_feeds, debug) - - return ret_from_swig + return tf_wrap.GenerateModelReport( + metagraph.SerializeToString(), assume_valid_feeds, debug) diff --git a/tensorflow/python/grappler/model_analyzer_wrapper.cc b/tensorflow/python/grappler/model_analyzer_wrapper.cc new file mode 100644 index 00000000000..d9699a69a8d --- /dev/null +++ b/tensorflow/python/grappler/model_analyzer_wrapper.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "include/pybind11/pybind11.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/python/grappler/model_analyzer.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_pywrap_model_analyzer, m) { + m.def("GenerateModelReport", + [](const py::bytes& serialized_metagraph, bool assume_valid_feeds, + bool debug) -> py::bytes { + tensorflow::MetaGraphDef metagraph; + if (!metagraph.ParseFromString(serialized_metagraph)) { + return "The MetaGraphDef could not be parsed as a valid protocol " + "buffer"; + } + + tensorflow::grappler::ItemConfig cfg; + cfg.apply_optimizations = false; + std::unique_ptr item = + tensorflow::grappler::GrapplerItemFromMetaGraphDef( + "metagraph", metagraph, cfg); + if (item == nullptr) { + return "Error: failed to preprocess metagraph: check your log file " + "for errors"; + } + + tensorflow::grappler::ModelAnalyzer analyzer(*item); + + std::ostringstream os; + tensorflow::MaybeRaiseFromStatus( + analyzer.GenerateReport(debug, assume_valid_feeds, os)); + return py::bytes(os.str()); + }); +} diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 7985b9cd9d6..413b5126e77 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -32,6 +32,5 @@ limitations under the License. %include "tensorflow/python/grappler/item.i" %include "tensorflow/python/grappler/tf_optimizer.i" %include "tensorflow/python/grappler/cost_analyzer.i" -%include "tensorflow/python/grappler/model_analyzer.i" %include "tensorflow/compiler/mlir/python/mlir.i"