diff --git a/tensorflow/python/util/module_wrapper.py b/tensorflow/python/util/module_wrapper.py index dffdd513b4b..c5856eeb13d 100644 --- a/tensorflow/python/util/module_wrapper.py +++ b/tensorflow/python/util/module_wrapper.py @@ -236,4 +236,4 @@ class TFModuleWrapper(types.ModuleType): return self._tfmw_wrapped_module.__repr__() def __reduce__(self): - return __import__, (self.__name__,) + return importlib.import_module, (self.__name__,) diff --git a/tensorflow/python/util/module_wrapper_test.py b/tensorflow/python/util/module_wrapper_test.py index 582e98abdfa..f8a2161bdff 100644 --- a/tensorflow/python/util/module_wrapper_test.py +++ b/tensorflow/python/util/module_wrapper_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import logging +import pickle import types from tensorflow.python.platform import test @@ -132,5 +133,15 @@ class LazyLoadingWrapperTest(test.TestCase): self.assertEqual(wrapped_module.lite, _cmd) +class PickleTest(test.TestCase): + + def testPickleSubmodule(self): + name = PickleTest.__module__ # The current module is a submodule. + module = module_wrapper.TFModuleWrapper(MockModule(name), name) + restored = pickle.loads(pickle.dumps(module)) + self.assertEqual(restored.__name__, name) + self.assertIsNotNone(restored.PickleTest) + + if __name__ == '__main__': test.main()