Export graph_analyzer functions from C++ to Python with pybind11 instead of swig.
PiperOrigin-RevId: 270729160
This commit is contained in:
parent
da9971c3c1
commit
0949d6a59a
@ -4996,7 +4996,6 @@ tf_py_wrap_cc(
|
|||||||
"framework/python_op_gen.i",
|
"framework/python_op_gen.i",
|
||||||
"grappler/cluster.i",
|
"grappler/cluster.i",
|
||||||
"grappler/cost_analyzer.i",
|
"grappler/cost_analyzer.i",
|
||||||
"grappler/graph_analyzer.i",
|
|
||||||
"grappler/item.i",
|
"grappler/item.i",
|
||||||
"grappler/model_analyzer.i",
|
"grappler/model_analyzer.i",
|
||||||
"grappler/tf_optimizer.i",
|
"grappler/tf_optimizer.i",
|
||||||
@ -5112,6 +5111,7 @@ genrule(
|
|||||||
":cpp_python_util", # util
|
":cpp_python_util", # util
|
||||||
"//tensorflow/core:util_port", # util_port
|
"//tensorflow/core:util_port", # util_port
|
||||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||||
|
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", # graph_analyzer
|
||||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||||
],
|
],
|
||||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||||
@ -7015,6 +7015,17 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_graph_analyzer",
|
||||||
|
srcs = ["grappler/graph_analyzer_tool_wrapper.cc"],
|
||||||
|
module_name = "_pywrap_graph_analyzer",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "graph_analyzer",
|
name = "graph_analyzer",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -7023,8 +7034,8 @@ py_binary(
|
|||||||
python_version = "PY2",
|
python_version = "PY2",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_graph_analyzer",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":pywrap_tensorflow_internal",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
from tensorflow.python import _pywrap_graph_analyzer as tf_wrap
|
||||||
from tensorflow.python.platform import app
|
from tensorflow.python.platform import app
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -13,14 +13,10 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
%{
|
#include "include/pybind11/pybind11.h"
|
||||||
#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
|
#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
|
||||||
%}
|
|
||||||
|
|
||||||
%{
|
PYBIND11_MODULE(_pywrap_graph_analyzer_tool, m) {
|
||||||
void GraphAnalyzer(const string& file_path, int n) {
|
m.def("GraphAnalyzer",
|
||||||
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n);
|
&tensorflow::grappler::graph_analyzer::GraphAnalyzerTool);
|
||||||
}
|
}
|
||||||
%}
|
|
||||||
|
|
||||||
void GraphAnalyzer(const string& file_path, int n);
|
|
@ -49,7 +49,6 @@ limitations under the License.
|
|||||||
%include "tensorflow/python/grappler/item.i"
|
%include "tensorflow/python/grappler/item.i"
|
||||||
%include "tensorflow/python/grappler/tf_optimizer.i"
|
%include "tensorflow/python/grappler/tf_optimizer.i"
|
||||||
%include "tensorflow/python/grappler/cost_analyzer.i"
|
%include "tensorflow/python/grappler/cost_analyzer.i"
|
||||||
%include "tensorflow/python/grappler/graph_analyzer.i"
|
|
||||||
%include "tensorflow/python/grappler/model_analyzer.i"
|
%include "tensorflow/python/grappler/model_analyzer.i"
|
||||||
|
|
||||||
%include "tensorflow/python/util/traceme.i"
|
%include "tensorflow/python/util/traceme.i"
|
||||||
|
@ -38,3 +38,5 @@ tensorflow::tfprof::Profile
|
|||||||
tensorflow::tfprof::PrintModelAnalysis
|
tensorflow::tfprof::PrintModelAnalysis
|
||||||
tensorflow::tfprof::SerializeToString
|
tensorflow::tfprof::SerializeToString
|
||||||
|
|
||||||
|
[graph_analyzer_tool] # graph_analyzer
|
||||||
|
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
|
||||||
|
Loading…
x
Reference in New Issue
Block a user