Export the stat_summarizer functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA and MLIR are using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.

PiperOrigin-RevId: 266439885
This commit is contained in:
Hye Soo Yang 2019-08-30 13:03:44 -07:00 committed by TensorFlower Gardener
parent d72aca428b
commit e4d8a30f03
8 changed files with 81 additions and 92 deletions

View File

@ -15,7 +15,7 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:_pywrap_stat_summarizer",
"//tensorflow/python:util",
],
)
@ -28,7 +28,7 @@ tf_py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:_pywrap_stat_summarizer",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
],

View File

@ -22,13 +22,10 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import, line-too-long
from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer
from tensorflow.python.pywrap_tensorflow import NewStatSummarizer
from tensorflow.python.pywrap_tensorflow import StatSummarizer
from tensorflow.python._pywrap_stat_summarizer import StatSummarizer
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['DeleteStatSummarizer', 'NewStatSummarizer',
'StatSummarizer']
_allowed_symbols = ['StatSummarizer']
remove_undocumented(__name__, _allowed_symbols)

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_stat_summarizer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@ -36,7 +36,8 @@ class StatSummarizerTest(test.TestCase):
product = math_ops.matmul(matrix1, matrix2, name=r"product")
graph_def = graph.as_graph_def()
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
ss = _pywrap_stat_summarizer.StatSummarizer(
graph_def.SerializeToString())
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
@ -69,8 +70,6 @@ class StatSummarizerTest(test.TestCase):
# Test that a CDF summed to 100%
self.assertRegexpMatches(output_string, r"100\.")
pywrap_tensorflow.DeleteStatSummarizer(ss)
if __name__ == "__main__":
test.main()

View File

@ -395,6 +395,24 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_stat_summarizer",
srcs = ["util/stat_summarizer_wrapper.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_pywrap_stat_summarizer",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
"@pybind11",
],
)
cc_library(
name = "cpp_python_util",
srcs = ["util/util.cc"],
@ -703,6 +721,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":_pywrap_stat_summarizer",
":_pywrap_utils",
":common_shapes",
":composite_tensor",
@ -5001,7 +5020,6 @@ tf_py_wrap_cc(
"util/port.i",
"util/py_checkpoint_reader.i",
"util/scoped_annotation.i",
"util/stat_summarizer.i",
"util/tfprof.i",
"util/traceme.i",
"util/transform_graph.i",
@ -5096,8 +5114,9 @@ genrule(
name = "pybind_symbol_target_libs_file",
srcs = [
":cpp_python_util", # util
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
],
outs = ["pybind_symbol_target_libs_file"],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({
"//tensorflow:windows": """
for SRC in $(SRCS); do

View File

@ -21,7 +21,6 @@ limitations under the License.
%include "tensorflow/python/util/port.i"
%include "tensorflow/python/util/py_checkpoint_reader.i"
%include "tensorflow/python/util/stat_summarizer.i"
%include "tensorflow/python/util/tfprof.i"
%include "tensorflow/python/lib/core/py_func.i"

View File

@ -1,78 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%include <std_string.i>
%include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i"
%{
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/python/lib/core/py_func.h"
#include "tensorflow/core/framework/step_stats.pb.h"
%}
%ignoreall
%unignore NewStatSummarizer;
%unignore DeleteStatSummarizer;
%unignore tensorflow;
%unignore tensorflow::StatSummarizer;
%unignore tensorflow::StatSummarizer::StatSummarizer;
%unignore tensorflow::StatSummarizer::~StatSummarizer;
%unignore tensorflow::StatSummarizer::Initialize;
%unignore tensorflow::StatSummarizer::InitializeStr;
%unignore tensorflow::StatSummarizer::ProcessStepStats;
%unignore tensorflow::StatSummarizer::ProcessStepStatsStr;
%unignore tensorflow::StatSummarizer::PrintStepStats;
%unignore tensorflow::StatSummarizer::GetOutputString;
// TODO(ashankar): Remove the unused argument from the API.
%{
tensorflow::StatSummarizer* NewStatSummarizer(
const string& unused) {
return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions());
}
%}
%{
void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) {
delete ss;
}
%}
tensorflow::StatSummarizer* NewStatSummarizer(const string& unused);
void DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
%extend tensorflow::StatSummarizer {
void ProcessStepStatsStr(const string& step_stats_str) {
tensorflow::StepStats step_stats;
step_stats.ParseFromString(step_stats_str);
$self->ProcessStepStats(step_stats);
}
}
%extend tensorflow::StatSummarizer {
StatSummarizer() {
tensorflow::StatSummarizer* ss = new tensorflow::StatSummarizer(
tensorflow::StatSummarizerOptions());
return ss;
}
}
%include "tensorflow/core/util/stat_summarizer_options.h"
%include "tensorflow/core/util/stat_summarizer.h"
%unignoreall

View File

@ -0,0 +1,49 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/memory/memory.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/util/stat_summarizer.h"
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_stat_summarizer, m) {
py::class_<tensorflow::StatSummarizer> stat_summ_class(m, "StatSummarizer",
py::dynamic_attr());
stat_summ_class
.def(py::init([](std::string graph_def_serialized) {
tensorflow::GraphDef proto;
proto.ParseFromString(graph_def_serialized);
return new tensorflow::StatSummarizer(proto);
}))
.def(py::init([]() {
return new tensorflow::StatSummarizer(
tensorflow::StatSummarizerOptions());
}))
.def("ProcessStepStats", &tensorflow::StatSummarizer::ProcessStepStats)
.def("GetOutputString", &tensorflow::StatSummarizer::GetOutputString)
.def("PrintStepStats", &tensorflow::StatSummarizer::PrintStepStats)
.def("ProcessStepStatsStr", [](tensorflow::StatSummarizer& self,
const std::string& step_stats_str) {
tensorflow::StepStats step_stats;
step_stats.ParseFromString(step_stats_str);
self.ProcessStepStats(step_stats);
});
};

View File

@ -17,3 +17,7 @@ tensorflow::swig::IsSequenceForData
tensorflow::swig::FlattenForData
tensorflow::swig::AssertSameStructureForData
tensorflow::swig::RegisterType
[stream_executor_pimpl]
stream_executor::StreamExecutor::EnablePeerAccessTo
stream_executor::StreamExecutor::CanEnablePeerAccessTo