Export graph_analyzer functions from C++ to Python with pybind11 instead of swig.

PiperOrigin-RevId: 270729160
This commit is contained in:
Peter Buchlovsky 2019-09-23 12:02:29 -07:00 committed by TensorFlower Gardener
parent da9971c3c1
commit 0949d6a59a
5 changed files with 21 additions and 13 deletions

View File

@ -4996,7 +4996,6 @@ tf_py_wrap_cc(
"framework/python_op_gen.i",
"grappler/cluster.i",
"grappler/cost_analyzer.i",
"grappler/graph_analyzer.i",
"grappler/item.i",
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
@ -5112,6 +5111,7 @@ genrule(
":cpp_python_util", # util
"//tensorflow/core:util_port", # util_port
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
],
outs = ["pybind_symbol_target_libs_file.txt"],
@ -7015,6 +7015,17 @@ py_library(
],
)
tf_python_pybind_extension(
name = "_pywrap_graph_analyzer",
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],
module_name = "_pywrap_graph_analyzer",
deps = [
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
"@com_google_absl//absl/strings",
"@pybind11",
],
)
py_binary(
name = "graph_analyzer",
srcs = [
@ -7023,8 +7034,8 @@ py_binary(
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":_pywrap_graph_analyzer",
":framework_for_generated_wrappers",
":pywrap_tensorflow_internal",
],
)

View File

@ -25,7 +25,7 @@ from __future__ import print_function
import argparse
import sys
from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python import _pywrap_graph_analyzer as tf_wrap
from tensorflow.python.platform import app

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,14 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
%}
%{
void GraphAnalyzer(const string& file_path, int n) {
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n);
PYBIND11_MODULE(_pywrap_graph_analyzer_tool, m) {
m.def("GraphAnalyzer",
&tensorflow::grappler::graph_analyzer::GraphAnalyzerTool);
}
%}
void GraphAnalyzer(const string& file_path, int n);

View File

@ -49,7 +49,6 @@ limitations under the License.
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/python/grappler/graph_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
%include "tensorflow/python/util/traceme.i"

View File

@ -38,3 +38,5 @@ tensorflow::tfprof::Profile
tensorflow::tfprof::PrintModelAnalysis
tensorflow::tfprof::SerializeToString
[graph_analyzer_tool] # graph_analyzer
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool