STT-tensorflow/tensorflow/python/util/module_wrapper_test.py
Tom Hennigan f17712cfcb Use importlib.import_module instead of __import__ when unpickling.
`__import__` only returns the top level package, so
`__import__('tensorflow.compat.v1')` will return `tensorflow` not
`tensorflow.compat.v1`.

PiperOrigin-RevId: 301345019
Change-Id: If2b3b1e34717bc387e2a79a3a63960da79ca3d64
2020-03-17 04:10:37 -07:00

148 lines
5.6 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.util.module_wrapper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import pickle
import types
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import module_wrapper
from tensorflow.python.util import tf_inspect
from tensorflow.tools.compatibility import all_renames_v2
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
class MockModule(types.ModuleType):
pass
class DeprecationWrapperTest(test.TestCase):
def testWrapperIsAModule(self):
module = MockModule('test')
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
self.assertTrue(tf_inspect.ismodule(wrapped_module))
@test.mock.patch.object(logging, 'warning', autospec=True)
def testDeprecationWarnings(self, mock_warning):
module = MockModule('test')
module.foo = 1
module.bar = 2
module.baz = 3
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
self.assertTrue(tf_inspect.ismodule(wrapped_module))
self.assertEqual(0, mock_warning.call_count)
bar = wrapped_module.bar
self.assertEqual(1, mock_warning.call_count)
foo = wrapped_module.foo
self.assertEqual(1, mock_warning.call_count)
baz = wrapped_module.baz # pylint: disable=unused-variable
self.assertEqual(2, mock_warning.call_count)
baz = wrapped_module.baz
self.assertEqual(2, mock_warning.call_count)
# Check that values stayed the same
self.assertEqual(module.foo, foo)
self.assertEqual(module.bar, bar)
class LazyLoadingWrapperTest(test.TestCase):
def testLazyLoad(self):
module = MockModule('test')
apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member
self.assertEqual(wrapped_module.cmd, _cmd)
self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
def testLazyLoadLocalOverride(self):
# Test that we can override and add fields to the wrapped module.
module = MockModule('test')
apis = {'cmd': ('', 'cmd')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
self.assertEqual(wrapped_module.cmd, _cmd)
setattr(wrapped_module, 'cmd', 1)
setattr(wrapped_module, 'cgi', 2)
self.assertEqual(wrapped_module.cmd, 1) # override
self.assertEqual(wrapped_module.cgi, 2) # add
def testLazyLoadDict(self):
# Test that we can override and add fields to the wrapped module.
module = MockModule('test')
apis = {'cmd': ('', 'cmd')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
# At first cmd key does not exist in __dict__
self.assertNotIn('cmd', wrapped_module.__dict__)
# After it is referred (lazyloaded), it gets added to __dict__
wrapped_module.cmd # pylint: disable=pointless-statement
self.assertEqual(wrapped_module.__dict__['cmd'], _cmd)
# When we call setattr, it also gets added to __dict__
setattr(wrapped_module, 'cmd2', _cmd)
self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd)
def testLazyLoadWildcardImport(self):
# Test that public APIs are in __all__.
module = MockModule('test')
module._should_not_be_public = 5
apis = {'cmd': ('', 'cmd')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
setattr(wrapped_module, 'hello', 1)
self.assertIn('hello', wrapped_module.__all__)
self.assertIn('cmd', wrapped_module.__all__)
self.assertNotIn('_should_not_be_public', wrapped_module.__all__)
def testLazyLoadCorrectLiteModule(self):
# If set, always load lite module from public API list.
module = MockModule('test')
apis = {'lite': ('', 'cmd')}
module.lite = 5
import cmd as _cmd # pylint: disable=g-import-not-at-top
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False, has_lite=True)
self.assertEqual(wrapped_module.lite, _cmd)
class PickleTest(test.TestCase):
def testPickleSubmodule(self):
name = PickleTest.__module__ # The current module is a submodule.
module = module_wrapper.TFModuleWrapper(MockModule(name), name)
restored = pickle.loads(pickle.dumps(module))
self.assertEqual(restored.__name__, name)
self.assertIsNotNone(restored.PickleTest)
if __name__ == '__main__':
test.main()