diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 983aa361e43..2c0a7452692 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,7 +18,10 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site import sys as _sys # API IMPORTS PLACEHOLDER @@ -52,6 +55,41 @@ elif _tf_api_dir not in __path__: from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top _compat.enable_v2_behavior() + +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) + # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the @@ -72,4 +110,6 @@ try: del compiler except NameError: pass + + # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index e1996397627..514aba1b596 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -18,7 +18,10 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site import sys as _sys # pylint: disable=g-bad-import-order @@ -69,6 +72,40 @@ if not hasattr(_current_module, '__path__'): elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) + # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the