Export the stat_summarizer functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA and MLIR are using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 266439885
This commit is contained in:
parent
d72aca428b
commit
e4d8a30f03
@ -15,7 +15,7 @@ py_library(
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:_pywrap_stat_summarizer",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -28,7 +28,7 @@ tf_py_test(
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:_pywrap_stat_summarizer",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
|
@ -22,13 +22,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer
|
||||
from tensorflow.python.pywrap_tensorflow import NewStatSummarizer
|
||||
from tensorflow.python.pywrap_tensorflow import StatSummarizer
|
||||
from tensorflow.python._pywrap_stat_summarizer import StatSummarizer
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['DeleteStatSummarizer', 'NewStatSummarizer',
|
||||
'StatSummarizer']
|
||||
_allowed_symbols = ['StatSummarizer']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -19,7 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_stat_summarizer
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -36,7 +36,8 @@ class StatSummarizerTest(test.TestCase):
|
||||
product = math_ops.matmul(matrix1, matrix2, name=r"product")
|
||||
|
||||
graph_def = graph.as_graph_def()
|
||||
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
|
||||
ss = _pywrap_stat_summarizer.StatSummarizer(
|
||||
graph_def.SerializeToString())
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
@ -69,8 +70,6 @@ class StatSummarizerTest(test.TestCase):
|
||||
# Test that a CDF summed to 100%
|
||||
self.assertRegexpMatches(output_string, r"100\.")
|
||||
|
||||
pywrap_tensorflow.DeleteStatSummarizer(ss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -395,6 +395,24 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_stat_summarizer",
|
||||
srcs = ["util/stat_summarizer_wrapper.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_stat_summarizer",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpp_python_util",
|
||||
srcs = ["util/util.cc"],
|
||||
@ -703,6 +721,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_stat_summarizer",
|
||||
":_pywrap_utils",
|
||||
":common_shapes",
|
||||
":composite_tensor",
|
||||
@ -5001,7 +5020,6 @@ tf_py_wrap_cc(
|
||||
"util/port.i",
|
||||
"util/py_checkpoint_reader.i",
|
||||
"util/scoped_annotation.i",
|
||||
"util/stat_summarizer.i",
|
||||
"util/tfprof.i",
|
||||
"util/traceme.i",
|
||||
"util/transform_graph.i",
|
||||
@ -5096,8 +5114,9 @@ genrule(
|
||||
name = "pybind_symbol_target_libs_file",
|
||||
srcs = [
|
||||
":cpp_python_util", # util
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
|
||||
],
|
||||
outs = ["pybind_symbol_target_libs_file"],
|
||||
outs = ["pybind_symbol_target_libs_file.txt"],
|
||||
cmd = select({
|
||||
"//tensorflow:windows": """
|
||||
for SRC in $(SRCS); do
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
|
||||
%include "tensorflow/python/util/port.i"
|
||||
%include "tensorflow/python/util/py_checkpoint_reader.i"
|
||||
%include "tensorflow/python/util/stat_summarizer.i"
|
||||
%include "tensorflow/python/util/tfprof.i"
|
||||
|
||||
%include "tensorflow/python/lib/core/py_func.i"
|
||||
|
@ -1,78 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%include <std_string.i>
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore NewStatSummarizer;
|
||||
%unignore DeleteStatSummarizer;
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::StatSummarizer;
|
||||
%unignore tensorflow::StatSummarizer::StatSummarizer;
|
||||
%unignore tensorflow::StatSummarizer::~StatSummarizer;
|
||||
%unignore tensorflow::StatSummarizer::Initialize;
|
||||
%unignore tensorflow::StatSummarizer::InitializeStr;
|
||||
%unignore tensorflow::StatSummarizer::ProcessStepStats;
|
||||
%unignore tensorflow::StatSummarizer::ProcessStepStatsStr;
|
||||
%unignore tensorflow::StatSummarizer::PrintStepStats;
|
||||
%unignore tensorflow::StatSummarizer::GetOutputString;
|
||||
|
||||
|
||||
// TODO(ashankar): Remove the unused argument from the API.
|
||||
%{
|
||||
tensorflow::StatSummarizer* NewStatSummarizer(
|
||||
const string& unused) {
|
||||
return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions());
|
||||
}
|
||||
%}
|
||||
|
||||
%{
|
||||
void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) {
|
||||
delete ss;
|
||||
}
|
||||
%}
|
||||
|
||||
tensorflow::StatSummarizer* NewStatSummarizer(const string& unused);
|
||||
void DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
|
||||
|
||||
%extend tensorflow::StatSummarizer {
|
||||
void ProcessStepStatsStr(const string& step_stats_str) {
|
||||
tensorflow::StepStats step_stats;
|
||||
step_stats.ParseFromString(step_stats_str);
|
||||
$self->ProcessStepStats(step_stats);
|
||||
}
|
||||
}
|
||||
|
||||
%extend tensorflow::StatSummarizer {
|
||||
StatSummarizer() {
|
||||
tensorflow::StatSummarizer* ss = new tensorflow::StatSummarizer(
|
||||
tensorflow::StatSummarizerOptions());
|
||||
return ss;
|
||||
}
|
||||
}
|
||||
%include "tensorflow/core/util/stat_summarizer_options.h"
|
||||
%include "tensorflow/core/util/stat_summarizer.h"
|
||||
%unignoreall
|
49
tensorflow/python/util/stat_summarizer_wrapper.cc
Normal file
49
tensorflow/python/util/stat_summarizer_wrapper.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/pytypes.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_stat_summarizer, m) {
|
||||
py::class_<tensorflow::StatSummarizer> stat_summ_class(m, "StatSummarizer",
|
||||
py::dynamic_attr());
|
||||
stat_summ_class
|
||||
.def(py::init([](std::string graph_def_serialized) {
|
||||
tensorflow::GraphDef proto;
|
||||
proto.ParseFromString(graph_def_serialized);
|
||||
return new tensorflow::StatSummarizer(proto);
|
||||
}))
|
||||
.def(py::init([]() {
|
||||
return new tensorflow::StatSummarizer(
|
||||
tensorflow::StatSummarizerOptions());
|
||||
}))
|
||||
.def("ProcessStepStats", &tensorflow::StatSummarizer::ProcessStepStats)
|
||||
.def("GetOutputString", &tensorflow::StatSummarizer::GetOutputString)
|
||||
.def("PrintStepStats", &tensorflow::StatSummarizer::PrintStepStats)
|
||||
.def("ProcessStepStatsStr", [](tensorflow::StatSummarizer& self,
|
||||
const std::string& step_stats_str) {
|
||||
tensorflow::StepStats step_stats;
|
||||
step_stats.ParseFromString(step_stats_str);
|
||||
self.ProcessStepStats(step_stats);
|
||||
});
|
||||
};
|
@ -17,3 +17,7 @@ tensorflow::swig::IsSequenceForData
|
||||
tensorflow::swig::FlattenForData
|
||||
tensorflow::swig::AssertSameStructureForData
|
||||
tensorflow::swig::RegisterType
|
||||
|
||||
[stream_executor_pimpl]
|
||||
stream_executor::StreamExecutor::EnablePeerAccessTo
|
||||
stream_executor::StreamExecutor::CanEnablePeerAccessTo
|
||||
|
Loading…
Reference in New Issue
Block a user