Add capability to load plugins installed via tensorflow-plugins pip directory.
PiperOrigin-RevId: 226091344
This commit is contained in:
parent
51bac90e00
commit
5195204b47
@ -18,7 +18,10 @@ from __future__ import absolute_import as _absolute_import
|
||||
from __future__ import division as _division
|
||||
from __future__ import print_function as _print_function
|
||||
|
||||
import distutils as _distutils
|
||||
import inspect as _inspect
|
||||
import os as _os
|
||||
import site as _site
|
||||
import sys as _sys
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
@ -52,6 +55,41 @@ elif _tf_api_dir not in __path__:
|
||||
from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
|
||||
_compat.enable_v2_behavior()
|
||||
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
|
||||
# directories.
|
||||
# TODO(gunan): Find a better location for this code snippet.
|
||||
from tensorflow.python.framework import load_library as _ll
|
||||
from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
||||
if 'sysconfig' in dir(_distutils):
|
||||
_site_packages_dirs += [_distutils.sysconfig.get_python_lib()]
|
||||
|
||||
_site_packages_dirs = list(set(_site_packages_dirs))
|
||||
|
||||
# Find the location of this exact file.
|
||||
_current_file_location = _inspect.getfile(_inspect.currentframe())
|
||||
|
||||
def _running_from_pip_package():
|
||||
return any(
|
||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||
|
||||
if _running_from_pip_package():
|
||||
for s in _site_packages_dirs:
|
||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(plugin_dir):
|
||||
_ll.load_library(plugin_dir)
|
||||
|
||||
# These symbols appear because we import the python package which
|
||||
# in turn imports from tensorflow.core and tensorflow.python. They
|
||||
# must come from this module. So python adds these symbols for the
|
||||
@ -72,4 +110,6 @@ try:
|
||||
del compiler
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -18,7 +18,10 @@ from __future__ import absolute_import as _absolute_import
|
||||
from __future__ import division as _division
|
||||
from __future__ import print_function as _print_function
|
||||
|
||||
import distutils as _distutils
|
||||
import inspect as _inspect
|
||||
import os as _os
|
||||
import site as _site
|
||||
import sys as _sys
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
@ -69,6 +72,40 @@ if not hasattr(_current_module, '__path__'):
|
||||
elif _tf_api_dir not in __path__:
|
||||
__path__.append(_tf_api_dir)
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
|
||||
# directories.
|
||||
# TODO(gunan): Find a better location for this code snippet.
|
||||
from tensorflow.python.framework import load_library as _ll
|
||||
from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
||||
if 'sysconfig' in dir(_distutils):
|
||||
_site_packages_dirs += [_distutils.sysconfig.get_python_lib()]
|
||||
|
||||
_site_packages_dirs = list(set(_site_packages_dirs))
|
||||
|
||||
# Find the location of this exact file.
|
||||
_current_file_location = _inspect.getfile(_inspect.currentframe())
|
||||
|
||||
def _running_from_pip_package():
|
||||
return any(
|
||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||
|
||||
if _running_from_pip_package():
|
||||
for s in _site_packages_dirs:
|
||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
||||
if _fi.file_exists(plugin_dir):
|
||||
_ll.load_library(plugin_dir)
|
||||
|
||||
# These symbols appear because we import the python package which
|
||||
# in turn imports from tensorflow.core and tensorflow.python. They
|
||||
# must come from this module. So python adds these symbols for the
|
||||
|
Loading…
Reference in New Issue
Block a user