Export graph_analyzer functions from C++ to Python with pybind11 instead of swig.
PiperOrigin-RevId: 270729160
This commit is contained in:
parent
da9971c3c1
commit
0949d6a59a
tensorflow
python
tools/def_file_filter
@ -4996,7 +4996,6 @@ tf_py_wrap_cc(
|
||||
"framework/python_op_gen.i",
|
||||
"grappler/cluster.i",
|
||||
"grappler/cost_analyzer.i",
|
||||
"grappler/graph_analyzer.i",
|
||||
"grappler/item.i",
|
||||
"grappler/model_analyzer.i",
|
||||
"grappler/tf_optimizer.i",
|
||||
@ -5112,6 +5111,7 @@ genrule(
|
||||
":cpp_python_util", # util
|
||||
"//tensorflow/core:util_port", # util_port
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||
],
|
||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||
@ -7015,6 +7015,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_graph_analyzer",
|
||||
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],
|
||||
module_name = "_pywrap_graph_analyzer",
|
||||
deps = [
|
||||
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "graph_analyzer",
|
||||
srcs = [
|
||||
@ -7023,8 +7034,8 @@ py_binary(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_graph_analyzer",
|
||||
":framework_for_generated_wrappers",
|
||||
":pywrap_tensorflow_internal",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -25,7 +25,7 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python import _pywrap_graph_analyzer as tf_wrap
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,14 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%{
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
|
||||
%}
|
||||
|
||||
%{
|
||||
void GraphAnalyzer(const string& file_path, int n) {
|
||||
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n);
|
||||
PYBIND11_MODULE(_pywrap_graph_analyzer_tool, m) {
|
||||
m.def("GraphAnalyzer",
|
||||
&tensorflow::grappler::graph_analyzer::GraphAnalyzerTool);
|
||||
}
|
||||
%}
|
||||
|
||||
void GraphAnalyzer(const string& file_path, int n);
|
@ -49,7 +49,6 @@ limitations under the License.
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||
%include "tensorflow/python/grappler/graph_analyzer.i"
|
||||
%include "tensorflow/python/grappler/model_analyzer.i"
|
||||
|
||||
%include "tensorflow/python/util/traceme.i"
|
||||
|
@ -38,3 +38,5 @@ tensorflow::tfprof::Profile
|
||||
tensorflow::tfprof::PrintModelAnalysis
|
||||
tensorflow::tfprof::SerializeToString
|
||||
|
||||
[graph_analyzer_tool] # graph_analyzer
|
||||
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
|
||||
|
Loading…
Reference in New Issue
Block a user