diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index a423d8bb4cc..64be2c70a11 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -482,8 +482,8 @@ static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { // Prepare a NamedTuple type to hold the outputs, if there are multiple if (num_outs > 1) { - const string tuple_type_prefix = - strings::StrCat("_", op_def.name(), "Output = collections.namedtuple("); + const string tuple_type_prefix = strings::StrCat( + "_", op_def.name(), "Output = _collections.namedtuple("); const string tuple_type_suffix = strings::StrCat( "\"", op_def.name(), "\", ", lower_op_name_outputs, ")"); strings::Appendf( @@ -656,18 +656,18 @@ string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, This file is MACHINE GENERATED! Do not edit. """ -import collections +import collections as _collections -from google.protobuf import text_format +from google.protobuf import text_format as _text_format -from tensorflow.core.framework import op_def_pb2 +from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 # Needed to trigger the call to _set_call_cpp_shape_fn. -from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import common_shapes as _common_shapes -from tensorflow.python.framework import op_def_registry -from tensorflow.python.framework import ops -from tensorflow.python.framework import op_def_library +from tensorflow.python.framework import op_def_registry as _op_def_registry +from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework import op_def_library as _op_def_library )"); // We'll make a copy of ops that filters out descriptions. @@ -699,7 +699,7 @@ from tensorflow.python.framework import op_def_library GetPythonOp(op_def, is_hidden, lower_case_name)); if (!require_shapes) { - strings::Appendf(&result, "ops.RegisterShape(\"%s\")(None)\n", + strings::Appendf(&result, "_ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str()); } @@ -709,10 +709,10 @@ from tensorflow.python.framework import op_def_library } strings::Appendf(&result, R"(def _InitOpDefLibrary(): - op_list = op_def_pb2.OpList() - text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) - op_def_registry.register_op_list(op_list) - op_def_lib = op_def_library.OpDefLibrary() + op_list = _op_def_pb2.OpList() + _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) + _op_def_registry.register_op_list(op_list) + op_def_lib = _op_def_library.OpDefLibrary() op_def_lib.add_op_list(op_list) return op_def_lib diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index e8c7ddcd49b..b83693303ac 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -105,6 +105,7 @@ import sys import numpy as np import six +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/python/ops/sdca_ops.py b/tensorflow/python/ops/sdca_ops.py index 84a207336fe..3876bc96421 100644 --- a/tensorflow/python/ops/sdca_ops.py +++ b/tensorflow/python/ops/sdca_ops.py @@ -27,16 +27,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import ops + # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_sdca_ops import * # pylint: enable=wildcard-import -# pylint: disable=anomalous-backslash-in-string,protected-access +from tensorflow.python.util.all_util import remove_undocumented + ops.NotDifferentiable("SdcaFprint") ops.NotDifferentiable("SdcaOptimizer") ops.NotDifferentiable("SdcaShrinkL1") + + +remove_undocumented(__name__) diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index b6b06f9eb91..0563b370ea0 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -64,18 +64,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=g-bad-import-order from tensorflow.python.client import device_lib as _device_lib from tensorflow.python.framework import test_util as _test_util from tensorflow.python.platform import googletest as _googletest from tensorflow.python.util.all_util import remove_undocumented # pylint: disable=unused-import -from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase from tensorflow.python.framework.test_util import assert_equal_graph_def +from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase from tensorflow.python.ops.gradient_checker import compute_gradient_error from tensorflow.python.ops.gradient_checker import compute_gradient -# pylint: enable=unused-import +# pylint: enable=unused-import,g-bad-import-order import sys if sys.version_info.major == 2: diff --git a/tensorflow/python/user_ops/user_ops.py b/tensorflow/python/user_ops/user_ops.py index fce89ec91bd..17dbab706c9 100644 --- a/tensorflow/python/user_ops/user_ops.py +++ b/tensorflow/python/user_ops/user_ops.py @@ -19,10 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops import gen_user_ops -from tensorflow.python.ops.gen_user_ops import * +from tensorflow.python.ops import gen_user_ops as _gen_user_ops + +# go/tf-wildcard-import +from tensorflow.python.ops.gen_user_ops import * # pylint: disable=wildcard-import def my_fact(): """Example of overriding the generated code for an Op.""" - return gen_user_ops._fact() + return _gen_user_ops._fact() # pylint: disable=protected-access diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py index 00771573c4c..08f33657510 100644 --- a/tensorflow/python/util/all_util.py +++ b/tensorflow/python/util/all_util.py @@ -83,8 +83,7 @@ def reveal_undocumented(symbol_name, target_module=None): def remove_undocumented(module_name, allowed_exception_list=None, doc_string_modules=None): - """Removes symbols in a module that are not referenced by a docstring that - contributes to documentation. + """Removes symbols in a module that are not referenced by a docstring. Args: module_name: the name of the module (usually `__name__`). @@ -100,7 +99,7 @@ def remove_undocumented(module_name, allowed_exception_list=None, """ current_symbols = set(dir(_sys.modules[module_name])) should_have = make_all(module_name, doc_string_modules) - should_have += allowed_exception_list + should_have += allowed_exception_list or [] extra_symbols = current_symbols - set(should_have) target_module = _sys.modules[module_name] for extra_symbol in extra_symbols: diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index f1d43134b85..96ae9583d73 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -9,6 +9,22 @@ package( default_visibility = ["//tensorflow:__subpackages__"], ) +py_library( + name = "public_api", + srcs = ["public_api.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "public_api_test", + srcs = ["public_api_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":public_api", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "traverse", srcs = ["traverse.py"], diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py new file mode 100644 index 00000000000..5d70cb7b767 --- /dev/null +++ b/tensorflow/tools/common/public_api.py @@ -0,0 +1,78 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Visitor restricting traversal to only the public tensorflow API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + + +class PublicAPIVisitor(object): + """Visitor to use with `traverse` to visit exactly the public TF API.""" + + def __init__(self, visitor): + """Constructor. + + `visitor` should be a callable suitable as a visitor for `traverse`. It will + be called only for members of the public TensorFlow API. + + Args: + visitor: A visitor to call for the public API. + """ + self._visitor = visitor + + # Modules/classes we do not want to descend into if we hit them. Usually, + # sytem modules exposed through platforms for compatibility reasons. + # Each entry maps a module path to a name to ignore in traversal. + _do_not_descend_map = { + # TODO(drpng): This can be removed once sealed off. + '': ['platform', 'pywrap_tensorflow'], + + # Some implementations have this internal module that we shouldn't expose. + 'flags': ['cpp_flags'], + + # Everything below here is legitimate. + 'app': 'flags', # It'll stay, but it's not officially part of the API + 'test': ['mock'], # Imported for compatibility between py2/3. + } + + def _isprivate(self, name): + """Return whether a name is private.""" + return name.startswith('_') + + def _do_not_descend(self, path, name): + """Safely queries if a specific fully qualified name should be excluded.""" + return (path in self._do_not_descend_map and + name in self._do_not_descend_map[path]) + + def __call__(self, path, parent, children): + """Visitor interface, see `traverse` for details.""" + if inspect.ismodule(parent) and len(path.split('.')) > 10: + raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a ' + 'problem with an accidental public import.' % path) + + # Remove things that are not visible. + for name, child in list(children): + if self._isprivate(name): + children.remove((name, child)) + + self._visitor(path, parent, children) + + # Remove things that are visible, but which should not be descended into. + for name, child in list(children): + if self._do_not_descend(path, name): + children.remove((name, child)) diff --git a/tensorflow/tools/common/public_api_test.py b/tensorflow/tools/common/public_api_test.py new file mode 100644 index 00000000000..93a3bcc2740 --- /dev/null +++ b/tensorflow/tools/common/public_api_test.py @@ -0,0 +1,68 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.tools.common.public_api.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import googletest +from tensorflow.tools.common import public_api + + +class PublicApiTest(googletest.TestCase): + + class TestVisitor(object): + + def __init__(self): + self.symbols = set() + self.last_parent = None + self.last_children = None + + def __call__(self, path, parent, children): + self.symbols.add(path) + self.last_parent = parent + self.last_children = list(children) # Make a copy to preserve state. + + def test_call_forward(self): + visitor = self.TestVisitor() + children = [('name1', 'thing1'), ('name2', 'thing2')] + public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) + self.assertEqual(set(['test']), visitor.symbols) + self.assertEqual('dummy', visitor.last_parent) + self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')], + visitor.last_children) + + def test_private_child_removal(self): + visitor = self.TestVisitor() + children = [('name1', 'thing1'), ('_name2', 'thing2')] + public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) + # Make sure the private symbols are removed before the visitor is called. + self.assertEqual([('name1', 'thing1')], visitor.last_children) + self.assertEqual([('name1', 'thing1')], children) + + def test_no_descent_child_removal(self): + visitor = self.TestVisitor() + children = [('name1', 'thing1'), ('mock', 'thing2')] + public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) + # Make sure not-to-be-descended-into symbols are removed after the visitor + # is called. + self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')], + visitor.last_children) + self.assertEqual([('name1', 'thing1')], children) + + +if __name__ == '__main__': + googletest.main()