Change traverse_test.test_module to traverse a constructed dummy module rather than testcase itself.

PiperOrigin-RevId: 197010681
This commit is contained in:
A. Unique TensorFlower 2018-05-17 10:15:45 -07:00 committed by TensorFlower Gardener
parent f4162b7eaf
commit 151e35680b
4 changed files with 82 additions and 10 deletions

View File

@ -40,7 +40,24 @@ py_test(
srcs = ["traverse_test.py"],
srcs_version = "PY2AND3",
deps = [
":test_module1",
":test_module2",
":traverse",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "test_module1",
srcs = ["test_module1.py"],
srcs_version = "PY2AND3",
deps = [
":test_module2",
],
)
py_library(
name = "test_module2",
srcs = ["test_module2.py"],
srcs_version = "PY2AND3",
)

View File

@ -0,0 +1,31 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A module target for TraverseTest.test_module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.tools.common import test_module2
class ModuleClass1(object):
def __init__(self):
self._m2 = test_module2.ModuleClass2()
def __model_class1_method__(self):
pass

View File

@ -0,0 +1,29 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A module target for TraverseTest.test_module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class ModuleClass2(object):
def __init__(self):
pass
def __model_class1_method__(self):
pass

View File

@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from tensorflow.python.platform import googletest
from tensorflow.tools.common import test_module1
from tensorflow.tools.common import test_module2
from tensorflow.tools.common import traverse
@ -30,10 +30,6 @@ class TestVisitor(object):
self.call_log = []
def __call__(self, path, parent, children):
# Do not traverse googletest, it's very deep.
for item in list(children):
if item[1] is googletest:
children.remove(item)
self.call_log += [(path, parent, children)]
@ -51,13 +47,12 @@ class TraverseTest(googletest.TestCase):
def test_module(self):
visitor = TestVisitor()
traverse.traverse(sys.modules[__name__], visitor)
traverse.traverse(test_module1, visitor)
called = [parent for _, parent, _ in visitor.call_log]
self.assertIn(TestVisitor, called)
self.assertIn(TraverseTest, called)
self.assertIn(traverse, called)
self.assertIn(test_module1.ModuleClass1, called)
self.assertIn(test_module2.ModuleClass2, called)
def test_class(self):
visitor = TestVisitor()