diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 90ae5655399..f3449f80986 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5828,7 +5828,11 @@ py_library( name = "pywrap_tensorflow", srcs = [ "pywrap_tensorflow.py", - ], + ] + if_static( + ["pywrap_dlopen_global_flags.py"], + # Import will fail, indicating no global dlopen flags + otherwise = [], + ), # b/153585257 srcs_version = "PY2AND3", deps = [":pywrap_tensorflow_internal"], ) diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py index 8981a3c6302..863e2106f08 100644 --- a/tensorflow/python/pywrap_tensorflow.py +++ b/tensorflow/python/pywrap_tensorflow.py @@ -18,20 +18,68 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Python2.7 does not have a ModuleNotFoundError. -try: - ModuleNotFoundError -except NameError: - ModuleNotFoundError = ImportError +import ctypes +import sys +import traceback -# pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable -try: - from tensorflow.python._pywrap_tensorflow_internal import * -# This try catch logic is because there is no bazel equivalent for py_extension. -# Externally in opensource we must enable exceptions to load the shared object -# by exposing the PyInit symbols with pybind. This error will only be -# caught internally or if someone changes the name of the target _pywrap_tensorflow_internal. +from tensorflow.python.platform import self_check -# This logic is used in other internal projects using py_extension. -except ModuleNotFoundError: - pass +# Perform pre-load sanity checks in order to produce a more actionable error. +self_check.preload_check() + +# pylint: disable=wildcard-import,g-import-not-at-top,unused-import,line-too-long + +try: + # This import is expected to fail if there is an explicit shared object + # dependency (with_framework_lib=true), since we do not need RTLD_GLOBAL. + from tensorflow.python import pywrap_dlopen_global_flags + _use_dlopen_global_flags = True +except ImportError: + _use_dlopen_global_flags = False + +# On UNIX-based platforms, pywrap_tensorflow is a python library that +# dynamically loads _pywrap_tensorflow.so. +_can_set_rtld_local = ( + hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')) +if _can_set_rtld_local: + _default_dlopen_flags = sys.getdlopenflags() + +try: + if _use_dlopen_global_flags: + pywrap_dlopen_global_flags.set_dlopen_flags() + elif _can_set_rtld_local: + # Ensure RTLD_LOCAL behavior for platforms where it isn't the default + # (macOS). On Linux RTLD_LOCAL is 0, so this does nothing (and would not + # override an RTLD_GLOBAL in _default_dlopen_flags). + sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL) + + # Python2.7 does not have a ModuleNotFoundError. + try: + ModuleNotFoundError + except NameError: + ModuleNotFoundError = ImportError + + # pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable + try: + from tensorflow.python._pywrap_tensorflow_internal import * + # This try catch logic is because there is no bazel equivalent for py_extension. + # Externally in opensource we must enable exceptions to load the shared object + # by exposing the PyInit symbols with pybind. This error will only be + # caught internally or if someone changes the name of the target _pywrap_tensorflow_internal. + + # This logic is used in other internal projects using py_extension. + except ModuleNotFoundError: + pass + + if _use_dlopen_global_flags: + pywrap_dlopen_global_flags.reset_dlopen_flags() + elif _can_set_rtld_local: + sys.setdlopenflags(_default_dlopen_flags) +except ImportError: + msg = """%s\n\nFailed to load the native TensorFlow runtime.\n +See https://www.tensorflow.org/install/errors\n +for some common reasons and solutions. Include the entire stack trace +above this error message when asking for help.""" % traceback.format_exc() + raise ImportError(msg) + +# pylint: enable=wildcard-import,g-import-not-at-top,unused-import,line-too-long