diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 7f965e8ac80..dd9e39933fe 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -15,7 +15,7 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:_pywrap_stat_summarizer", "//tensorflow/python:util", ], ) @@ -28,7 +28,7 @@ tf_py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:_pywrap_stat_summarizer", "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], diff --git a/tensorflow/contrib/stat_summarizer/__init__.py b/tensorflow/contrib/stat_summarizer/__init__.py index 53d5548863a..a1012f542ea 100644 --- a/tensorflow/contrib/stat_summarizer/__init__.py +++ b/tensorflow/contrib/stat_summarizer/__init__.py @@ -22,13 +22,10 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import, line-too-long -from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer -from tensorflow.python.pywrap_tensorflow import NewStatSummarizer -from tensorflow.python.pywrap_tensorflow import StatSummarizer +from tensorflow.python._pywrap_stat_summarizer import StatSummarizer from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['DeleteStatSummarizer', 'NewStatSummarizer', - 'StatSummarizer'] +_allowed_symbols = ['StatSummarizer'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py index e6a0b305670..542189f3f23 100644 --- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py +++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 -from tensorflow.python import pywrap_tensorflow +from tensorflow.python import _pywrap_stat_summarizer from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -36,7 +36,8 @@ class StatSummarizerTest(test.TestCase): product = math_ops.matmul(matrix1, matrix2, name=r"product") graph_def = graph.as_graph_def() - ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString()) + ss = _pywrap_stat_summarizer.StatSummarizer( + graph_def.SerializeToString()) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) @@ -69,8 +70,6 @@ class StatSummarizerTest(test.TestCase): # Test that a CDF summed to 100% self.assertRegexpMatches(output_string, r"100\.") - pywrap_tensorflow.DeleteStatSummarizer(ss) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b167b64197e..a28297194b0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -395,6 +395,24 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_stat_summarizer", + srcs = ["util/stat_summarizer_wrapper.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_stat_summarizer", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/memory", + "@pybind11", + ], +) + cc_library( name = "cpp_python_util", srcs = ["util/util.cc"], @@ -703,6 +721,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":_pywrap_stat_summarizer", ":_pywrap_utils", ":common_shapes", ":composite_tensor", @@ -5001,7 +5020,6 @@ tf_py_wrap_cc( "util/port.i", "util/py_checkpoint_reader.i", "util/scoped_annotation.i", - "util/stat_summarizer.i", "util/tfprof.i", "util/traceme.i", "util/transform_graph.i", @@ -5096,8 +5114,9 @@ genrule( name = "pybind_symbol_target_libs_file", srcs = [ ":cpp_python_util", # util + "//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer ], - outs = ["pybind_symbol_target_libs_file"], + outs = ["pybind_symbol_target_libs_file.txt"], cmd = select({ "//tensorflow:windows": """ for SRC in $(SRCS); do diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index c331601bebd..5891947a817 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -21,7 +21,6 @@ limitations under the License. %include "tensorflow/python/util/port.i" %include "tensorflow/python/util/py_checkpoint_reader.i" -%include "tensorflow/python/util/stat_summarizer.i" %include "tensorflow/python/util/tfprof.i" %include "tensorflow/python/lib/core/py_func.i" diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i deleted file mode 100644 index a5a7984d914..00000000000 --- a/tensorflow/python/util/stat_summarizer.i +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -%include -%include "tensorflow/python/lib/core/strings.i" -%include "tensorflow/python/platform/base.i" - -%{ -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/stat_summarizer.h" -#include "tensorflow/python/lib/core/py_func.h" - -#include "tensorflow/core/framework/step_stats.pb.h" -%} - -%ignoreall - -%unignore NewStatSummarizer; -%unignore DeleteStatSummarizer; -%unignore tensorflow; -%unignore tensorflow::StatSummarizer; -%unignore tensorflow::StatSummarizer::StatSummarizer; -%unignore tensorflow::StatSummarizer::~StatSummarizer; -%unignore tensorflow::StatSummarizer::Initialize; -%unignore tensorflow::StatSummarizer::InitializeStr; -%unignore tensorflow::StatSummarizer::ProcessStepStats; -%unignore tensorflow::StatSummarizer::ProcessStepStatsStr; -%unignore tensorflow::StatSummarizer::PrintStepStats; -%unignore tensorflow::StatSummarizer::GetOutputString; - - -// TODO(ashankar): Remove the unused argument from the API. -%{ -tensorflow::StatSummarizer* NewStatSummarizer( - const string& unused) { - return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions()); -} -%} - -%{ -void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) { - delete ss; -} -%} - -tensorflow::StatSummarizer* NewStatSummarizer(const string& unused); -void DeleteStatSummarizer(tensorflow::StatSummarizer* ss); - -%extend tensorflow::StatSummarizer { - void ProcessStepStatsStr(const string& step_stats_str) { - tensorflow::StepStats step_stats; - step_stats.ParseFromString(step_stats_str); - $self->ProcessStepStats(step_stats); -} -} - -%extend tensorflow::StatSummarizer { - StatSummarizer() { - tensorflow::StatSummarizer* ss = new tensorflow::StatSummarizer( - tensorflow::StatSummarizerOptions()); - return ss; -} -} -%include "tensorflow/core/util/stat_summarizer_options.h" -%include "tensorflow/core/util/stat_summarizer.h" -%unignoreall diff --git a/tensorflow/python/util/stat_summarizer_wrapper.cc b/tensorflow/python/util/stat_summarizer_wrapper.cc new file mode 100644 index 00000000000..f46ddc518e0 --- /dev/null +++ b/tensorflow/python/util/stat_summarizer_wrapper.cc @@ -0,0 +1,49 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/util/stat_summarizer.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_pywrap_stat_summarizer, m) { + py::class_ stat_summ_class(m, "StatSummarizer", + py::dynamic_attr()); + stat_summ_class + .def(py::init([](std::string graph_def_serialized) { + tensorflow::GraphDef proto; + proto.ParseFromString(graph_def_serialized); + return new tensorflow::StatSummarizer(proto); + })) + .def(py::init([]() { + return new tensorflow::StatSummarizer( + tensorflow::StatSummarizerOptions()); + })) + .def("ProcessStepStats", &tensorflow::StatSummarizer::ProcessStepStats) + .def("GetOutputString", &tensorflow::StatSummarizer::GetOutputString) + .def("PrintStepStats", &tensorflow::StatSummarizer::PrintStepStats) + .def("ProcessStepStatsStr", [](tensorflow::StatSummarizer& self, + const std::string& step_stats_str) { + tensorflow::StepStats step_stats; + step_stats.ParseFromString(step_stats_str); + self.ProcessStepStats(step_stats); + }); +}; diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 4dcc8abaa8d..7bcb43bd510 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -17,3 +17,7 @@ tensorflow::swig::IsSequenceForData tensorflow::swig::FlattenForData tensorflow::swig::AssertSameStructureForData tensorflow::swig::RegisterType + +[stream_executor_pimpl] +stream_executor::StreamExecutor::EnablePeerAccessTo +stream_executor::StreamExecutor::CanEnablePeerAccessTo