From f17712cfcbd14c19a1f7cb1dbb9be41f2dbba184 Mon Sep 17 00:00:00 2001
From: Tom Hennigan <tomhennigan@google.com>
Date: Tue, 17 Mar 2020 04:07:21 -0700
Subject: [PATCH] Use importlib.import_module instead of __import__ when
 unpickling.

`__import__` only returns the top level package, so
`__import__('tensorflow.compat.v1')` will return `tensorflow` not
`tensorflow.compat.v1`.

PiperOrigin-RevId: 301345019
Change-Id: If2b3b1e34717bc387e2a79a3a63960da79ca3d64
---
 tensorflow/python/util/module_wrapper.py      |  2 +-
 tensorflow/python/util/module_wrapper_test.py | 11 +++++++++++
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/util/module_wrapper.py b/tensorflow/python/util/module_wrapper.py
index dffdd513b4b..c5856eeb13d 100644
--- a/tensorflow/python/util/module_wrapper.py
+++ b/tensorflow/python/util/module_wrapper.py
@@ -236,4 +236,4 @@ class TFModuleWrapper(types.ModuleType):
     return self._tfmw_wrapped_module.__repr__()
 
   def __reduce__(self):
-    return __import__, (self.__name__,)
+    return importlib.import_module, (self.__name__,)
diff --git a/tensorflow/python/util/module_wrapper_test.py b/tensorflow/python/util/module_wrapper_test.py
index 582e98abdfa..f8a2161bdff 100644
--- a/tensorflow/python/util/module_wrapper_test.py
+++ b/tensorflow/python/util/module_wrapper_test.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import logging
+import pickle
 import types
 
 from tensorflow.python.platform import test
@@ -132,5 +133,15 @@ class LazyLoadingWrapperTest(test.TestCase):
     self.assertEqual(wrapped_module.lite, _cmd)
 
 
+class PickleTest(test.TestCase):
+
+  def testPickleSubmodule(self):
+    name = PickleTest.__module__  # The current module is a submodule.
+    module = module_wrapper.TFModuleWrapper(MockModule(name), name)
+    restored = pickle.loads(pickle.dumps(module))
+    self.assertEqual(restored.__name__, name)
+    self.assertIsNotNone(restored.PickleTest)
+
+
 if __name__ == '__main__':
   test.main()