Change traverse_test.test_module to traverse a constructed dummy module rather than testcase itself.
PiperOrigin-RevId: 197010681
This commit is contained in:
parent
f4162b7eaf
commit
151e35680b
@ -40,7 +40,24 @@ py_test(
|
||||
srcs = ["traverse_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":test_module1",
|
||||
":test_module2",
|
||||
":traverse",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_module1",
|
||||
srcs = ["test_module1.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":test_module2",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_module2",
|
||||
srcs = ["test_module2.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
31
tensorflow/tools/common/test_module1.py
Normal file
31
tensorflow/tools/common/test_module1.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A module target for TraverseTest.test_module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.tools.common import test_module2
|
||||
|
||||
|
||||
class ModuleClass1(object):
|
||||
|
||||
def __init__(self):
|
||||
self._m2 = test_module2.ModuleClass2()
|
||||
|
||||
def __model_class1_method__(self):
|
||||
pass
|
||||
|
29
tensorflow/tools/common/test_module2.py
Normal file
29
tensorflow/tools/common/test_module2.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A module target for TraverseTest.test_module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class ModuleClass2(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __model_class1_method__(self):
|
||||
pass
|
||||
|
@ -18,9 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.tools.common import test_module1
|
||||
from tensorflow.tools.common import test_module2
|
||||
from tensorflow.tools.common import traverse
|
||||
|
||||
|
||||
@ -30,10 +30,6 @@ class TestVisitor(object):
|
||||
self.call_log = []
|
||||
|
||||
def __call__(self, path, parent, children):
|
||||
# Do not traverse googletest, it's very deep.
|
||||
for item in list(children):
|
||||
if item[1] is googletest:
|
||||
children.remove(item)
|
||||
self.call_log += [(path, parent, children)]
|
||||
|
||||
|
||||
@ -51,13 +47,12 @@ class TraverseTest(googletest.TestCase):
|
||||
|
||||
def test_module(self):
|
||||
visitor = TestVisitor()
|
||||
traverse.traverse(sys.modules[__name__], visitor)
|
||||
traverse.traverse(test_module1, visitor)
|
||||
|
||||
called = [parent for _, parent, _ in visitor.call_log]
|
||||
|
||||
self.assertIn(TestVisitor, called)
|
||||
self.assertIn(TraverseTest, called)
|
||||
self.assertIn(traverse, called)
|
||||
self.assertIn(test_module1.ModuleClass1, called)
|
||||
self.assertIn(test_module2.ModuleClass2, called)
|
||||
|
||||
def test_class(self):
|
||||
visitor = TestVisitor()
|
||||
|
Loading…
x
Reference in New Issue
Block a user