Export the stat_summarizer functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA and MLIR are using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 266439885
This commit is contained in:
parent
d72aca428b
commit
e4d8a30f03
@ -15,7 +15,7 @@ py_library(
|
|||||||
srcs = ["__init__.py"],
|
srcs = ["__init__.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:_pywrap_stat_summarizer",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -28,7 +28,7 @@ tf_py_test(
|
|||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:_pywrap_stat_summarizer",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
@ -22,13 +22,10 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||||
from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer
|
from tensorflow.python._pywrap_stat_summarizer import StatSummarizer
|
||||||
from tensorflow.python.pywrap_tensorflow import NewStatSummarizer
|
|
||||||
from tensorflow.python.pywrap_tensorflow import StatSummarizer
|
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
_allowed_symbols = ['DeleteStatSummarizer', 'NewStatSummarizer',
|
_allowed_symbols = ['StatSummarizer']
|
||||||
'StatSummarizer']
|
|
||||||
|
|
||||||
remove_undocumented(__name__, _allowed_symbols)
|
remove_undocumented(__name__, _allowed_symbols)
|
||||||
|
@ -19,7 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import _pywrap_stat_summarizer
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -36,7 +36,8 @@ class StatSummarizerTest(test.TestCase):
|
|||||||
product = math_ops.matmul(matrix1, matrix2, name=r"product")
|
product = math_ops.matmul(matrix1, matrix2, name=r"product")
|
||||||
|
|
||||||
graph_def = graph.as_graph_def()
|
graph_def = graph.as_graph_def()
|
||||||
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
|
ss = _pywrap_stat_summarizer.StatSummarizer(
|
||||||
|
graph_def.SerializeToString())
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
@ -69,8 +70,6 @@ class StatSummarizerTest(test.TestCase):
|
|||||||
# Test that a CDF summed to 100%
|
# Test that a CDF summed to 100%
|
||||||
self.assertRegexpMatches(output_string, r"100\.")
|
self.assertRegexpMatches(output_string, r"100\.")
|
||||||
|
|
||||||
pywrap_tensorflow.DeleteStatSummarizer(ss)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -395,6 +395,24 @@ tf_python_pybind_extension(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_stat_summarizer",
|
||||||
|
srcs = ["util/stat_summarizer_wrapper.cc"],
|
||||||
|
copts = [
|
||||||
|
"-fexceptions",
|
||||||
|
"-fno-strict-aliasing",
|
||||||
|
],
|
||||||
|
features = ["-use_header_modules"],
|
||||||
|
module_name = "_pywrap_stat_summarizer",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpp_python_util",
|
name = "cpp_python_util",
|
||||||
srcs = ["util/util.cc"],
|
srcs = ["util/util.cc"],
|
||||||
@ -703,6 +721,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_stat_summarizer",
|
||||||
":_pywrap_utils",
|
":_pywrap_utils",
|
||||||
":common_shapes",
|
":common_shapes",
|
||||||
":composite_tensor",
|
":composite_tensor",
|
||||||
@ -5001,7 +5020,6 @@ tf_py_wrap_cc(
|
|||||||
"util/port.i",
|
"util/port.i",
|
||||||
"util/py_checkpoint_reader.i",
|
"util/py_checkpoint_reader.i",
|
||||||
"util/scoped_annotation.i",
|
"util/scoped_annotation.i",
|
||||||
"util/stat_summarizer.i",
|
|
||||||
"util/tfprof.i",
|
"util/tfprof.i",
|
||||||
"util/traceme.i",
|
"util/traceme.i",
|
||||||
"util/transform_graph.i",
|
"util/transform_graph.i",
|
||||||
@ -5096,8 +5114,9 @@ genrule(
|
|||||||
name = "pybind_symbol_target_libs_file",
|
name = "pybind_symbol_target_libs_file",
|
||||||
srcs = [
|
srcs = [
|
||||||
":cpp_python_util", # util
|
":cpp_python_util", # util
|
||||||
|
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||||
],
|
],
|
||||||
outs = ["pybind_symbol_target_libs_file"],
|
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
"//tensorflow:windows": """
|
"//tensorflow:windows": """
|
||||||
for SRC in $(SRCS); do
|
for SRC in $(SRCS); do
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/util/port.i"
|
%include "tensorflow/python/util/port.i"
|
||||||
%include "tensorflow/python/util/py_checkpoint_reader.i"
|
%include "tensorflow/python/util/py_checkpoint_reader.i"
|
||||||
%include "tensorflow/python/util/stat_summarizer.i"
|
|
||||||
%include "tensorflow/python/util/tfprof.i"
|
%include "tensorflow/python/util/tfprof.i"
|
||||||
|
|
||||||
%include "tensorflow/python/lib/core/py_func.i"
|
%include "tensorflow/python/lib/core/py_func.i"
|
||||||
|
@ -1,78 +0,0 @@
|
|||||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
%include <std_string.i>
|
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
|
||||||
|
|
||||||
%{
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/util/stat_summarizer.h"
|
|
||||||
#include "tensorflow/python/lib/core/py_func.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
|
||||||
%}
|
|
||||||
|
|
||||||
%ignoreall
|
|
||||||
|
|
||||||
%unignore NewStatSummarizer;
|
|
||||||
%unignore DeleteStatSummarizer;
|
|
||||||
%unignore tensorflow;
|
|
||||||
%unignore tensorflow::StatSummarizer;
|
|
||||||
%unignore tensorflow::StatSummarizer::StatSummarizer;
|
|
||||||
%unignore tensorflow::StatSummarizer::~StatSummarizer;
|
|
||||||
%unignore tensorflow::StatSummarizer::Initialize;
|
|
||||||
%unignore tensorflow::StatSummarizer::InitializeStr;
|
|
||||||
%unignore tensorflow::StatSummarizer::ProcessStepStats;
|
|
||||||
%unignore tensorflow::StatSummarizer::ProcessStepStatsStr;
|
|
||||||
%unignore tensorflow::StatSummarizer::PrintStepStats;
|
|
||||||
%unignore tensorflow::StatSummarizer::GetOutputString;
|
|
||||||
|
|
||||||
|
|
||||||
// TODO(ashankar): Remove the unused argument from the API.
|
|
||||||
%{
|
|
||||||
tensorflow::StatSummarizer* NewStatSummarizer(
|
|
||||||
const string& unused) {
|
|
||||||
return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions());
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
|
|
||||||
%{
|
|
||||||
void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) {
|
|
||||||
delete ss;
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
|
|
||||||
tensorflow::StatSummarizer* NewStatSummarizer(const string& unused);
|
|
||||||
void DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
|
|
||||||
|
|
||||||
%extend tensorflow::StatSummarizer {
|
|
||||||
void ProcessStepStatsStr(const string& step_stats_str) {
|
|
||||||
tensorflow::StepStats step_stats;
|
|
||||||
step_stats.ParseFromString(step_stats_str);
|
|
||||||
$self->ProcessStepStats(step_stats);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
%extend tensorflow::StatSummarizer {
|
|
||||||
StatSummarizer() {
|
|
||||||
tensorflow::StatSummarizer* ss = new tensorflow::StatSummarizer(
|
|
||||||
tensorflow::StatSummarizerOptions());
|
|
||||||
return ss;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
%include "tensorflow/core/util/stat_summarizer_options.h"
|
|
||||||
%include "tensorflow/core/util/stat_summarizer.h"
|
|
||||||
%unignoreall
|
|
49
tensorflow/python/util/stat_summarizer_wrapper.cc
Normal file
49
tensorflow/python/util/stat_summarizer_wrapper.cc
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "include/pybind11/pytypes.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/util/stat_summarizer.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_stat_summarizer, m) {
|
||||||
|
py::class_<tensorflow::StatSummarizer> stat_summ_class(m, "StatSummarizer",
|
||||||
|
py::dynamic_attr());
|
||||||
|
stat_summ_class
|
||||||
|
.def(py::init([](std::string graph_def_serialized) {
|
||||||
|
tensorflow::GraphDef proto;
|
||||||
|
proto.ParseFromString(graph_def_serialized);
|
||||||
|
return new tensorflow::StatSummarizer(proto);
|
||||||
|
}))
|
||||||
|
.def(py::init([]() {
|
||||||
|
return new tensorflow::StatSummarizer(
|
||||||
|
tensorflow::StatSummarizerOptions());
|
||||||
|
}))
|
||||||
|
.def("ProcessStepStats", &tensorflow::StatSummarizer::ProcessStepStats)
|
||||||
|
.def("GetOutputString", &tensorflow::StatSummarizer::GetOutputString)
|
||||||
|
.def("PrintStepStats", &tensorflow::StatSummarizer::PrintStepStats)
|
||||||
|
.def("ProcessStepStatsStr", [](tensorflow::StatSummarizer& self,
|
||||||
|
const std::string& step_stats_str) {
|
||||||
|
tensorflow::StepStats step_stats;
|
||||||
|
step_stats.ParseFromString(step_stats_str);
|
||||||
|
self.ProcessStepStats(step_stats);
|
||||||
|
});
|
||||||
|
};
|
@ -17,3 +17,7 @@ tensorflow::swig::IsSequenceForData
|
|||||||
tensorflow::swig::FlattenForData
|
tensorflow::swig::FlattenForData
|
||||||
tensorflow::swig::AssertSameStructureForData
|
tensorflow::swig::AssertSameStructureForData
|
||||||
tensorflow::swig::RegisterType
|
tensorflow::swig::RegisterType
|
||||||
|
|
||||||
|
[stream_executor_pimpl]
|
||||||
|
stream_executor::StreamExecutor::EnablePeerAccessTo
|
||||||
|
stream_executor::StreamExecutor::CanEnablePeerAccessTo
|
||||||
|
Loading…
Reference in New Issue
Block a user