Fix a bad comment.
Change: 143889993
This commit is contained in:
parent
367078ab8c
commit
2691c1260f
tensorflow
python
framework
ops
platform
user_ops
util
tools/common
@ -482,8 +482,8 @@ static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
|
||||
|
||||
// Prepare a NamedTuple type to hold the outputs, if there are multiple
|
||||
if (num_outs > 1) {
|
||||
const string tuple_type_prefix =
|
||||
strings::StrCat("_", op_def.name(), "Output = collections.namedtuple(");
|
||||
const string tuple_type_prefix = strings::StrCat(
|
||||
"_", op_def.name(), "Output = _collections.namedtuple(");
|
||||
const string tuple_type_suffix = strings::StrCat(
|
||||
"\"", op_def.name(), "\", ", lower_op_name_outputs, ")");
|
||||
strings::Appendf(
|
||||
@ -656,18 +656,18 @@ string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
|
||||
This file is MACHINE GENERATED! Do not edit.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import collections as _collections
|
||||
|
||||
from google.protobuf import text_format
|
||||
from google.protobuf import text_format as _text_format
|
||||
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
|
||||
|
||||
# Needed to trigger the call to _set_call_cpp_shape_fn.
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import common_shapes as _common_shapes
|
||||
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import op_def_library
|
||||
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import op_def_library as _op_def_library
|
||||
)");
|
||||
|
||||
// We'll make a copy of ops that filters out descriptions.
|
||||
@ -699,7 +699,7 @@ from tensorflow.python.framework import op_def_library
|
||||
GetPythonOp(op_def, is_hidden, lower_case_name));
|
||||
|
||||
if (!require_shapes) {
|
||||
strings::Appendf(&result, "ops.RegisterShape(\"%s\")(None)\n",
|
||||
strings::Appendf(&result, "_ops.RegisterShape(\"%s\")(None)\n",
|
||||
op_def.name().c_str());
|
||||
}
|
||||
|
||||
@ -709,10 +709,10 @@ from tensorflow.python.framework import op_def_library
|
||||
}
|
||||
|
||||
strings::Appendf(&result, R"(def _InitOpDefLibrary():
|
||||
op_list = op_def_pb2.OpList()
|
||||
text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
|
||||
op_def_registry.register_op_list(op_list)
|
||||
op_def_lib = op_def_library.OpDefLibrary()
|
||||
op_list = _op_def_pb2.OpList()
|
||||
_text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
|
||||
_op_def_registry.register_op_list(op_list)
|
||||
op_def_lib = _op_def_library.OpDefLibrary()
|
||||
op_def_lib.add_op_list(op_list)
|
||||
return op_def_lib
|
||||
|
||||
|
@ -105,6 +105,7 @@ import sys
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -27,16 +27,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_sdca_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
# pylint: disable=anomalous-backslash-in-string,protected-access
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
ops.NotDifferentiable("SdcaFprint")
|
||||
ops.NotDifferentiable("SdcaOptimizer")
|
||||
ops.NotDifferentiable("SdcaShrinkL1")
|
||||
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
@ -64,18 +64,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python.client import device_lib as _device_lib
|
||||
from tensorflow.python.framework import test_util as _test_util
|
||||
from tensorflow.python.platform import googletest as _googletest
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
|
||||
from tensorflow.python.framework.test_util import assert_equal_graph_def
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
|
||||
|
||||
from tensorflow.python.ops.gradient_checker import compute_gradient_error
|
||||
from tensorflow.python.ops.gradient_checker import compute_gradient
|
||||
# pylint: enable=unused-import
|
||||
# pylint: enable=unused-import,g-bad-import-order
|
||||
|
||||
import sys
|
||||
if sys.version_info.major == 2:
|
||||
|
@ -19,10 +19,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import gen_user_ops
|
||||
from tensorflow.python.ops.gen_user_ops import *
|
||||
from tensorflow.python.ops import gen_user_ops as _gen_user_ops
|
||||
|
||||
# go/tf-wildcard-import
|
||||
from tensorflow.python.ops.gen_user_ops import * # pylint: disable=wildcard-import
|
||||
|
||||
|
||||
def my_fact():
|
||||
"""Example of overriding the generated code for an Op."""
|
||||
return gen_user_ops._fact()
|
||||
return _gen_user_ops._fact() # pylint: disable=protected-access
|
||||
|
@ -83,8 +83,7 @@ def reveal_undocumented(symbol_name, target_module=None):
|
||||
|
||||
def remove_undocumented(module_name, allowed_exception_list=None,
|
||||
doc_string_modules=None):
|
||||
"""Removes symbols in a module that are not referenced by a docstring that
|
||||
contributes to documentation.
|
||||
"""Removes symbols in a module that are not referenced by a docstring.
|
||||
|
||||
Args:
|
||||
module_name: the name of the module (usually `__name__`).
|
||||
@ -100,7 +99,7 @@ def remove_undocumented(module_name, allowed_exception_list=None,
|
||||
"""
|
||||
current_symbols = set(dir(_sys.modules[module_name]))
|
||||
should_have = make_all(module_name, doc_string_modules)
|
||||
should_have += allowed_exception_list
|
||||
should_have += allowed_exception_list or []
|
||||
extra_symbols = current_symbols - set(should_have)
|
||||
target_module = _sys.modules[module_name]
|
||||
for extra_symbol in extra_symbols:
|
||||
|
@ -9,6 +9,22 @@ package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "public_api",
|
||||
srcs = ["public_api.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "public_api_test",
|
||||
srcs = ["public_api_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":public_api",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "traverse",
|
||||
srcs = ["traverse.py"],
|
||||
|
78
tensorflow/tools/common/public_api.py
Normal file
78
tensorflow/tools/common/public_api.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Visitor restricting traversal to only the public tensorflow API."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
class PublicAPIVisitor(object):
|
||||
"""Visitor to use with `traverse` to visit exactly the public TF API."""
|
||||
|
||||
def __init__(self, visitor):
|
||||
"""Constructor.
|
||||
|
||||
`visitor` should be a callable suitable as a visitor for `traverse`. It will
|
||||
be called only for members of the public TensorFlow API.
|
||||
|
||||
Args:
|
||||
visitor: A visitor to call for the public API.
|
||||
"""
|
||||
self._visitor = visitor
|
||||
|
||||
# Modules/classes we do not want to descend into if we hit them. Usually,
|
||||
# sytem modules exposed through platforms for compatibility reasons.
|
||||
# Each entry maps a module path to a name to ignore in traversal.
|
||||
_do_not_descend_map = {
|
||||
# TODO(drpng): This can be removed once sealed off.
|
||||
'': ['platform', 'pywrap_tensorflow'],
|
||||
|
||||
# Some implementations have this internal module that we shouldn't expose.
|
||||
'flags': ['cpp_flags'],
|
||||
|
||||
# Everything below here is legitimate.
|
||||
'app': 'flags', # It'll stay, but it's not officially part of the API
|
||||
'test': ['mock'], # Imported for compatibility between py2/3.
|
||||
}
|
||||
|
||||
def _isprivate(self, name):
|
||||
"""Return whether a name is private."""
|
||||
return name.startswith('_')
|
||||
|
||||
def _do_not_descend(self, path, name):
|
||||
"""Safely queries if a specific fully qualified name should be excluded."""
|
||||
return (path in self._do_not_descend_map and
|
||||
name in self._do_not_descend_map[path])
|
||||
|
||||
def __call__(self, path, parent, children):
|
||||
"""Visitor interface, see `traverse` for details."""
|
||||
if inspect.ismodule(parent) and len(path.split('.')) > 10:
|
||||
raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a '
|
||||
'problem with an accidental public import.' % path)
|
||||
|
||||
# Remove things that are not visible.
|
||||
for name, child in list(children):
|
||||
if self._isprivate(name):
|
||||
children.remove((name, child))
|
||||
|
||||
self._visitor(path, parent, children)
|
||||
|
||||
# Remove things that are visible, but which should not be descended into.
|
||||
for name, child in list(children):
|
||||
if self._do_not_descend(path, name):
|
||||
children.remove((name, child))
|
68
tensorflow/tools/common/public_api_test.py
Normal file
68
tensorflow/tools/common/public_api_test.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.tools.common.public_api."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.tools.common import public_api
|
||||
|
||||
|
||||
class PublicApiTest(googletest.TestCase):
|
||||
|
||||
class TestVisitor(object):
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = set()
|
||||
self.last_parent = None
|
||||
self.last_children = None
|
||||
|
||||
def __call__(self, path, parent, children):
|
||||
self.symbols.add(path)
|
||||
self.last_parent = parent
|
||||
self.last_children = list(children) # Make a copy to preserve state.
|
||||
|
||||
def test_call_forward(self):
|
||||
visitor = self.TestVisitor()
|
||||
children = [('name1', 'thing1'), ('name2', 'thing2')]
|
||||
public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
|
||||
self.assertEqual(set(['test']), visitor.symbols)
|
||||
self.assertEqual('dummy', visitor.last_parent)
|
||||
self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')],
|
||||
visitor.last_children)
|
||||
|
||||
def test_private_child_removal(self):
|
||||
visitor = self.TestVisitor()
|
||||
children = [('name1', 'thing1'), ('_name2', 'thing2')]
|
||||
public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
|
||||
# Make sure the private symbols are removed before the visitor is called.
|
||||
self.assertEqual([('name1', 'thing1')], visitor.last_children)
|
||||
self.assertEqual([('name1', 'thing1')], children)
|
||||
|
||||
def test_no_descent_child_removal(self):
|
||||
visitor = self.TestVisitor()
|
||||
children = [('name1', 'thing1'), ('mock', 'thing2')]
|
||||
public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
|
||||
# Make sure not-to-be-descended-into symbols are removed after the visitor
|
||||
# is called.
|
||||
self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')],
|
||||
visitor.last_children)
|
||||
self.assertEqual([('name1', 'thing1')], children)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
Loading…
Reference in New Issue
Block a user