Fix a bad comment.

Change: 143889993
This commit is contained in:
Martin Wicke 2017-01-08 01:05:21 -08:00 committed by TensorFlower Gardener
parent 367078ab8c
commit 2691c1260f
9 changed files with 193 additions and 26 deletions

View File

@ -482,8 +482,8 @@ static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
// Prepare a NamedTuple type to hold the outputs, if there are multiple
if (num_outs > 1) {
const string tuple_type_prefix =
strings::StrCat("_", op_def.name(), "Output = collections.namedtuple(");
const string tuple_type_prefix = strings::StrCat(
"_", op_def.name(), "Output = _collections.namedtuple(");
const string tuple_type_suffix = strings::StrCat(
"\"", op_def.name(), "\", ", lower_op_name_outputs, ")");
strings::Appendf(
@ -656,18 +656,18 @@ string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
This file is MACHINE GENERATED! Do not edit.
"""
import collections
import collections as _collections
from google.protobuf import text_format
from google.protobuf import text_format as _text_format
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
# Needed to trigger the call to _set_call_cpp_shape_fn.
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import common_shapes as _common_shapes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import op_def_library
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
)");
// We'll make a copy of ops that filters out descriptions.
@ -699,7 +699,7 @@ from tensorflow.python.framework import op_def_library
GetPythonOp(op_def, is_hidden, lower_case_name));
if (!require_shapes) {
strings::Appendf(&result, "ops.RegisterShape(\"%s\")(None)\n",
strings::Appendf(&result, "_ops.RegisterShape(\"%s\")(None)\n",
op_def.name().c_str());
}
@ -709,10 +709,10 @@ from tensorflow.python.framework import op_def_library
}
strings::Appendf(&result, R"(def _InitOpDefLibrary():
op_list = op_def_pb2.OpList()
text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
op_def_registry.register_op_list(op_list)
op_def_lib = op_def_library.OpDefLibrary()
op_list = _op_def_pb2.OpList()
_text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
_op_def_registry.register_op_list(op_list)
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib

View File

@ -105,6 +105,7 @@ import sys
import numpy as np
import six
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops

View File

@ -27,16 +27,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_sdca_ops import *
# pylint: enable=wildcard-import
# pylint: disable=anomalous-backslash-in-string,protected-access
from tensorflow.python.util.all_util import remove_undocumented
ops.NotDifferentiable("SdcaFprint")
ops.NotDifferentiable("SdcaOptimizer")
ops.NotDifferentiable("SdcaShrinkL1")
remove_undocumented(__name__)

View File

@ -64,18 +64,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python.client import device_lib as _device_lib
from tensorflow.python.framework import test_util as _test_util
from tensorflow.python.platform import googletest as _googletest
from tensorflow.python.util.all_util import remove_undocumented
# pylint: disable=unused-import
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
from tensorflow.python.framework.test_util import assert_equal_graph_def
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
from tensorflow.python.ops.gradient_checker import compute_gradient_error
from tensorflow.python.ops.gradient_checker import compute_gradient
# pylint: enable=unused-import
# pylint: enable=unused-import,g-bad-import-order
import sys
if sys.version_info.major == 2:

View File

@ -19,10 +19,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import gen_user_ops
from tensorflow.python.ops.gen_user_ops import *
from tensorflow.python.ops import gen_user_ops as _gen_user_ops
# go/tf-wildcard-import
from tensorflow.python.ops.gen_user_ops import * # pylint: disable=wildcard-import
def my_fact():
"""Example of overriding the generated code for an Op."""
return gen_user_ops._fact()
return _gen_user_ops._fact() # pylint: disable=protected-access

View File

@ -83,8 +83,7 @@ def reveal_undocumented(symbol_name, target_module=None):
def remove_undocumented(module_name, allowed_exception_list=None,
doc_string_modules=None):
"""Removes symbols in a module that are not referenced by a docstring that
contributes to documentation.
"""Removes symbols in a module that are not referenced by a docstring.
Args:
module_name: the name of the module (usually `__name__`).
@ -100,7 +99,7 @@ def remove_undocumented(module_name, allowed_exception_list=None,
"""
current_symbols = set(dir(_sys.modules[module_name]))
should_have = make_all(module_name, doc_string_modules)
should_have += allowed_exception_list
should_have += allowed_exception_list or []
extra_symbols = current_symbols - set(should_have)
target_module = _sys.modules[module_name]
for extra_symbol in extra_symbols:

View File

@ -9,6 +9,22 @@ package(
default_visibility = ["//tensorflow:__subpackages__"],
)
py_library(
name = "public_api",
srcs = ["public_api.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "public_api_test",
srcs = ["public_api_test.py"],
srcs_version = "PY2AND3",
deps = [
":public_api",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "traverse",
srcs = ["traverse.py"],

View File

@ -0,0 +1,78 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visitor restricting traversal to only the public tensorflow API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
class PublicAPIVisitor(object):
"""Visitor to use with `traverse` to visit exactly the public TF API."""
def __init__(self, visitor):
"""Constructor.
`visitor` should be a callable suitable as a visitor for `traverse`. It will
be called only for members of the public TensorFlow API.
Args:
visitor: A visitor to call for the public API.
"""
self._visitor = visitor
# Modules/classes we do not want to descend into if we hit them. Usually,
# sytem modules exposed through platforms for compatibility reasons.
# Each entry maps a module path to a name to ignore in traversal.
_do_not_descend_map = {
# TODO(drpng): This can be removed once sealed off.
'': ['platform', 'pywrap_tensorflow'],
# Some implementations have this internal module that we shouldn't expose.
'flags': ['cpp_flags'],
# Everything below here is legitimate.
'app': 'flags', # It'll stay, but it's not officially part of the API
'test': ['mock'], # Imported for compatibility between py2/3.
}
def _isprivate(self, name):
"""Return whether a name is private."""
return name.startswith('_')
def _do_not_descend(self, path, name):
"""Safely queries if a specific fully qualified name should be excluded."""
return (path in self._do_not_descend_map and
name in self._do_not_descend_map[path])
def __call__(self, path, parent, children):
"""Visitor interface, see `traverse` for details."""
if inspect.ismodule(parent) and len(path.split('.')) > 10:
raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a '
'problem with an accidental public import.' % path)
# Remove things that are not visible.
for name, child in list(children):
if self._isprivate(name):
children.remove((name, child))
self._visitor(path, parent, children)
# Remove things that are visible, but which should not be descended into.
for name, child in list(children):
if self._do_not_descend(path, name):
children.remove((name, child))

View File

@ -0,0 +1,68 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.tools.common.public_api."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import googletest
from tensorflow.tools.common import public_api
class PublicApiTest(googletest.TestCase):
class TestVisitor(object):
def __init__(self):
self.symbols = set()
self.last_parent = None
self.last_children = None
def __call__(self, path, parent, children):
self.symbols.add(path)
self.last_parent = parent
self.last_children = list(children) # Make a copy to preserve state.
def test_call_forward(self):
visitor = self.TestVisitor()
children = [('name1', 'thing1'), ('name2', 'thing2')]
public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
self.assertEqual(set(['test']), visitor.symbols)
self.assertEqual('dummy', visitor.last_parent)
self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')],
visitor.last_children)
def test_private_child_removal(self):
visitor = self.TestVisitor()
children = [('name1', 'thing1'), ('_name2', 'thing2')]
public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
# Make sure the private symbols are removed before the visitor is called.
self.assertEqual([('name1', 'thing1')], visitor.last_children)
self.assertEqual([('name1', 'thing1')], children)
def test_no_descent_child_removal(self):
visitor = self.TestVisitor()
children = [('name1', 'thing1'), ('mock', 'thing2')]
public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
# Make sure not-to-be-descended-into symbols are removed after the visitor
# is called.
self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')],
visitor.last_children)
self.assertEqual([('name1', 'thing1')], children)
if __name__ == '__main__':
googletest.main()