Add hooks to allow lazyLoad TensorFlow public API.
- Updates existing DeprecationWrapper with the ability to import modules only when they are referred. - Updates how TensorFlow generates public API. Wraps all generated TensorFlow __init__.py modules with this enhanced wrapper. To enable lazy-loading in the future, toggle _LAZY_LOADING flag in create_python_api.py. Once lazy loading is enabled, the wrapper will have the following behaviors: - dir() will always return module?s attributes. - __all__ will always return all public APIs. - __dict__ will be populated as attributes are being referred. - After wrapper instance is created, to add more attributes, use setattr(import does not explicitly call setattr) to make sure dir, __all__, __dict__ are updated. - import * will work as expected. Built and tested with pip package. PiperOrigin-RevId: 257240535
This commit is contained in:
parent
b34ed5bd1e
commit
7ece5ce95f
@ -41,12 +41,15 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Make sure directory containing top level submodules is in
|
# Make sure directory containing top level submodules is in
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
# We're using bitwise, but there's nothing special about that.
|
# We're using bitwise, but there's nothing special about that.
|
||||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
_API_MODULE = sys.modules[__name__].bitwise # pylint: disable=undefined-variable
|
||||||
_current_module = _sys.modules[__name__]
|
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||||
|
_current_module = _sys.modules[__name__]
|
||||||
|
|
||||||
if not hasattr(_current_module, '__path__'):
|
if not hasattr(_current_module, '__path__'):
|
||||||
__path__ = [_tf_api_dir]
|
__path__ = [_tf_api_dir]
|
||||||
elif _tf_api_dir not in __path__:
|
elif _tf_api_dir not in __path__:
|
||||||
@ -57,6 +60,7 @@ try:
|
|||||||
from tensorboard.summary._tf import summary
|
from tensorboard.summary._tf import summary
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "summary", summary)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||||
@ -65,6 +69,7 @@ try:
|
|||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -72,6 +77,7 @@ try:
|
|||||||
from tensorflow.python.keras.api._v2 import keras
|
from tensorflow.python.keras.api._v2 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -122,25 +128,17 @@ if _running_from_pip_package():
|
|||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
try:
|
try:
|
||||||
del python
|
del python
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('python')
|
|
||||||
del core
|
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('core')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
# Don't fail if these modules are not available.
|
|
||||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
|
||||||
# does not have 'python', 'core' directories. Then, it will be copied
|
|
||||||
# to tensorflow/ which does have these two directories.
|
|
||||||
pass
|
pass
|
||||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
try:
|
||||||
# others don't exist.
|
del core
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
del compiler
|
del compiler
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('compiler')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
# pylint: enable=undefined-variable
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
@ -148,6 +146,10 @@ if hasattr(_current_module, 'keras'):
|
|||||||
metrics = keras.metrics
|
metrics = keras.metrics
|
||||||
optimizers = keras.optimizers
|
optimizers = keras.optimizers
|
||||||
initializers = keras.initializers
|
initializers = keras.initializers
|
||||||
|
setattr(_current_module, "losses", losses)
|
||||||
|
setattr(_current_module, "metrics", metrics)
|
||||||
|
setattr(_current_module, "optimizers", optimizers)
|
||||||
|
setattr(_current_module, "initializers", initializers)
|
||||||
|
|
||||||
compat.v2.compat.v1 = compat.v1
|
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
@ -30,10 +30,12 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Make sure directory containing top level submodules is in
|
# Make sure directory containing top level submodules is in
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
# We're using bitwise, but there's nothing special about that.
|
# We're using bitwise, but there's nothing special about that.
|
||||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
_API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||||
if not hasattr(_current_module, '__path__'):
|
if not hasattr(_current_module, '__path__'):
|
||||||
@ -46,6 +48,7 @@ try:
|
|||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -53,6 +56,7 @@ try:
|
|||||||
from tensorflow.python.keras.api._v1 import keras
|
from tensorflow.python.keras.api._v1 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -77,9 +81,8 @@ if '__all__' in vars():
|
|||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
# The 'app' module will be imported as part of the placeholder section above.
|
# The 'app' module will be imported as part of the placeholder section above.
|
||||||
app.flags = flags # pylint: disable=undefined-variable
|
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||||
if '__all__' in vars():
|
setattr(_current_module, "flags", flags)
|
||||||
vars()['__all__'].append('flags')
|
|
||||||
|
|
||||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||||
# running under pip.
|
# running under pip.
|
||||||
@ -122,25 +125,16 @@ if _running_from_pip_package():
|
|||||||
# pylint: disable=undefined-variable
|
# pylint: disable=undefined-variable
|
||||||
try:
|
try:
|
||||||
del python
|
del python
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('python')
|
|
||||||
del core
|
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('core')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
# Don't fail if these modules are not available.
|
|
||||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
|
||||||
# does not have 'python', 'core' directories. Then, it will be copied
|
|
||||||
# to tensorflow/ which does have these two directories.
|
|
||||||
pass
|
pass
|
||||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
try:
|
||||||
# others don't exist.
|
del core
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
del compiler
|
del compiler
|
||||||
if '__all__' in vars():
|
|
||||||
vars()['__all__'].remove('compiler')
|
|
||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
compat.v2.compat.v1 = compat.v1
|
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
@ -28,12 +28,16 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
try:
|
try:
|
||||||
from tensorboard.summary._tf import summary
|
from tensorboard.summary._tf import summary
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||||
|
# Make sure we get the correct summary module with lazy loading
|
||||||
|
setattr(_current_module, "summary", summary)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_logging.warning(
|
_logging.warning(
|
||||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||||
@ -43,6 +47,7 @@ try:
|
|||||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -50,6 +55,7 @@ try:
|
|||||||
from tensorflow.python.keras.api._v2 import keras
|
from tensorflow.python.keras.api._v2 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -61,11 +67,15 @@ except ImportError:
|
|||||||
#
|
#
|
||||||
# This make this one symbol available directly.
|
# This make this one symbol available directly.
|
||||||
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
|
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
|
||||||
|
setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
_current_module = _sys.modules[__name__]
|
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
losses = keras.losses
|
losses = keras.losses
|
||||||
metrics = keras.metrics
|
metrics = keras.metrics
|
||||||
optimizers = keras.optimizers
|
optimizers = keras.optimizers
|
||||||
initializers = keras.initializers
|
initializers = keras.initializers
|
||||||
|
setattr(_current_module, "losses", losses)
|
||||||
|
setattr(_current_module, "metrics", metrics)
|
||||||
|
setattr(_current_module, "optimizers", optimizers)
|
||||||
|
setattr(_current_module, "initializers", initializers)
|
||||||
|
@ -27,12 +27,15 @@ from tensorflow.python.tools import module_util as _module_util
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
|
# WRAPPER_PLACEHOLDER
|
||||||
|
|
||||||
# Hook external TensorFlow modules.
|
# Hook external TensorFlow modules.
|
||||||
_current_module = _sys.modules[__name__]
|
_current_module = _sys.modules[__name__]
|
||||||
try:
|
try:
|
||||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "estimator", estimator)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -40,9 +43,11 @@ try:
|
|||||||
from tensorflow.python.keras.api._v1 import keras
|
from tensorflow.python.keras.api._v1 import keras
|
||||||
_current_module.__path__ = (
|
_current_module.__path__ = (
|
||||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||||
|
setattr(_current_module, "keras", keras)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
app.flags = flags # pylint: disable=undefined-variable
|
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||||
|
setattr(_current_module, "flags", flags)
|
||||||
|
@ -4545,9 +4545,9 @@ tf_py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "deprecation_wrapper_test",
|
name = "module_wrapper_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["util/deprecation_wrapper_test.py"],
|
srcs = ["util/module_wrapper_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
":util",
|
":util",
|
||||||
|
@ -48,15 +48,48 @@ _GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
|
|||||||
|
|
||||||
from __future__ import print_function as _print_function
|
from __future__ import print_function as _print_function
|
||||||
|
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n'
|
_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n'
|
||||||
_DEPRECATION_FOOTER = """
|
_DEPRECATION_FOOTER = """
|
||||||
import sys as _sys
|
from tensorflow.python.util import module_wrapper as _module_wrapper
|
||||||
from tensorflow.python.util import deprecation_wrapper as _deprecation_wrapper
|
|
||||||
|
|
||||||
if not isinstance(_sys.modules[__name__], _deprecation_wrapper.DeprecationWrapper):
|
if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
|
||||||
_sys.modules[__name__] = _deprecation_wrapper.DeprecationWrapper(
|
_sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
|
||||||
_sys.modules[__name__], "%s")
|
_sys.modules[__name__], "%s", public_apis=_PUBLIC_APIS, deprecation=%s,
|
||||||
|
has_lite=%s)
|
||||||
|
"""
|
||||||
|
_MODULE_TEXT_TEMPLATE = """
|
||||||
|
# Inform pytype that this module is dynamically populated (b/111239204).
|
||||||
|
_LAZY_LOADING = False
|
||||||
|
|
||||||
|
_PUBLIC_APIS = {
|
||||||
|
%s
|
||||||
|
}
|
||||||
|
|
||||||
|
if _LAZY_LOADING:
|
||||||
|
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||||
|
else:
|
||||||
|
import importlib as _importlib
|
||||||
|
for symbol, symbol_loc_info in _PUBLIC_APIS.items():
|
||||||
|
if symbol_loc_info[0]:
|
||||||
|
attr = getattr(_importlib.import_module(symbol_loc_info[0]), symbol_loc_info[1])
|
||||||
|
else:
|
||||||
|
attr = _importlib.import_module(symbol_loc_info[1])
|
||||||
|
setattr(_sys.modules[__name__], symbol, attr)
|
||||||
|
try:
|
||||||
|
del symbol
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
del symbol_loc_info
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
del attr
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -76,17 +109,7 @@ def format_import(source_module_name, source_name, dest_name):
|
|||||||
Returns:
|
Returns:
|
||||||
An import statement string.
|
An import statement string.
|
||||||
"""
|
"""
|
||||||
if source_module_name:
|
return " '%s': ('%s', '%s')," % (dest_name, source_module_name, source_name)
|
||||||
if source_name == dest_name:
|
|
||||||
return 'from %s import %s' % (source_module_name, source_name)
|
|
||||||
else:
|
|
||||||
return 'from %s import %s as %s' % (
|
|
||||||
source_module_name, source_name, dest_name)
|
|
||||||
else:
|
|
||||||
if source_name == dest_name:
|
|
||||||
return 'import %s' % source_name
|
|
||||||
else:
|
|
||||||
return 'import %s as %s' % (source_name, dest_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_canonical_import(import_set):
|
def get_canonical_import(import_set):
|
||||||
@ -129,7 +152,6 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
lambda: collections.defaultdict(set))
|
lambda: collections.defaultdict(set))
|
||||||
self._dest_import_to_id = collections.defaultdict(int)
|
self._dest_import_to_id = collections.defaultdict(int)
|
||||||
# Names that start with underscore in the root module.
|
# Names that start with underscore in the root module.
|
||||||
self._underscore_names_in_root = []
|
|
||||||
self._api_version = api_version
|
self._api_version = api_version
|
||||||
|
|
||||||
def _check_already_imported(self, symbol_id, api_name):
|
def _check_already_imported(self, symbol_id, api_name):
|
||||||
@ -166,9 +188,6 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
symbol_id = -1 if not symbol else id(symbol)
|
symbol_id = -1 if not symbol else id(symbol)
|
||||||
self._check_already_imported(symbol_id, full_api_name)
|
self._check_already_imported(symbol_id, full_api_name)
|
||||||
|
|
||||||
if not dest_module_name and dest_name.startswith('_'):
|
|
||||||
self._underscore_names_in_root.append(dest_name)
|
|
||||||
|
|
||||||
# The same symbol can be available in multiple modules.
|
# The same symbol can be available in multiple modules.
|
||||||
# We store all possible ways of importing this symbol and later pick just
|
# We store all possible ways of importing this symbol and later pick just
|
||||||
# one.
|
# one.
|
||||||
@ -197,11 +216,13 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
submodule = module_split[submodule_index-1]
|
submodule = module_split[submodule_index-1]
|
||||||
parent_module += '.' + submodule if parent_module else submodule
|
parent_module += '.' + submodule if parent_module else submodule
|
||||||
import_from = self._output_package
|
import_from = self._output_package
|
||||||
if submodule_index > 0:
|
import_from += '.' + '.'.join(module_split[:submodule_index + 1])
|
||||||
import_from += '.' + '.'.join(module_split[:submodule_index])
|
|
||||||
self.add_import(
|
self.add_import(
|
||||||
None, import_from, module_split[submodule_index],
|
symbol=None,
|
||||||
parent_module, module_split[submodule_index])
|
source_module_name='',
|
||||||
|
source_name=import_from,
|
||||||
|
dest_module_name=parent_module,
|
||||||
|
dest_name=module_split[submodule_index])
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
"""Get a map from destination module to __init__.py code for that module.
|
"""Get a map from destination module to __init__.py code for that module.
|
||||||
@ -221,26 +242,20 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
get_canonical_import(imports)
|
get_canonical_import(imports)
|
||||||
for _, imports in dest_name_to_imports.items()
|
for _, imports in dest_name_to_imports.items()
|
||||||
]
|
]
|
||||||
module_text_map[dest_module] = '\n'.join(sorted(imports_list))
|
module_text_map[dest_module] = _MODULE_TEXT_TEMPLATE % '\n'.join(
|
||||||
|
sorted(imports_list))
|
||||||
|
|
||||||
# Expose exported symbols with underscores in root module
|
|
||||||
# since we import from it using * import.
|
|
||||||
underscore_names_str = ', '.join(
|
|
||||||
'\'%s\'' % name for name in self._underscore_names_in_root)
|
|
||||||
# We will always generate a root __init__.py file to let us handle *
|
|
||||||
# imports consistently. Be sure to have a root __init__.py file listed in
|
|
||||||
# the script outputs.
|
|
||||||
module_text_map[''] = module_text_map.get('', '') + '''
|
|
||||||
_names_with_underscore = [%s]
|
|
||||||
__all__ = [_s for _s in dir() if not _s.startswith('_')]
|
|
||||||
__all__.extend([_s for _s in _names_with_underscore])
|
|
||||||
''' % underscore_names_str
|
|
||||||
|
|
||||||
if self._api_version == 1: # Add 1.* deprecations.
|
|
||||||
for dest_module, _ in self._module_imports.items():
|
for dest_module, _ in self._module_imports.items():
|
||||||
|
deprecation = 'False'
|
||||||
|
has_lite = 'False'
|
||||||
|
if self._api_version == 1: # Add 1.* deprecations.
|
||||||
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
|
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
|
||||||
|
deprecation = 'True'
|
||||||
|
# Workaround to make sure not load lite from lite/__init__.py
|
||||||
|
if not dest_module and 'lite' in self._module_imports:
|
||||||
|
has_lite = 'True'
|
||||||
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
|
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
|
||||||
dest_module)
|
dest_module, deprecation, has_lite)
|
||||||
|
|
||||||
return module_text_map, footer_text_map
|
return module_text_map, footer_text_map
|
||||||
|
|
||||||
@ -519,6 +534,10 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
|||||||
_GENERATED_FILE_HEADER % get_module_docstring(
|
_GENERATED_FILE_HEADER % get_module_docstring(
|
||||||
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
|
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
|
||||||
if module in deprecation_footer_map:
|
if module in deprecation_footer_map:
|
||||||
|
if '# WRAPPER_PLACEHOLDER' in contents:
|
||||||
|
contents = contents.replace('# WRAPPER_PLACEHOLDER',
|
||||||
|
deprecation_footer_map[module])
|
||||||
|
else:
|
||||||
contents += deprecation_footer_map[module]
|
contents += deprecation_footer_map[module]
|
||||||
with open(module_name_to_file_path[module], 'w') as fp:
|
with open(module_name_to_file_path[module], 'w') as fp:
|
||||||
fp.write(contents)
|
fp.write(contents)
|
||||||
|
@ -67,15 +67,16 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
api_version=1)
|
api_version=1)
|
||||||
expected_import = (
|
expected_import = ('\'test_op1\': '
|
||||||
'from tensorflow.python.test_module '
|
'(\'tensorflow.python.test_module\','
|
||||||
'import test_op as test_op1')
|
' \'test_op\')')
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
expected_import in str(imports),
|
expected_import in str(imports),
|
||||||
msg='%s not in %s' % (expected_import, str(imports)))
|
msg='%s not in %s' % (expected_import, str(imports)))
|
||||||
|
|
||||||
expected_import = ('from tensorflow.python.test_module '
|
expected_import = ('\'test_op\': '
|
||||||
'import test_op')
|
'(\'tensorflow.python.test_module\','
|
||||||
|
' \'test_op\')')
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
expected_import in str(imports),
|
expected_import in str(imports),
|
||||||
msg='%s not in %s' % (expected_import, str(imports)))
|
msg='%s not in %s' % (expected_import, str(imports)))
|
||||||
@ -89,8 +90,10 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
api_version=2)
|
api_version=2)
|
||||||
expected_import = ('from tensorflow.python.test_module '
|
expected_import = (
|
||||||
'import TestClass')
|
'\'NewTestClass\':'
|
||||||
|
' (\'tensorflow.python.test_module\','
|
||||||
|
' \'TestClass\')')
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
'TestClass' in str(imports),
|
'TestClass' in str(imports),
|
||||||
msg='%s not in %s' % (expected_import, str(imports)))
|
msg='%s not in %s' % (expected_import, str(imports)))
|
||||||
@ -101,8 +104,9 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
api_version=1)
|
api_version=1)
|
||||||
expected = ('from tensorflow.python.test_module '
|
expected = ('\'_TEST_CONSTANT\':'
|
||||||
'import _TEST_CONSTANT')
|
' (\'tensorflow.python.test_module\','
|
||||||
|
' \'_TEST_CONSTANT\')')
|
||||||
self.assertTrue(expected in str(imports),
|
self.assertTrue(expected in str(imports),
|
||||||
msg='%s not in %s' % (expected, str(imports)))
|
msg='%s not in %s' % (expected, str(imports)))
|
||||||
|
|
||||||
|
@ -12,138 +12,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Provides wrapper for TensorFlow modules to support deprecation messages.
|
"""Compatibility wrapper for TensorFlow modules to support deprecation messages.
|
||||||
|
|
||||||
TODO(annarev): potentially merge with LazyLoader.
|
Please use module_wrapper instead.
|
||||||
|
TODO(yifeif): remove once no longer referred by estimator
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import sys
|
from tensorflow.python.util import module_wrapper
|
||||||
import types
|
|
||||||
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
# For backward compatibility for other pip packages that use this class.
|
||||||
from tensorflow.python.util import tf_decorator
|
DeprecationWrapper = module_wrapper.TFModuleWrapper
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
from tensorflow.python.util import tf_stack
|
|
||||||
from tensorflow.tools.compatibility import all_renames_v2
|
|
||||||
|
|
||||||
|
|
||||||
_PER_MODULE_WARNING_LIMIT = 1
|
|
||||||
|
|
||||||
|
|
||||||
def get_rename_v2(name):
|
|
||||||
if name not in all_renames_v2.symbol_renames:
|
|
||||||
return None
|
|
||||||
return all_renames_v2.symbol_renames[name]
|
|
||||||
|
|
||||||
|
|
||||||
def _call_location():
|
|
||||||
# We want to get stack frame 2 frames up from current frame,
|
|
||||||
# i.e. above _getattr__ and _call_location calls.
|
|
||||||
stack = tf_stack.extract_stack_file_and_line(max_length=3)
|
|
||||||
if not stack: # should never happen as we're in a function
|
|
||||||
return 'UNKNOWN'
|
|
||||||
frame = stack[0]
|
|
||||||
return '{}:{}'.format(frame.file, frame.line)
|
|
||||||
|
|
||||||
|
|
||||||
def contains_deprecation_decorator(decorators):
|
|
||||||
return any(
|
|
||||||
d.decorator_name == 'deprecated' for d in decorators)
|
|
||||||
|
|
||||||
|
|
||||||
def has_deprecation_decorator(symbol):
|
|
||||||
"""Checks if given object has a deprecation decorator.
|
|
||||||
|
|
||||||
We check if deprecation decorator is in decorators as well as
|
|
||||||
whether symbol is a class whose __init__ method has a deprecation
|
|
||||||
decorator.
|
|
||||||
Args:
|
|
||||||
symbol: Python object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if symbol has deprecation decorator.
|
|
||||||
"""
|
|
||||||
decorators, symbol = tf_decorator.unwrap(symbol)
|
|
||||||
if contains_deprecation_decorator(decorators):
|
|
||||||
return True
|
|
||||||
if tf_inspect.isfunction(symbol):
|
|
||||||
return False
|
|
||||||
if not tf_inspect.isclass(symbol):
|
|
||||||
return False
|
|
||||||
if not hasattr(symbol, '__init__'):
|
|
||||||
return False
|
|
||||||
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
|
|
||||||
return contains_deprecation_decorator(init_decorators)
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecationWrapper(types.ModuleType):
|
|
||||||
"""Wrapper for TensorFlow modules to support deprecation messages."""
|
|
||||||
|
|
||||||
def __init__(self, wrapped, module_name): # pylint: disable=super-on-old-class
|
|
||||||
super(DeprecationWrapper, self).__init__(wrapped.__name__)
|
|
||||||
self.__dict__.update(wrapped.__dict__)
|
|
||||||
# Prefix all local attributes with _dw_ so that we can
|
|
||||||
# handle them differently in attribute access methods.
|
|
||||||
self._dw_wrapped_module = wrapped
|
|
||||||
self._dw_module_name = module_name
|
|
||||||
# names we already checked for deprecation
|
|
||||||
self._dw_deprecated_checked = set()
|
|
||||||
self._dw_warning_count = 0
|
|
||||||
|
|
||||||
def __getattribute__(self, name): # pylint: disable=super-on-old-class
|
|
||||||
attr = super(DeprecationWrapper, self).__getattribute__(name)
|
|
||||||
if name.startswith('__') or name.startswith('_dw_'):
|
|
||||||
return attr
|
|
||||||
|
|
||||||
if (self._dw_warning_count < _PER_MODULE_WARNING_LIMIT and
|
|
||||||
name not in self._dw_deprecated_checked):
|
|
||||||
|
|
||||||
self._dw_deprecated_checked.add(name)
|
|
||||||
|
|
||||||
if self._dw_module_name:
|
|
||||||
full_name = 'tf.%s.%s' % (self._dw_module_name, name)
|
|
||||||
else:
|
|
||||||
full_name = 'tf.%s' % name
|
|
||||||
rename = get_rename_v2(full_name)
|
|
||||||
if rename and not has_deprecation_decorator(attr):
|
|
||||||
call_location = _call_location()
|
|
||||||
# skip locations in Python source
|
|
||||||
if not call_location.startswith('<'):
|
|
||||||
logging.warning(
|
|
||||||
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
|
||||||
_call_location(), full_name, rename)
|
|
||||||
self._dw_warning_count += 1
|
|
||||||
return attr
|
|
||||||
|
|
||||||
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
|
|
||||||
if arg.startswith('_dw_'):
|
|
||||||
super(DeprecationWrapper, self).__setattr__(arg, val)
|
|
||||||
else:
|
|
||||||
setattr(self._dw_wrapped_module, arg, val)
|
|
||||||
self.__dict__[arg] = val
|
|
||||||
|
|
||||||
def __dir__(self):
|
|
||||||
return dir(self._dw_wrapped_module)
|
|
||||||
|
|
||||||
def __delattr__(self, name): # pylint: disable=super-on-old-class
|
|
||||||
if name.startswith('_dw_'):
|
|
||||||
super(DeprecationWrapper, self).__delattr__(name)
|
|
||||||
else:
|
|
||||||
delattr(self._dw_wrapped_module, name)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self._dw_wrapped_module.__repr__()
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
return self.__name__
|
|
||||||
|
|
||||||
def __setstate__(self, d):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
self.__init__(
|
|
||||||
sys.modules[d]._dw_wrapped_module,
|
|
||||||
sys.modules[d]._dw_module_name)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for tensorflow.python.util.deprecation_wrapper."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import types
|
|
||||||
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
|
||||||
from tensorflow.python.util import deprecation_wrapper
|
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
from tensorflow.tools.compatibility import all_renames_v2
|
|
||||||
|
|
||||||
deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
|
||||||
|
|
||||||
|
|
||||||
class MockModule(types.ModuleType):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecationWrapperTest(test.TestCase):
|
|
||||||
|
|
||||||
def testWrapperIsAModule(self):
|
|
||||||
module = MockModule('test')
|
|
||||||
wrapped_module = deprecation_wrapper.DeprecationWrapper(
|
|
||||||
module, 'test')
|
|
||||||
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
|
||||||
|
|
||||||
@test.mock.patch.object(logging, 'warning', autospec=True)
|
|
||||||
def testDeprecationWarnings(self, mock_warning):
|
|
||||||
module = MockModule('test')
|
|
||||||
module.foo = 1
|
|
||||||
module.bar = 2
|
|
||||||
module.baz = 3
|
|
||||||
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
|
|
||||||
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
|
|
||||||
|
|
||||||
wrapped_module = deprecation_wrapper.DeprecationWrapper(
|
|
||||||
module, 'test')
|
|
||||||
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
|
||||||
|
|
||||||
self.assertEqual(0, mock_warning.call_count)
|
|
||||||
bar = wrapped_module.bar
|
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
|
||||||
foo = wrapped_module.foo
|
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
|
||||||
baz = wrapped_module.baz
|
|
||||||
self.assertEqual(2, mock_warning.call_count)
|
|
||||||
baz = wrapped_module.baz
|
|
||||||
self.assertEqual(2, mock_warning.call_count)
|
|
||||||
|
|
||||||
# Check that values stayed the same
|
|
||||||
self.assertEqual(module.foo, foo)
|
|
||||||
self.assertEqual(module.bar, bar)
|
|
||||||
self.assertEqual(module.baz, baz)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test.main()
|
|
205
tensorflow/python/util/module_wrapper.py
Normal file
205
tensorflow/python/util/module_wrapper.py
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Provides wrapper for TensorFlow modules."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.util import tf_decorator
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
from tensorflow.python.util import tf_stack
|
||||||
|
from tensorflow.tools.compatibility import all_renames_v2
|
||||||
|
|
||||||
|
|
||||||
|
_PER_MODULE_WARNING_LIMIT = 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_rename_v2(name):
|
||||||
|
if name not in all_renames_v2.symbol_renames:
|
||||||
|
return None
|
||||||
|
return all_renames_v2.symbol_renames[name]
|
||||||
|
|
||||||
|
|
||||||
|
def _call_location():
|
||||||
|
# We want to get stack frame 2 frames up from current frame,
|
||||||
|
# i.e. above _getattr__ and _call_location calls.
|
||||||
|
stack = tf_stack.extract_stack_file_and_line(max_length=3)
|
||||||
|
if not stack: # should never happen as we're in a function
|
||||||
|
return 'UNKNOWN'
|
||||||
|
frame = stack[0]
|
||||||
|
return '{}:{}'.format(frame.file, frame.line)
|
||||||
|
|
||||||
|
|
||||||
|
def contains_deprecation_decorator(decorators):
|
||||||
|
return any(
|
||||||
|
d.decorator_name == 'deprecated' for d in decorators)
|
||||||
|
|
||||||
|
|
||||||
|
def has_deprecation_decorator(symbol):
|
||||||
|
"""Checks if given object has a deprecation decorator.
|
||||||
|
|
||||||
|
We check if deprecation decorator is in decorators as well as
|
||||||
|
whether symbol is a class whose __init__ method has a deprecation
|
||||||
|
decorator.
|
||||||
|
Args:
|
||||||
|
symbol: Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if symbol has deprecation decorator.
|
||||||
|
"""
|
||||||
|
decorators, symbol = tf_decorator.unwrap(symbol)
|
||||||
|
if contains_deprecation_decorator(decorators):
|
||||||
|
return True
|
||||||
|
if tf_inspect.isfunction(symbol):
|
||||||
|
return False
|
||||||
|
if not tf_inspect.isclass(symbol):
|
||||||
|
return False
|
||||||
|
if not hasattr(symbol, '__init__'):
|
||||||
|
return False
|
||||||
|
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
|
||||||
|
return contains_deprecation_decorator(init_decorators)
|
||||||
|
|
||||||
|
|
||||||
|
class TFModuleWrapper(types.ModuleType):
|
||||||
|
"""Wrapper for TF modules to support deprecation messages and lazyloading."""
|
||||||
|
|
||||||
|
def __init__( # pylint: disable=super-on-old-class
|
||||||
|
self,
|
||||||
|
wrapped,
|
||||||
|
module_name,
|
||||||
|
public_apis=None,
|
||||||
|
deprecation=True,
|
||||||
|
has_lite=False): # pylint: enable=super-on-old-class
|
||||||
|
super(TFModuleWrapper, self).__init__(wrapped.__name__)
|
||||||
|
self.__dict__.update(wrapped.__dict__)
|
||||||
|
# Prefix all local attributes with _tfmw_ so that we can
|
||||||
|
# handle them differently in attribute access methods.
|
||||||
|
self._tfmw_wrapped_module = wrapped
|
||||||
|
self._tfmw_module_name = module_name
|
||||||
|
self._tfmw_public_apis = public_apis
|
||||||
|
self._tfmw_print_deprecation_warnings = deprecation
|
||||||
|
self._tfmw_has_lite = has_lite
|
||||||
|
# Set __all__ so that import * work for lazy loaded modules
|
||||||
|
if self._tfmw_public_apis:
|
||||||
|
self._tfmw_wrapped_module.__all__ = list(self._tfmw_public_apis.keys())
|
||||||
|
self.__all__ = list(self._tfmw_public_apis.keys())
|
||||||
|
# names we already checked for deprecation
|
||||||
|
self._tfmw_deprecated_checked = set()
|
||||||
|
self._tfmw_warning_count = 0
|
||||||
|
|
||||||
|
def _tfmw_add_deprecation_warning(self, name, attr):
|
||||||
|
"""Print deprecation warning for attr with given name if necessary."""
|
||||||
|
if (self._tfmw_warning_count < _PER_MODULE_WARNING_LIMIT and
|
||||||
|
name not in self._tfmw_deprecated_checked):
|
||||||
|
|
||||||
|
self._tfmw_deprecated_checked.add(name)
|
||||||
|
|
||||||
|
if self._tfmw_module_name:
|
||||||
|
full_name = 'tf.%s.%s' % (self._tfmw_module_name, name)
|
||||||
|
else:
|
||||||
|
full_name = 'tf.%s' % name
|
||||||
|
rename = get_rename_v2(full_name)
|
||||||
|
if rename and not has_deprecation_decorator(attr):
|
||||||
|
call_location = _call_location()
|
||||||
|
# skip locations in Python source
|
||||||
|
if not call_location.startswith('<'):
|
||||||
|
logging.warning(
|
||||||
|
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
||||||
|
_call_location(), full_name, rename)
|
||||||
|
self._tfmw_warning_count += 1
|
||||||
|
|
||||||
|
def _tfmw_import_module(self, name):
|
||||||
|
symbol_loc_info = self._tfmw_public_apis[name]
|
||||||
|
if symbol_loc_info[0]:
|
||||||
|
module = importlib.import_module(symbol_loc_info[0])
|
||||||
|
attr = getattr(module, symbol_loc_info[1])
|
||||||
|
else:
|
||||||
|
attr = importlib.import_module(symbol_loc_info[1])
|
||||||
|
setattr(self._tfmw_wrapped_module, name, attr)
|
||||||
|
self.__dict__[name] = attr
|
||||||
|
return attr
|
||||||
|
|
||||||
|
def __getattribute__(self, name): # pylint: disable=super-on-old-class
|
||||||
|
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
|
||||||
|
if name == 'lite':
|
||||||
|
if self._tfmw_has_lite:
|
||||||
|
attr = self._tfmw_import_module(name)
|
||||||
|
setattr(self._tfmw_wrapped_module, 'lite', attr)
|
||||||
|
return attr
|
||||||
|
|
||||||
|
attr = super(TFModuleWrapper, self).__getattribute__(name)
|
||||||
|
if name.startswith('__') or name.startswith('_tfmw_'):
|
||||||
|
return attr
|
||||||
|
|
||||||
|
if self._tfmw_print_deprecation_warnings:
|
||||||
|
self._tfmw_add_deprecation_warning(name, attr)
|
||||||
|
return attr
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
attr = getattr(self._tfmw_wrapped_module, name)
|
||||||
|
except AttributeError as e:
|
||||||
|
if not self._tfmw_public_apis:
|
||||||
|
raise e
|
||||||
|
if name not in self._tfmw_public_apis:
|
||||||
|
raise e
|
||||||
|
attr = self._tfmw_import_module(name)
|
||||||
|
|
||||||
|
if self._tfmw_print_deprecation_warnings:
|
||||||
|
self._tfmw_add_deprecation_warning(name, attr)
|
||||||
|
return attr
|
||||||
|
|
||||||
|
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
|
||||||
|
if not arg.startswith('_tfmw_'):
|
||||||
|
setattr(self._tfmw_wrapped_module, arg, val)
|
||||||
|
self.__dict__[arg] = val
|
||||||
|
if arg not in self.__all__ and arg != '__all__':
|
||||||
|
self.__all__.append(arg)
|
||||||
|
super(TFModuleWrapper, self).__setattr__(arg, val)
|
||||||
|
|
||||||
|
def __dir__(self):
|
||||||
|
if self._tfmw_public_apis:
|
||||||
|
return list(
|
||||||
|
set(self._tfmw_public_apis.keys()).union(
|
||||||
|
set([
|
||||||
|
attr for attr in dir(self._tfmw_wrapped_module)
|
||||||
|
if not attr.startswith('_')
|
||||||
|
])))
|
||||||
|
else:
|
||||||
|
return dir(self._tfmw_wrapped_module)
|
||||||
|
|
||||||
|
def __delattr__(self, name): # pylint: disable=super-on-old-class
|
||||||
|
if name.startswith('_tfmw_'):
|
||||||
|
super(TFModuleWrapper, self).__delattr__(name)
|
||||||
|
else:
|
||||||
|
delattr(self._tfmw_wrapped_module, name)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self._tfmw_wrapped_module.__repr__()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.__name__
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.__init__(sys.modules[d]._tfmw_wrapped_module,
|
||||||
|
sys.modules[d]._tfmw_module_name)
|
||||||
|
# pylint: enable=protected-access
|
136
tensorflow/python/util/module_wrapper_test.py
Normal file
136
tensorflow/python/util/module_wrapper_test.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.python.util.module_wrapper."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import types
|
||||||
|
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.util import module_wrapper
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
from tensorflow.tools.compatibility import all_renames_v2
|
||||||
|
|
||||||
|
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
|
class MockModule(types.ModuleType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecationWrapperTest(test.TestCase):
|
||||||
|
|
||||||
|
def testWrapperIsAModule(self):
|
||||||
|
module = MockModule('test')
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
|
||||||
|
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
||||||
|
|
||||||
|
@test.mock.patch.object(logging, 'warning', autospec=True)
|
||||||
|
def testDeprecationWarnings(self, mock_warning):
|
||||||
|
module = MockModule('test')
|
||||||
|
module.foo = 1
|
||||||
|
module.bar = 2
|
||||||
|
module.baz = 3
|
||||||
|
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
|
||||||
|
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
|
||||||
|
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
|
||||||
|
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
||||||
|
|
||||||
|
self.assertEqual(0, mock_warning.call_count)
|
||||||
|
bar = wrapped_module.bar
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
foo = wrapped_module.foo
|
||||||
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
|
baz = wrapped_module.baz # pylint: disable=unused-variable
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
baz = wrapped_module.baz
|
||||||
|
self.assertEqual(2, mock_warning.call_count)
|
||||||
|
|
||||||
|
# Check that values stayed the same
|
||||||
|
self.assertEqual(module.foo, foo)
|
||||||
|
self.assertEqual(module.bar, bar)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoadingWrapperTest(test.TestCase):
|
||||||
|
|
||||||
|
def testLazyLoad(self):
|
||||||
|
module = MockModule('test')
|
||||||
|
apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')}
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||||
|
module, 'test', public_apis=apis, deprecation=False)
|
||||||
|
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||||
|
from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member
|
||||||
|
self.assertEqual(wrapped_module.cmd, _cmd)
|
||||||
|
self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
|
||||||
|
|
||||||
|
def testLazyLoadLocalOverride(self):
|
||||||
|
# Test that we can override and add fields to the wrapped module.
|
||||||
|
module = MockModule('test')
|
||||||
|
apis = {'cmd': ('', 'cmd')}
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||||
|
module, 'test', public_apis=apis, deprecation=False)
|
||||||
|
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||||
|
self.assertEqual(wrapped_module.cmd, _cmd)
|
||||||
|
setattr(wrapped_module, 'cmd', 1)
|
||||||
|
setattr(wrapped_module, 'cgi', 2)
|
||||||
|
self.assertEqual(wrapped_module.cmd, 1) # override
|
||||||
|
self.assertEqual(wrapped_module.cgi, 2) # add
|
||||||
|
|
||||||
|
def testLazyLoadDict(self):
|
||||||
|
# Test that we can override and add fields to the wrapped module.
|
||||||
|
module = MockModule('test')
|
||||||
|
apis = {'cmd': ('', 'cmd')}
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||||
|
module, 'test', public_apis=apis, deprecation=False)
|
||||||
|
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||||
|
# At first cmd key does not exist in __dict__
|
||||||
|
self.assertNotIn('cmd', wrapped_module.__dict__)
|
||||||
|
# After it is referred (lazyloaded), it gets added to __dict__
|
||||||
|
wrapped_module.cmd # pylint: disable=pointless-statement
|
||||||
|
self.assertEqual(wrapped_module.__dict__['cmd'], _cmd)
|
||||||
|
# When we call setattr, it also gets added to __dict__
|
||||||
|
setattr(wrapped_module, 'cmd2', _cmd)
|
||||||
|
self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd)
|
||||||
|
|
||||||
|
def testLazyLoadWildcardImport(self):
|
||||||
|
# Test that public APIs are in __all__.
|
||||||
|
module = MockModule('test')
|
||||||
|
module._should_not_be_public = 5
|
||||||
|
apis = {'cmd': ('', 'cmd')}
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||||
|
module, 'test', public_apis=apis, deprecation=False)
|
||||||
|
setattr(wrapped_module, 'hello', 1)
|
||||||
|
self.assertIn('hello', wrapped_module.__all__)
|
||||||
|
self.assertIn('cmd', wrapped_module.__all__)
|
||||||
|
self.assertNotIn('_should_not_be_public', wrapped_module.__all__)
|
||||||
|
|
||||||
|
def testLazyLoadCorrectLiteModule(self):
|
||||||
|
# If set, always load lite module from public API list.
|
||||||
|
module = MockModule('test')
|
||||||
|
apis = {'lite': ('', 'cmd')}
|
||||||
|
module.lite = 5
|
||||||
|
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||||
|
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||||
|
module, 'test', public_apis=apis, deprecation=False, has_lite=True)
|
||||||
|
self.assertEqual(wrapped_module.lite, _cmd)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -23,9 +23,9 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import deprecation_wrapper
|
from tensorflow.python.util import module_wrapper
|
||||||
|
|
||||||
deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
class DeprecationTest(test.TestCase):
|
class DeprecationTest(test.TestCase):
|
||||||
@ -38,9 +38,8 @@ class DeprecationTest(test.TestCase):
|
|||||||
|
|
||||||
tf.tables_initializer()
|
tf.tables_initializer()
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||||
mock_warning.call_args[0][1],
|
"module_wrapper.py:")
|
||||||
"deprecation_test.py:")
|
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
mock_warning.call_args[0][2], r"tables_initializer")
|
mock_warning.call_args[0][2], r"tables_initializer")
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
@ -60,9 +59,8 @@ class DeprecationTest(test.TestCase):
|
|||||||
|
|
||||||
tf.ragged.RaggedTensorValue(value, row_splits)
|
tf.ragged.RaggedTensorValue(value, row_splits)
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||||
mock_warning.call_args[0][1],
|
"module_wrapper.py:")
|
||||||
"deprecation_test.py:")
|
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
mock_warning.call_args[0][2], r"ragged.RaggedTensorValue")
|
mock_warning.call_args[0][2], r"ragged.RaggedTensorValue")
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
@ -84,9 +82,8 @@ class DeprecationTest(test.TestCase):
|
|||||||
|
|
||||||
tf.sparse_mask(array, mask_indices)
|
tf.sparse_mask(array, mask_indices)
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||||
mock_warning.call_args[0][1],
|
"module_wrapper.py:")
|
||||||
"deprecation_test.py:")
|
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
mock_warning.call_args[0][2], r"sparse_mask")
|
mock_warning.call_args[0][2], r"sparse_mask")
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
@ -103,9 +100,8 @@ class DeprecationTest(test.TestCase):
|
|||||||
|
|
||||||
tf.VarLenFeature(tf.dtypes.int32)
|
tf.VarLenFeature(tf.dtypes.int32)
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||||
mock_warning.call_args[0][1],
|
"module_wrapper.py:")
|
||||||
"deprecation_test.py:")
|
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
mock_warning.call_args[0][2], r"VarLenFeature")
|
mock_warning.call_args[0][2], r"VarLenFeature")
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
@ -122,9 +118,8 @@ class DeprecationTest(test.TestCase):
|
|||||||
|
|
||||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
|
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
|
||||||
self.assertEqual(1, mock_warning.call_count)
|
self.assertEqual(1, mock_warning.call_count)
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||||
mock_warning.call_args[0][1],
|
"module_wrapper.py:")
|
||||||
"deprecation_test.py:")
|
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
mock_warning.call_args[0][2],
|
mock_warning.call_args[0][2],
|
||||||
r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY")
|
r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY")
|
||||||
|
@ -38,6 +38,11 @@ class ModuleTest(test.TestCase):
|
|||||||
|
|
||||||
def testDict(self):
|
def testDict(self):
|
||||||
# Check that a few modules are in __dict__.
|
# Check that a few modules are in __dict__.
|
||||||
|
# pylint: disable=pointless-statement
|
||||||
|
tf.nn
|
||||||
|
tf.keras
|
||||||
|
tf.image
|
||||||
|
# pylint: enable=pointless-statement
|
||||||
self.assertIn('nn', tf.__dict__)
|
self.assertIn('nn', tf.__dict__)
|
||||||
self.assertIn('keras', tf.__dict__)
|
self.assertIn('keras', tf.__dict__)
|
||||||
self.assertIn('image', tf.__dict__)
|
self.assertIn('image', tf.__dict__)
|
||||||
|
@ -178,15 +178,11 @@ function prepare_src() {
|
|||||||
#
|
#
|
||||||
# import tensorflow as tf
|
# import tensorflow as tf
|
||||||
#
|
#
|
||||||
# which is not ok. We are removing the deprecation stuff by using sed and
|
# which is not ok. We disable deprecation by using sed to toggle the flag
|
||||||
# deleting the pattern that the wrapper uses (all lines between a line ending
|
|
||||||
# with _deprecation_wrapper -- the import line -- and a line containing
|
|
||||||
# _sys.modules[__name__] as the argument of a function -- the last line in
|
|
||||||
# the deprecation autogenerated pattern)
|
|
||||||
# TODO(mihaimaruseac): When we move the API to root, remove this hack
|
# TODO(mihaimaruseac): When we move the API to root, remove this hack
|
||||||
# Note: Can't do in place sed that works on all OS, so use a temp file instead
|
# Note: Can't do in place sed that works on all OS, so use a temp file instead
|
||||||
sed \
|
sed \
|
||||||
"/_deprecation_wrapper$/,/_sys.modules[__name__],/ d" \
|
"s/deprecation=True/deprecation=False/g" \
|
||||||
"${TMPDIR}/tensorflow_core/__init__.py" > "${TMPDIR}/tensorflow_core/__init__.out"
|
"${TMPDIR}/tensorflow_core/__init__.py" > "${TMPDIR}/tensorflow_core/__init__.out"
|
||||||
mv "${TMPDIR}/tensorflow_core/__init__.out" "${TMPDIR}/tensorflow_core/__init__.py"
|
mv "${TMPDIR}/tensorflow_core/__init__.out" "${TMPDIR}/tensorflow_core/__init__.py"
|
||||||
}
|
}
|
||||||
|
@ -98,28 +98,6 @@ for _m in _top_level_modules:
|
|||||||
# We still need all the names that are toplevel on tensorflow_core
|
# We still need all the names that are toplevel on tensorflow_core
|
||||||
from tensorflow_core import *
|
from tensorflow_core import *
|
||||||
|
|
||||||
# We also need to bring in keras if available in tensorflow_core
|
|
||||||
# Above import * doesn't import it as __all__ is updated before keras is hooked
|
|
||||||
try:
|
|
||||||
from tensorflow_core import keras
|
|
||||||
except ImportError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Similarly for estimator, but only if this file is not read via a
|
|
||||||
# import tensorflow_estimator (same reasoning as above when forwarding estimator
|
|
||||||
# separatedly from the rest of the top level modules)
|
|
||||||
if not _root_estimator:
|
|
||||||
try:
|
|
||||||
from tensorflow_core import estimator
|
|
||||||
except ImportError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# And again for tensorboard (comes as summary)
|
|
||||||
try:
|
|
||||||
from tensorflow_core import summary
|
|
||||||
except ImportError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# In V1 API we need to print deprecation messages
|
# In V1 API we need to print deprecation messages
|
||||||
from tensorflow.python.util import deprecation_wrapper as _deprecation
|
from tensorflow.python.util import deprecation_wrapper as _deprecation
|
||||||
if not isinstance(_sys.modules[__name__], _deprecation.DeprecationWrapper):
|
if not isinstance(_sys.modules[__name__], _deprecation.DeprecationWrapper):
|
||||||
|
@ -97,32 +97,4 @@ for _m in _top_level_modules:
|
|||||||
# We still need all the names that are toplevel on tensorflow_core
|
# We still need all the names that are toplevel on tensorflow_core
|
||||||
from tensorflow_core import *
|
from tensorflow_core import *
|
||||||
|
|
||||||
# We also need to bring in keras if available in tensorflow_core
|
|
||||||
# Above import * doesn't import it as __all__ is updated before keras is hooked
|
|
||||||
try:
|
|
||||||
from tensorflow_core import keras
|
|
||||||
except ImportError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Similarly for estimator, but only if this file is not read via a
|
|
||||||
# import tensorflow_estimator (same reasoning as above when forwarding estimator
|
|
||||||
# separatedly from the rest of the top level modules)
|
|
||||||
if not _root_estimator:
|
|
||||||
try:
|
|
||||||
from tensorflow_core import estimator
|
|
||||||
except ImportError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# And again for tensorboard (comes as summary)
|
|
||||||
try:
|
|
||||||
from tensorflow_core import summary
|
|
||||||
except ImportError as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Also import module aliases
|
|
||||||
try:
|
|
||||||
from tensorflow_core import losses, metrics, initializers, optimizers
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)
|
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)
|
||||||
|
Loading…
Reference in New Issue
Block a user