Use importlib.import_module instead of __import__ when unpickling.

`__import__` only returns the top level package, so
`__import__('tensorflow.compat.v1')` will return `tensorflow` not
`tensorflow.compat.v1`.

PiperOrigin-RevId: 301345019
Change-Id: If2b3b1e34717bc387e2a79a3a63960da79ca3d64
This commit is contained in:
Tom Hennigan 2020-03-17 04:07:21 -07:00 committed by TensorFlower Gardener
parent 5815a8d2cd
commit f17712cfcb
2 changed files with 12 additions and 1 deletions

View File

@ -236,4 +236,4 @@ class TFModuleWrapper(types.ModuleType):
return self._tfmw_wrapped_module.__repr__()
def __reduce__(self):
return __import__, (self.__name__,)
return importlib.import_module, (self.__name__,)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import logging
import pickle
import types
from tensorflow.python.platform import test
@ -132,5 +133,15 @@ class LazyLoadingWrapperTest(test.TestCase):
self.assertEqual(wrapped_module.lite, _cmd)
class PickleTest(test.TestCase):
def testPickleSubmodule(self):
name = PickleTest.__module__ # The current module is a submodule.
module = module_wrapper.TFModuleWrapper(MockModule(name), name)
restored = pickle.loads(pickle.dumps(module))
self.assertEqual(restored.__name__, name)
self.assertIsNotNone(restored.PickleTest)
if __name__ == '__main__':
test.main()