From 8ac4834beeb7e186d0a1c3794fdc178fa3553d3b Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Tue, 27 Dec 2016 15:26:51 -0800 Subject: [PATCH] Make a traversal tool to visit everything in a given Python module/class. Change: 143061298 --- tensorflow/BUILD | 1 + tensorflow/tools/common/BUILD | 37 ++++++++++ tensorflow/tools/common/traverse.py | 91 ++++++++++++++++++++++++ tensorflow/tools/common/traverse_test.py | 84 ++++++++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 tensorflow/tools/common/BUILD create mode 100644 tensorflow/tools/common/traverse.py create mode 100644 tensorflow/tools/common/traverse_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index c324209805c..0c4f3e0647e 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -185,6 +185,7 @@ filegroup( "//tensorflow/tensorboard/lib:all_files", "//tensorflow/tensorboard/lib/python:all_files", "//tensorflow/tensorboard/scripts:all_files", + "//tensorflow/tools/common:all_files", "//tensorflow/tools/dist_test/server:all_files", "//tensorflow/tools/docker:all_files", "//tensorflow/tools/docker/notebooks:all_files", diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD new file mode 100644 index 00000000000..f1d43134b85 --- /dev/null +++ b/tensorflow/tools/common/BUILD @@ -0,0 +1,37 @@ +# Description: +# Common functionality for TensorFlow tooling + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package( + default_visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "traverse", + srcs = ["traverse.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "traverse_test", + srcs = ["traverse_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":traverse", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/tools/common/traverse.py b/tensorflow/tools/common/traverse.py new file mode 100644 index 00000000000..443838d9682 --- /dev/null +++ b/tensorflow/tools/common/traverse.py @@ -0,0 +1,91 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Traversing Python modules and classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import sys + + +__all__ = ['traverse'] + + +def _traverse_internal(root, visit, stack, path): + """Internal helper for traverse.""" + + # Only traverse modules and classes + if not inspect.isclass(root) and not inspect.ismodule(root): + return + + try: + children = inspect.getmembers(root) + except ImportError: + # On some Python installations, some modules do not support enumerating + # members (six in particular), leading to import errors. + children = [] + + new_stack = stack + [root] + visit(path, root, children) + for name, child in children: + # Do not descend into built-in modules + if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names: + continue + + # Break cycles + if any(child is item for item in new_stack): # `in`, but using `is` + continue + + child_path = path + '.' + name if path else name + _traverse_internal(child, visit, new_stack, child_path) + + +def traverse(root, visit): + """Recursively enumerate all members of `root`. + + Similar to the Python library function `os.path.walk`. + + Traverses the tree of Python objects starting with `root`, depth first. + Parent-child relationships in the tree are defined by membership in modules or + classes. The function `visit` is called with arguments + `(path, parent, children)` for each module or class `parent` found in the tree + of python objects starting with `root`. `path` is a string containing the name + with which `parent` is reachable from the current context. For example, if + `root` is a local class called `X` which contains a class `Y`, `visit` will be + called with `('Y', X.Y, children)`). + + If `root` is not a module or class, `visit` is never called. `traverse` + never descends into built-in modules. + + `children`, a list of `(name, object)` pairs are determined by + `inspect.getmembers`. To avoid visiting parts of the tree, `children` can be + modified in place, using `del` or slice assignment. + + Cycles (determined by reference equality, `is`) stop the traversal. A stack of + objects is kept to find cycles. Objects forming cycles may appear in + `children`, but `visit` will not be called with any object as `parent` which + is already in the stack. + + Traversing system modules can take a long time, it is advisable to pass a + `visit` callable which blacklists such modules. + + Args: + root: A python object with which to start the traversal. + visit: A function taking arguments `(path, parent, children)`. Will be + called for each object found in the traversal. + """ + _traverse_internal(root, visit, [], '') diff --git a/tensorflow/tools/common/traverse_test.py b/tensorflow/tools/common/traverse_test.py new file mode 100644 index 00000000000..eb195ec18ef --- /dev/null +++ b/tensorflow/tools/common/traverse_test.py @@ -0,0 +1,84 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Python module traversal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.python.platform import googletest +from tensorflow.tools.common import traverse + + +class TestVisitor(object): + + def __init__(self): + self.call_log = [] + + def __call__(self, path, parent, children): + # Do not traverse googletest, it's very deep. + for item in list(children): + if item[1] is googletest: + children.remove(item) + self.call_log += [(path, parent, children)] + + +class TraverseTest(googletest.TestCase): + + def test_cycle(self): + + class Cyclist(object): + pass + Cyclist.cycle = Cyclist + + visitor = TestVisitor() + traverse.traverse(Cyclist, visitor) + # We simply want to make sure we terminate. + + def test_module(self): + visitor = TestVisitor() + traverse.traverse(sys.modules[__name__], visitor) + + called = [parent for _, parent, _ in visitor.call_log] + + self.assertIn(TestVisitor, called) + self.assertIn(TraverseTest, called) + self.assertIn(traverse, called) + + def test_class(self): + visitor = TestVisitor() + traverse.traverse(TestVisitor, visitor) + self.assertEqual(TestVisitor, + visitor.call_log[0][1]) + # There are a bunch of other members, but make sure that the ones we know + # about are there. + self.assertIn('__init__', [name for name, _ in visitor.call_log[0][2]]) + self.assertIn('__call__', [name for name, _ in visitor.call_log[0][2]]) + + # There are more classes descended into, at least __class__ and + # __class__.__base__, neither of which are interesting to us, and which may + # change as part of Python version etc., so we don't test for them. + + def test_non_class(self): + integer = 5 + visitor = TestVisitor() + traverse.traverse(integer, visitor) + self.assertEqual([], visitor.call_log) + + +if __name__ == '__main__': + googletest.main()