From 0f01ee7a34d8e17ec43181eab7b46026bf9bae3c Mon Sep 17 00:00:00 2001 From: Amit Patankar <amitpatankar@google.com> Date: Thu, 9 Apr 2020 10:34:32 -0700 Subject: [PATCH] Re-add dynamic loading flags and logic in Python. PiperOrigin-RevId: 305712320 Change-Id: Iaa25ab0397e91ffa83cba805e81a6c9e315549d8 --- tensorflow/python/BUILD | 6 +- tensorflow/python/pywrap_tensorflow.py | 78 +++++++++++++++++++++----- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 90ae5655399..f3449f80986 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5828,7 +5828,11 @@ py_library( name = "pywrap_tensorflow", srcs = [ "pywrap_tensorflow.py", - ], + ] + if_static( + ["pywrap_dlopen_global_flags.py"], + # Import will fail, indicating no global dlopen flags + otherwise = [], + ), # b/153585257 srcs_version = "PY2AND3", deps = [":pywrap_tensorflow_internal"], ) diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py index 8981a3c6302..863e2106f08 100644 --- a/tensorflow/python/pywrap_tensorflow.py +++ b/tensorflow/python/pywrap_tensorflow.py @@ -18,20 +18,68 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Python2.7 does not have a ModuleNotFoundError. -try: - ModuleNotFoundError -except NameError: - ModuleNotFoundError = ImportError +import ctypes +import sys +import traceback -# pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable -try: - from tensorflow.python._pywrap_tensorflow_internal import * -# This try catch logic is because there is no bazel equivalent for py_extension. -# Externally in opensource we must enable exceptions to load the shared object -# by exposing the PyInit symbols with pybind. This error will only be -# caught internally or if someone changes the name of the target _pywrap_tensorflow_internal. +from tensorflow.python.platform import self_check -# This logic is used in other internal projects using py_extension. -except ModuleNotFoundError: - pass +# Perform pre-load sanity checks in order to produce a more actionable error. +self_check.preload_check() + +# pylint: disable=wildcard-import,g-import-not-at-top,unused-import,line-too-long + +try: + # This import is expected to fail if there is an explicit shared object + # dependency (with_framework_lib=true), since we do not need RTLD_GLOBAL. + from tensorflow.python import pywrap_dlopen_global_flags + _use_dlopen_global_flags = True +except ImportError: + _use_dlopen_global_flags = False + +# On UNIX-based platforms, pywrap_tensorflow is a python library that +# dynamically loads _pywrap_tensorflow.so. +_can_set_rtld_local = ( + hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')) +if _can_set_rtld_local: + _default_dlopen_flags = sys.getdlopenflags() + +try: + if _use_dlopen_global_flags: + pywrap_dlopen_global_flags.set_dlopen_flags() + elif _can_set_rtld_local: + # Ensure RTLD_LOCAL behavior for platforms where it isn't the default + # (macOS). On Linux RTLD_LOCAL is 0, so this does nothing (and would not + # override an RTLD_GLOBAL in _default_dlopen_flags). + sys.setdlopenflags(_default_dlopen_flags | ctypes.RTLD_LOCAL) + + # Python2.7 does not have a ModuleNotFoundError. + try: + ModuleNotFoundError + except NameError: + ModuleNotFoundError = ImportError + + # pylint: disable=wildcard-import,g-import-not-at-top,line-too-long,undefined-variable + try: + from tensorflow.python._pywrap_tensorflow_internal import * + # This try catch logic is because there is no bazel equivalent for py_extension. + # Externally in opensource we must enable exceptions to load the shared object + # by exposing the PyInit symbols with pybind. This error will only be + # caught internally or if someone changes the name of the target _pywrap_tensorflow_internal. + + # This logic is used in other internal projects using py_extension. + except ModuleNotFoundError: + pass + + if _use_dlopen_global_flags: + pywrap_dlopen_global_flags.reset_dlopen_flags() + elif _can_set_rtld_local: + sys.setdlopenflags(_default_dlopen_flags) +except ImportError: + msg = """%s\n\nFailed to load the native TensorFlow runtime.\n +See https://www.tensorflow.org/install/errors\n +for some common reasons and solutions. Include the entire stack trace +above this error message when asking for help.""" % traceback.format_exc() + raise ImportError(msg) + +# pylint: enable=wildcard-import,g-import-not-at-top,unused-import,line-too-long