From 0f01ee7a34d8e17ec43181eab7b46026bf9bae3c Mon Sep 17 00:00:00 2001
From: Amit Patankar <amitpatankar@google.com>
Date: Thu, 9 Apr 2020 10:34:32 -0700
Subject: [PATCH] Re-add dynamic loading flags and logic in Python.

PiperOrigin-RevId: 305712320
Change-Id: Iaa25ab0397e91ffa83cba805e81a6c9e315549d8
---
 tensorflow/python/BUILD                |  6 +-
 tensorflow/python/pywrap_tensorflow.py | 78 +++++++++++++++++++++-----
 2 files changed, 68 insertions(+), 16 deletions(-)

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 90ae5655399..f3449f80986 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5828,7 +5828,11 @@ py_library(
     name = "pywrap_tensorflow",
     srcs = [
         "pywrap_tensorflow.py",
-    ],
+    ] + if_static(
+        ["pywrap_dlopen_global_flags.py"],
+        # Import will fail, indicating no global dlopen flags
+        otherwise = [],
+    ),  # b/153585257
     srcs_version = "PY2AND3",
     deps = [":pywrap_tensorflow_internal"],
 )
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
index 8981a3c6302..863e2106f08 100644
--- a/tensorflow/python/pywrap_tensorflow.py
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -18,20 +18,68 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# Python2.7 does not have a ModuleNotFoundError.
-try:
-  ModuleNotFoundError
-except NameError:
-  ModuleNotFoundError = ImportError
+import ctypes
+import sys
+import traceback
 
-# pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable
-try:
-  from tensorflow.python._pywrap_tensorflow_internal import *
-# This try catch logic is because there is no bazel equivalent for py_extension.
-# Externally in opensource we must enable exceptions to load the shared object
-# by exposing the PyInit symbols with pybind. This error will only be
-# caught internally or if someone changes the name of the target _pywrap_tensorflow_internal.
+from tensorflow.python.platform import self_check
 
-# This logic is used in other internal projects using py_extension.
-except ModuleNotFoundError:
-  pass
+# Perform pre-load sanity checks in order to produce a more actionable error.
+self_check.preload_check()
+
+# pylint: disable=wildcard-import,g-import-not-at-top,unused-import,line-too-long
+
+try:
+  # This import is expected to fail if there is an explicit shared object
+  # dependency (with_framework_lib=true), since we do not need RTLD_GLOBAL.
+  from tensorflow.python import pywrap_dlopen_global_flags
+  _use_dlopen_global_flags = True
+except ImportError:
+  _use_dlopen_global_flags = False
+
+# On UNIX-based platforms, pywrap_tensorflow is a python library that
+# dynamically loads _pywrap_tensorflow.so.
+_can_set_rtld_local = (
+    hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'))
+if _can_set_rtld_local:
+  _default_dlopen_flags = sys.getdlopenflags()
+
+try:
+  if _use_dlopen_global_flags:
+    pywrap_dlopen_global_flags.set_dlopen_flags()
+  elif _can_set_rtld_local:
+    # Ensure RTLD_LOCAL behavior for platforms where it isn't the default
+    # (macOS). On Linux RTLD_LOCAL is 0, so this does nothing (and would not
+    # override an RTLD_GLOBAL in _default_dlopen_flags).
+    sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL)
+
+  # Python2.7 does not have a ModuleNotFoundError.
+  try:
+    ModuleNotFoundError
+  except NameError:
+    ModuleNotFoundError = ImportError
+
+  # pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable
+  try:
+    from tensorflow.python._pywrap_tensorflow_internal import *
+  # This try catch logic is because there is no bazel equivalent for py_extension.
+  # Externally in opensource we must enable exceptions to load the shared object
+  # by exposing the PyInit symbols with pybind. This error will only be
+  # caught internally or if someone changes the name of the target _pywrap_tensorflow_internal.
+
+  # This logic is used in other internal projects using py_extension.
+  except ModuleNotFoundError:
+    pass
+
+  if _use_dlopen_global_flags:
+    pywrap_dlopen_global_flags.reset_dlopen_flags()
+  elif _can_set_rtld_local:
+    sys.setdlopenflags(_default_dlopen_flags)
+except ImportError:
+  msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
+See https://www.tensorflow.org/install/errors\n
+for some common reasons and solutions.  Include the entire stack trace
+above this error message when asking for help.""" % traceback.format_exc()
+  raise ImportError(msg)
+
+# pylint: enable=wildcard-import,g-import-not-at-top,unused-import,line-too-long