Use importlib.import_module instead of __import__ when unpickling.
`__import__` only returns the top level package, so `__import__('tensorflow.compat.v1')` will return `tensorflow` not `tensorflow.compat.v1`. PiperOrigin-RevId: 301345019 Change-Id: If2b3b1e34717bc387e2a79a3a63960da79ca3d64
This commit is contained in:
parent
5815a8d2cd
commit
f17712cfcb
@ -236,4 +236,4 @@ class TFModuleWrapper(types.ModuleType):
|
||||
return self._tfmw_wrapped_module.__repr__()
|
||||
|
||||
def __reduce__(self):
|
||||
return __import__, (self.__name__,)
|
||||
return importlib.import_module, (self.__name__,)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import types
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
@ -132,5 +133,15 @@ class LazyLoadingWrapperTest(test.TestCase):
|
||||
self.assertEqual(wrapped_module.lite, _cmd)
|
||||
|
||||
|
||||
class PickleTest(test.TestCase):
|
||||
|
||||
def testPickleSubmodule(self):
|
||||
name = PickleTest.__module__ # The current module is a submodule.
|
||||
module = module_wrapper.TFModuleWrapper(MockModule(name), name)
|
||||
restored = pickle.loads(pickle.dumps(module))
|
||||
self.assertEqual(restored.__name__, name)
|
||||
self.assertIsNotNone(restored.PickleTest)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user