From 0949d6a59a4f5bd92c0c18b6ccdc8c1e130b58b8 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky Date: Mon, 23 Sep 2019 12:02:29 -0700 Subject: [PATCH] Export graph_analyzer functions from C++ to Python with pybind11 instead of swig. PiperOrigin-RevId: 270729160 --- tensorflow/python/BUILD | 15 +++++++++++++-- tensorflow/python/grappler/graph_analyzer.py | 2 +- ..._analyzer.i => graph_analyzer_tool_wrapper.cc} | 14 +++++--------- tensorflow/python/tensorflow.i | 1 - .../tools/def_file_filter/symbols_pybind.txt | 2 ++ 5 files changed, 21 insertions(+), 13 deletions(-) rename tensorflow/python/grappler/{graph_analyzer.i => graph_analyzer_tool_wrapper.cc} (72%) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index ee30b992d7f..5ed987dbacb 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4996,7 +4996,6 @@ tf_py_wrap_cc( "framework/python_op_gen.i", "grappler/cluster.i", "grappler/cost_analyzer.i", - "grappler/graph_analyzer.i", "grappler/item.i", "grappler/model_analyzer.i", "grappler/tf_optimizer.i", @@ -5112,6 +5111,7 @@ genrule( ":cpp_python_util", # util "//tensorflow/core:util_port", # util_port "//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer + "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer "//tensorflow/core/profiler/internal:print_model_analysis", # tfprof ], outs = ["pybind_symbol_target_libs_file.txt"], @@ -7015,6 +7015,17 @@ py_library( ], ) +tf_python_pybind_extension( + name = "_pywrap_graph_analyzer", + srcs = ["grappler/graph_analyzer_tool_wrapper.cc"], + module_name = "_pywrap_graph_analyzer", + deps = [ + "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", + "@com_google_absl//absl/strings", + "@pybind11", + ], +) + py_binary( name = "graph_analyzer", srcs = [ @@ -7023,8 +7034,8 @@ py_binary( python_version = "PY2", srcs_version = "PY2AND3", deps = [ + ":_pywrap_graph_analyzer", ":framework_for_generated_wrappers", - ":pywrap_tensorflow_internal", ], ) diff --git a/tensorflow/python/grappler/graph_analyzer.py b/tensorflow/python/grappler/graph_analyzer.py index ec5544e38e7..c46a74ea64c 100644 --- a/tensorflow/python/grappler/graph_analyzer.py +++ b/tensorflow/python/grappler/graph_analyzer.py @@ -25,7 +25,7 @@ from __future__ import print_function import argparse import sys -from tensorflow.python import pywrap_tensorflow as tf_wrap +from tensorflow.python import _pywrap_graph_analyzer as tf_wrap from tensorflow.python.platform import app diff --git a/tensorflow/python/grappler/graph_analyzer.i b/tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc similarity index 72% rename from tensorflow/python/grappler/graph_analyzer.i rename to tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc index cc7b5358eb6..2d82942d55f 100644 --- a/tensorflow/python/grappler/graph_analyzer.i +++ b/tensorflow/python/grappler/graph_analyzer_tool_wrapper.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%{ +#include "include/pybind11/pybind11.h" #include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h" -%} -%{ -void GraphAnalyzer(const string& file_path, int n) { - tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n); +PYBIND11_MODULE(_pywrap_graph_analyzer_tool, m) { + m.def("GraphAnalyzer", + &tensorflow::grappler::graph_analyzer::GraphAnalyzerTool); } -%} - -void GraphAnalyzer(const string& file_path, int n); diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 4ba9bcddd8e..9a5a1a3e0a9 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -49,7 +49,6 @@ limitations under the License. %include "tensorflow/python/grappler/item.i" %include "tensorflow/python/grappler/tf_optimizer.i" %include "tensorflow/python/grappler/cost_analyzer.i" -%include "tensorflow/python/grappler/graph_analyzer.i" %include "tensorflow/python/grappler/model_analyzer.i" %include "tensorflow/python/util/traceme.i" diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 02c79d10e26..896a7923bc0 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -38,3 +38,5 @@ tensorflow::tfprof::Profile tensorflow::tfprof::PrintModelAnalysis tensorflow::tfprof::SerializeToString +[graph_analyzer_tool] # graph_analyzer +tensorflow::grappler::graph_analyzer::GraphAnalyzerTool