Add hooks to allow lazyLoad TensorFlow public API.

- Updates existing DeprecationWrapper with the ability to import modules only when they are referred.
- Updates how TensorFlow generates public API. Wraps all generated TensorFlow __init__.py modules with this enhanced wrapper.

To enable lazy-loading in the future, toggle _LAZY_LOADING flag in create_python_api.py.

Once lazy loading is enabled, the wrapper will have the following behaviors:
- dir() will always return module?s attributes.
- __all__ will always return all public APIs.
- __dict__ will be populated as attributes are being referred.
- After wrapper instance is created, to add more attributes, use setattr(import does not explicitly call setattr) to make sure dir, __all__, __dict__ are updated.
- import * will work as expected.

Built and tested with pip package.

PiperOrigin-RevId: 257240535
This commit is contained in:
Yifei Feng 2019-07-09 11:50:54 -07:00 committed by TensorFlower Gardener
parent b34ed5bd1e
commit 7ece5ce95f
16 changed files with 490 additions and 365 deletions

View File

@ -41,12 +41,15 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__]
_API_MODULE = sys.modules[__name__].bitwise # pylint: disable=undefined-variable
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
_current_module = _sys.modules[__name__]
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
@ -57,6 +60,7 @@ try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
setattr(_current_module, "summary", summary)
except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
@ -65,6 +69,7 @@ try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
@ -72,6 +77,7 @@ try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
@ -122,25 +128,17 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable
try:
del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
pass
# Similarly for compiler. Do it separately to make sure we do this even if the
# others don't exist.
try:
del core
except NameError:
pass
try:
del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError:
pass
# pylint: enable=undefined-variable
# Add module aliases
if hasattr(_current_module, 'keras'):
@ -148,6 +146,10 @@ if hasattr(_current_module, 'keras'):
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
compat.v2.compat.v1 = compat.v1
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
# pylint: enable=undefined-variable

View File

@ -30,10 +30,12 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = bitwise # pylint: disable=undefined-variable
_API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__]
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
@ -46,6 +48,7 @@ try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
@ -53,6 +56,7 @@ try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
@ -77,9 +81,8 @@ if '__all__' in vars():
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
# The 'app' module will be imported as part of the placeholder section above.
app.flags = flags # pylint: disable=undefined-variable
if '__all__' in vars():
vars()['__all__'].append('flags')
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
@ -122,25 +125,16 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable
try:
del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
pass
# Similarly for compiler. Do it separately to make sure we do this even if the
# others don't exist.
try:
del core
except NameError:
pass
try:
del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError:
pass
compat.v2.compat.v1 = compat.v1
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
# pylint: enable=undefined-variable

View File

@ -28,12 +28,16 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
# Make sure we get the correct summary module with lazy loading
setattr(_current_module, "summary", summary)
except ImportError:
_logging.warning(
"Limited tf.compat.v2.summary API due to missing TensorBoard "
@ -43,6 +47,7 @@ try:
from tensorflow_estimator.python.estimator.api._v2 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
@ -50,6 +55,7 @@ try:
from tensorflow.python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
@ -61,11 +67,15 @@ except ImportError:
#
# This make this one symbol available directly.
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
# Add module aliases
_current_module = _sys.modules[__name__]
if hasattr(_current_module, 'keras'):
losses = keras.losses
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)

View File

@ -27,12 +27,15 @@ from tensorflow.python.tools import module_util as _module_util
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Hook external TensorFlow modules.
_current_module = _sys.modules[__name__]
try:
from tensorflow_estimator.python.estimator.api._v1 import estimator
_current_module.__path__ = (
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
setattr(_current_module, "estimator", estimator)
except ImportError:
pass
@ -40,9 +43,11 @@ try:
from tensorflow.python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)

View File

@ -4545,9 +4545,9 @@ tf_py_test(
)
tf_py_test(
name = "deprecation_wrapper_test",
name = "module_wrapper_test",
size = "small",
srcs = ["util/deprecation_wrapper_test.py"],
srcs = ["util/module_wrapper_test.py"],
additional_deps = [
":client_testlib",
":util",

View File

@ -48,15 +48,48 @@ _GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
from __future__ import print_function as _print_function
import sys as _sys
"""
_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n'
_DEPRECATION_FOOTER = """
import sys as _sys
from tensorflow.python.util import deprecation_wrapper as _deprecation_wrapper
from tensorflow.python.util import module_wrapper as _module_wrapper
if not isinstance(_sys.modules[__name__], _deprecation_wrapper.DeprecationWrapper):
_sys.modules[__name__] = _deprecation_wrapper.DeprecationWrapper(
_sys.modules[__name__], "%s")
if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
_sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
_sys.modules[__name__], "%s", public_apis=_PUBLIC_APIS, deprecation=%s,
has_lite=%s)
"""
_MODULE_TEXT_TEMPLATE = """
# Inform pytype that this module is dynamically populated (b/111239204).
_LAZY_LOADING = False
_PUBLIC_APIS = {
%s
}
if _LAZY_LOADING:
_HAS_DYNAMIC_ATTRIBUTES = True
else:
import importlib as _importlib
for symbol, symbol_loc_info in _PUBLIC_APIS.items():
if symbol_loc_info[0]:
attr = getattr(_importlib.import_module(symbol_loc_info[0]), symbol_loc_info[1])
else:
attr = _importlib.import_module(symbol_loc_info[1])
setattr(_sys.modules[__name__], symbol, attr)
try:
del symbol
except NameError:
pass
try:
del symbol_loc_info
except NameError:
pass
try:
del attr
except NameError:
pass
"""
@ -76,17 +109,7 @@ def format_import(source_module_name, source_name, dest_name):
Returns:
An import statement string.
"""
if source_module_name:
if source_name == dest_name:
return 'from %s import %s' % (source_module_name, source_name)
else:
return 'from %s import %s as %s' % (
source_module_name, source_name, dest_name)
else:
if source_name == dest_name:
return 'import %s' % source_name
else:
return 'import %s as %s' % (source_name, dest_name)
return " '%s': ('%s', '%s')," % (dest_name, source_module_name, source_name)
def get_canonical_import(import_set):
@ -129,7 +152,6 @@ class _ModuleInitCodeBuilder(object):
lambda: collections.defaultdict(set))
self._dest_import_to_id = collections.defaultdict(int)
# Names that start with underscore in the root module.
self._underscore_names_in_root = []
self._api_version = api_version
def _check_already_imported(self, symbol_id, api_name):
@ -166,9 +188,6 @@ class _ModuleInitCodeBuilder(object):
symbol_id = -1 if not symbol else id(symbol)
self._check_already_imported(symbol_id, full_api_name)
if not dest_module_name and dest_name.startswith('_'):
self._underscore_names_in_root.append(dest_name)
# The same symbol can be available in multiple modules.
# We store all possible ways of importing this symbol and later pick just
# one.
@ -197,11 +216,13 @@ class _ModuleInitCodeBuilder(object):
submodule = module_split[submodule_index-1]
parent_module += '.' + submodule if parent_module else submodule
import_from = self._output_package
if submodule_index > 0:
import_from += '.' + '.'.join(module_split[:submodule_index])
import_from += '.' + '.'.join(module_split[:submodule_index + 1])
self.add_import(
None, import_from, module_split[submodule_index],
parent_module, module_split[submodule_index])
symbol=None,
source_module_name='',
source_name=import_from,
dest_module_name=parent_module,
dest_name=module_split[submodule_index])
def build(self):
"""Get a map from destination module to __init__.py code for that module.
@ -221,26 +242,20 @@ class _ModuleInitCodeBuilder(object):
get_canonical_import(imports)
for _, imports in dest_name_to_imports.items()
]
module_text_map[dest_module] = '\n'.join(sorted(imports_list))
module_text_map[dest_module] = _MODULE_TEXT_TEMPLATE % '\n'.join(
sorted(imports_list))
# Expose exported symbols with underscores in root module
# since we import from it using * import.
underscore_names_str = ', '.join(
'\'%s\'' % name for name in self._underscore_names_in_root)
# We will always generate a root __init__.py file to let us handle *
# imports consistently. Be sure to have a root __init__.py file listed in
# the script outputs.
module_text_map[''] = module_text_map.get('', '') + '''
_names_with_underscore = [%s]
__all__ = [_s for _s in dir() if not _s.startswith('_')]
__all__.extend([_s for _s in _names_with_underscore])
''' % underscore_names_str
if self._api_version == 1: # Add 1.* deprecations.
for dest_module, _ in self._module_imports.items():
for dest_module, _ in self._module_imports.items():
deprecation = 'False'
has_lite = 'False'
if self._api_version == 1: # Add 1.* deprecations.
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
dest_module)
deprecation = 'True'
# Workaround to make sure not load lite from lite/__init__.py
if not dest_module and 'lite' in self._module_imports:
has_lite = 'True'
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
dest_module, deprecation, has_lite)
return module_text_map, footer_text_map
@ -519,7 +534,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
_GENERATED_FILE_HEADER % get_module_docstring(
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
if module in deprecation_footer_map:
contents += deprecation_footer_map[module]
if '# WRAPPER_PLACEHOLDER' in contents:
contents = contents.replace('# WRAPPER_PLACEHOLDER',
deprecation_footer_map[module])
else:
contents += deprecation_footer_map[module]
with open(module_name_to_file_path[module], 'w') as fp:
fp.write(contents)

View File

@ -67,15 +67,16 @@ class CreatePythonApiTest(test.TestCase):
output_package='tensorflow',
api_name='tensorflow',
api_version=1)
expected_import = (
'from tensorflow.python.test_module '
'import test_op as test_op1')
expected_import = ('\'test_op1\': '
'(\'tensorflow.python.test_module\','
' \'test_op\')')
self.assertTrue(
expected_import in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
expected_import = ('from tensorflow.python.test_module '
'import test_op')
expected_import = ('\'test_op\': '
'(\'tensorflow.python.test_module\','
' \'test_op\')')
self.assertTrue(
expected_import in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
@ -89,8 +90,10 @@ class CreatePythonApiTest(test.TestCase):
output_package='tensorflow',
api_name='tensorflow',
api_version=2)
expected_import = ('from tensorflow.python.test_module '
'import TestClass')
expected_import = (
'\'NewTestClass\':'
' (\'tensorflow.python.test_module\','
' \'TestClass\')')
self.assertTrue(
'TestClass' in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
@ -101,8 +104,9 @@ class CreatePythonApiTest(test.TestCase):
output_package='tensorflow',
api_name='tensorflow',
api_version=1)
expected = ('from tensorflow.python.test_module '
'import _TEST_CONSTANT')
expected = ('\'_TEST_CONSTANT\':'
' (\'tensorflow.python.test_module\','
' \'_TEST_CONSTANT\')')
self.assertTrue(expected in str(imports),
msg='%s not in %s' % (expected, str(imports)))

View File

@ -12,138 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides wrapper for TensorFlow modules to support deprecation messages.
"""Compatibility wrapper for TensorFlow modules to support deprecation messages.
TODO(annarev): potentially merge with LazyLoader.
Please use module_wrapper instead.
TODO(yifeif): remove once no longer referred by estimator
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import types
from tensorflow.python.util import module_wrapper
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import tf_stack
from tensorflow.tools.compatibility import all_renames_v2
_PER_MODULE_WARNING_LIMIT = 1
def get_rename_v2(name):
if name not in all_renames_v2.symbol_renames:
return None
return all_renames_v2.symbol_renames[name]
def _call_location():
# We want to get stack frame 2 frames up from current frame,
# i.e. above _getattr__ and _call_location calls.
stack = tf_stack.extract_stack_file_and_line(max_length=3)
if not stack: # should never happen as we're in a function
return 'UNKNOWN'
frame = stack[0]
return '{}:{}'.format(frame.file, frame.line)
def contains_deprecation_decorator(decorators):
return any(
d.decorator_name == 'deprecated' for d in decorators)
def has_deprecation_decorator(symbol):
"""Checks if given object has a deprecation decorator.
We check if deprecation decorator is in decorators as well as
whether symbol is a class whose __init__ method has a deprecation
decorator.
Args:
symbol: Python object.
Returns:
True if symbol has deprecation decorator.
"""
decorators, symbol = tf_decorator.unwrap(symbol)
if contains_deprecation_decorator(decorators):
return True
if tf_inspect.isfunction(symbol):
return False
if not tf_inspect.isclass(symbol):
return False
if not hasattr(symbol, '__init__'):
return False
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
return contains_deprecation_decorator(init_decorators)
class DeprecationWrapper(types.ModuleType):
"""Wrapper for TensorFlow modules to support deprecation messages."""
def __init__(self, wrapped, module_name): # pylint: disable=super-on-old-class
super(DeprecationWrapper, self).__init__(wrapped.__name__)
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _dw_ so that we can
# handle them differently in attribute access methods.
self._dw_wrapped_module = wrapped
self._dw_module_name = module_name
# names we already checked for deprecation
self._dw_deprecated_checked = set()
self._dw_warning_count = 0
def __getattribute__(self, name): # pylint: disable=super-on-old-class
attr = super(DeprecationWrapper, self).__getattribute__(name)
if name.startswith('__') or name.startswith('_dw_'):
return attr
if (self._dw_warning_count < _PER_MODULE_WARNING_LIMIT and
name not in self._dw_deprecated_checked):
self._dw_deprecated_checked.add(name)
if self._dw_module_name:
full_name = 'tf.%s.%s' % (self._dw_module_name, name)
else:
full_name = 'tf.%s' % name
rename = get_rename_v2(full_name)
if rename and not has_deprecation_decorator(attr):
call_location = _call_location()
# skip locations in Python source
if not call_location.startswith('<'):
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), full_name, rename)
self._dw_warning_count += 1
return attr
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
if arg.startswith('_dw_'):
super(DeprecationWrapper, self).__setattr__(arg, val)
else:
setattr(self._dw_wrapped_module, arg, val)
self.__dict__[arg] = val
def __dir__(self):
return dir(self._dw_wrapped_module)
def __delattr__(self, name): # pylint: disable=super-on-old-class
if name.startswith('_dw_'):
super(DeprecationWrapper, self).__delattr__(name)
else:
delattr(self._dw_wrapped_module, name)
def __repr__(self):
return self._dw_wrapped_module.__repr__()
def __getstate__(self):
return self.__name__
def __setstate__(self, d):
# pylint: disable=protected-access
self.__init__(
sys.modules[d]._dw_wrapped_module,
sys.modules[d]._dw_module_name)
# pylint: enable=protected-access
# For backward compatibility for other pip packages that use this class.
DeprecationWrapper = module_wrapper.TFModuleWrapper

View File

@ -1,75 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.util.deprecation_wrapper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import types
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation_wrapper
from tensorflow.python.util import tf_inspect
from tensorflow.tools.compatibility import all_renames_v2
deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5
class MockModule(types.ModuleType):
pass
class DeprecationWrapperTest(test.TestCase):
def testWrapperIsAModule(self):
module = MockModule('test')
wrapped_module = deprecation_wrapper.DeprecationWrapper(
module, 'test')
self.assertTrue(tf_inspect.ismodule(wrapped_module))
@test.mock.patch.object(logging, 'warning', autospec=True)
def testDeprecationWarnings(self, mock_warning):
module = MockModule('test')
module.foo = 1
module.bar = 2
module.baz = 3
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
wrapped_module = deprecation_wrapper.DeprecationWrapper(
module, 'test')
self.assertTrue(tf_inspect.ismodule(wrapped_module))
self.assertEqual(0, mock_warning.call_count)
bar = wrapped_module.bar
self.assertEqual(1, mock_warning.call_count)
foo = wrapped_module.foo
self.assertEqual(1, mock_warning.call_count)
baz = wrapped_module.baz
self.assertEqual(2, mock_warning.call_count)
baz = wrapped_module.baz
self.assertEqual(2, mock_warning.call_count)
# Check that values stayed the same
self.assertEqual(module.foo, foo)
self.assertEqual(module.bar, bar)
self.assertEqual(module.baz, baz)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,205 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides wrapper for TensorFlow modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import sys
import types
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import tf_stack
from tensorflow.tools.compatibility import all_renames_v2
_PER_MODULE_WARNING_LIMIT = 1
def get_rename_v2(name):
if name not in all_renames_v2.symbol_renames:
return None
return all_renames_v2.symbol_renames[name]
def _call_location():
# We want to get stack frame 2 frames up from current frame,
# i.e. above _getattr__ and _call_location calls.
stack = tf_stack.extract_stack_file_and_line(max_length=3)
if not stack: # should never happen as we're in a function
return 'UNKNOWN'
frame = stack[0]
return '{}:{}'.format(frame.file, frame.line)
def contains_deprecation_decorator(decorators):
return any(
d.decorator_name == 'deprecated' for d in decorators)
def has_deprecation_decorator(symbol):
"""Checks if given object has a deprecation decorator.
We check if deprecation decorator is in decorators as well as
whether symbol is a class whose __init__ method has a deprecation
decorator.
Args:
symbol: Python object.
Returns:
True if symbol has deprecation decorator.
"""
decorators, symbol = tf_decorator.unwrap(symbol)
if contains_deprecation_decorator(decorators):
return True
if tf_inspect.isfunction(symbol):
return False
if not tf_inspect.isclass(symbol):
return False
if not hasattr(symbol, '__init__'):
return False
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
return contains_deprecation_decorator(init_decorators)
class TFModuleWrapper(types.ModuleType):
"""Wrapper for TF modules to support deprecation messages and lazyloading."""
def __init__( # pylint: disable=super-on-old-class
self,
wrapped,
module_name,
public_apis=None,
deprecation=True,
has_lite=False): # pylint: enable=super-on-old-class
super(TFModuleWrapper, self).__init__(wrapped.__name__)
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _tfmw_ so that we can
# handle them differently in attribute access methods.
self._tfmw_wrapped_module = wrapped
self._tfmw_module_name = module_name
self._tfmw_public_apis = public_apis
self._tfmw_print_deprecation_warnings = deprecation
self._tfmw_has_lite = has_lite
# Set __all__ so that import * work for lazy loaded modules
if self._tfmw_public_apis:
self._tfmw_wrapped_module.__all__ = list(self._tfmw_public_apis.keys())
self.__all__ = list(self._tfmw_public_apis.keys())
# names we already checked for deprecation
self._tfmw_deprecated_checked = set()
self._tfmw_warning_count = 0
def _tfmw_add_deprecation_warning(self, name, attr):
"""Print deprecation warning for attr with given name if necessary."""
if (self._tfmw_warning_count < _PER_MODULE_WARNING_LIMIT and
name not in self._tfmw_deprecated_checked):
self._tfmw_deprecated_checked.add(name)
if self._tfmw_module_name:
full_name = 'tf.%s.%s' % (self._tfmw_module_name, name)
else:
full_name = 'tf.%s' % name
rename = get_rename_v2(full_name)
if rename and not has_deprecation_decorator(attr):
call_location = _call_location()
# skip locations in Python source
if not call_location.startswith('<'):
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), full_name, rename)
self._tfmw_warning_count += 1
def _tfmw_import_module(self, name):
symbol_loc_info = self._tfmw_public_apis[name]
if symbol_loc_info[0]:
module = importlib.import_module(symbol_loc_info[0])
attr = getattr(module, symbol_loc_info[1])
else:
attr = importlib.import_module(symbol_loc_info[1])
setattr(self._tfmw_wrapped_module, name, attr)
self.__dict__[name] = attr
return attr
def __getattribute__(self, name): # pylint: disable=super-on-old-class
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
return attr
attr = super(TFModuleWrapper, self).__getattribute__(name)
if name.startswith('__') or name.startswith('_tfmw_'):
return attr
if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr
def __getattr__(self, name):
try:
attr = getattr(self._tfmw_wrapped_module, name)
except AttributeError as e:
if not self._tfmw_public_apis:
raise e
if name not in self._tfmw_public_apis:
raise e
attr = self._tfmw_import_module(name)
if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
if not arg.startswith('_tfmw_'):
setattr(self._tfmw_wrapped_module, arg, val)
self.__dict__[arg] = val
if arg not in self.__all__ and arg != '__all__':
self.__all__.append(arg)
super(TFModuleWrapper, self).__setattr__(arg, val)
def __dir__(self):
if self._tfmw_public_apis:
return list(
set(self._tfmw_public_apis.keys()).union(
set([
attr for attr in dir(self._tfmw_wrapped_module)
if not attr.startswith('_')
])))
else:
return dir(self._tfmw_wrapped_module)
def __delattr__(self, name): # pylint: disable=super-on-old-class
if name.startswith('_tfmw_'):
super(TFModuleWrapper, self).__delattr__(name)
else:
delattr(self._tfmw_wrapped_module, name)
def __repr__(self):
return self._tfmw_wrapped_module.__repr__()
def __getstate__(self):
return self.__name__
def __setstate__(self, d):
# pylint: disable=protected-access
self.__init__(sys.modules[d]._tfmw_wrapped_module,
sys.modules[d]._tfmw_module_name)
# pylint: enable=protected-access

View File

@ -0,0 +1,136 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.util.module_wrapper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import types
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import module_wrapper
from tensorflow.python.util import tf_inspect
from tensorflow.tools.compatibility import all_renames_v2
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
class MockModule(types.ModuleType):
pass
class DeprecationWrapperTest(test.TestCase):
def testWrapperIsAModule(self):
module = MockModule('test')
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
self.assertTrue(tf_inspect.ismodule(wrapped_module))
@test.mock.patch.object(logging, 'warning', autospec=True)
def testDeprecationWarnings(self, mock_warning):
module = MockModule('test')
module.foo = 1
module.bar = 2
module.baz = 3
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
self.assertTrue(tf_inspect.ismodule(wrapped_module))
self.assertEqual(0, mock_warning.call_count)
bar = wrapped_module.bar
self.assertEqual(1, mock_warning.call_count)
foo = wrapped_module.foo
self.assertEqual(1, mock_warning.call_count)
baz = wrapped_module.baz # pylint: disable=unused-variable
self.assertEqual(2, mock_warning.call_count)
baz = wrapped_module.baz
self.assertEqual(2, mock_warning.call_count)
# Check that values stayed the same
self.assertEqual(module.foo, foo)
self.assertEqual(module.bar, bar)
class LazyLoadingWrapperTest(test.TestCase):
def testLazyLoad(self):
module = MockModule('test')
apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member
self.assertEqual(wrapped_module.cmd, _cmd)
self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
def testLazyLoadLocalOverride(self):
# Test that we can override and add fields to the wrapped module.
module = MockModule('test')
apis = {'cmd': ('', 'cmd')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
self.assertEqual(wrapped_module.cmd, _cmd)
setattr(wrapped_module, 'cmd', 1)
setattr(wrapped_module, 'cgi', 2)
self.assertEqual(wrapped_module.cmd, 1) # override
self.assertEqual(wrapped_module.cgi, 2) # add
def testLazyLoadDict(self):
# Test that we can override and add fields to the wrapped module.
module = MockModule('test')
apis = {'cmd': ('', 'cmd')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
# At first cmd key does not exist in __dict__
self.assertNotIn('cmd', wrapped_module.__dict__)
# After it is referred (lazyloaded), it gets added to __dict__
wrapped_module.cmd # pylint: disable=pointless-statement
self.assertEqual(wrapped_module.__dict__['cmd'], _cmd)
# When we call setattr, it also gets added to __dict__
setattr(wrapped_module, 'cmd2', _cmd)
self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd)
def testLazyLoadWildcardImport(self):
# Test that public APIs are in __all__.
module = MockModule('test')
module._should_not_be_public = 5
apis = {'cmd': ('', 'cmd')}
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False)
setattr(wrapped_module, 'hello', 1)
self.assertIn('hello', wrapped_module.__all__)
self.assertIn('cmd', wrapped_module.__all__)
self.assertNotIn('_should_not_be_public', wrapped_module.__all__)
def testLazyLoadCorrectLiteModule(self):
# If set, always load lite module from public API list.
module = MockModule('test')
apis = {'lite': ('', 'cmd')}
module.lite = 5
import cmd as _cmd # pylint: disable=g-import-not-at-top
wrapped_module = module_wrapper.TFModuleWrapper(
module, 'test', public_apis=apis, deprecation=False, has_lite=True)
self.assertEqual(wrapped_module.lite, _cmd)
if __name__ == '__main__':
test.main()

View File

@ -23,9 +23,9 @@ import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation_wrapper
from tensorflow.python.util import module_wrapper
deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
class DeprecationTest(test.TestCase):
@ -38,9 +38,8 @@ class DeprecationTest(test.TestCase):
tf.tables_initializer()
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(mock_warning.call_args[0][1],
"module_wrapper.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"tables_initializer")
self.assertRegexpMatches(
@ -60,9 +59,8 @@ class DeprecationTest(test.TestCase):
tf.ragged.RaggedTensorValue(value, row_splits)
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(mock_warning.call_args[0][1],
"module_wrapper.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"ragged.RaggedTensorValue")
self.assertRegexpMatches(
@ -84,9 +82,8 @@ class DeprecationTest(test.TestCase):
tf.sparse_mask(array, mask_indices)
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(mock_warning.call_args[0][1],
"module_wrapper.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"sparse_mask")
self.assertRegexpMatches(
@ -103,9 +100,8 @@ class DeprecationTest(test.TestCase):
tf.VarLenFeature(tf.dtypes.int32)
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(mock_warning.call_args[0][1],
"module_wrapper.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2], r"VarLenFeature")
self.assertRegexpMatches(
@ -122,9 +118,8 @@ class DeprecationTest(test.TestCase):
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
self.assertEqual(1, mock_warning.call_count)
self.assertRegexpMatches(
mock_warning.call_args[0][1],
"deprecation_test.py:")
self.assertRegexpMatches(mock_warning.call_args[0][1],
"module_wrapper.py:")
self.assertRegexpMatches(
mock_warning.call_args[0][2],
r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY")

View File

@ -38,6 +38,11 @@ class ModuleTest(test.TestCase):
def testDict(self):
# Check that a few modules are in __dict__.
# pylint: disable=pointless-statement
tf.nn
tf.keras
tf.image
# pylint: enable=pointless-statement
self.assertIn('nn', tf.__dict__)
self.assertIn('keras', tf.__dict__)
self.assertIn('image', tf.__dict__)

View File

@ -178,15 +178,11 @@ function prepare_src() {
#
# import tensorflow as tf
#
# which is not ok. We are removing the deprecation stuff by using sed and
# deleting the pattern that the wrapper uses (all lines between a line ending
# with _deprecation_wrapper -- the import line -- and a line containing
# _sys.modules[__name__] as the argument of a function -- the last line in
# the deprecation autogenerated pattern)
# which is not ok. We disable deprecation by using sed to toggle the flag
# TODO(mihaimaruseac): When we move the API to root, remove this hack
# Note: Can't do in place sed that works on all OS, so use a temp file instead
sed \
"/_deprecation_wrapper$/,/_sys.modules[__name__],/ d" \
"s/deprecation=True/deprecation=False/g" \
"${TMPDIR}/tensorflow_core/__init__.py" > "${TMPDIR}/tensorflow_core/__init__.out"
mv "${TMPDIR}/tensorflow_core/__init__.out" "${TMPDIR}/tensorflow_core/__init__.py"
}

View File

@ -98,28 +98,6 @@ for _m in _top_level_modules:
# We still need all the names that are toplevel on tensorflow_core
from tensorflow_core import *
# We also need to bring in keras if available in tensorflow_core
# Above import * doesn't import it as __all__ is updated before keras is hooked
try:
from tensorflow_core import keras
except ImportError as e:
pass
# Similarly for estimator, but only if this file is not read via a
# import tensorflow_estimator (same reasoning as above when forwarding estimator
# separatedly from the rest of the top level modules)
if not _root_estimator:
try:
from tensorflow_core import estimator
except ImportError as e:
pass
# And again for tensorboard (comes as summary)
try:
from tensorflow_core import summary
except ImportError as e:
pass
# In V1 API we need to print deprecation messages
from tensorflow.python.util import deprecation_wrapper as _deprecation
if not isinstance(_sys.modules[__name__], _deprecation.DeprecationWrapper):

View File

@ -97,32 +97,4 @@ for _m in _top_level_modules:
# We still need all the names that are toplevel on tensorflow_core
from tensorflow_core import *
# We also need to bring in keras if available in tensorflow_core
# Above import * doesn't import it as __all__ is updated before keras is hooked
try:
from tensorflow_core import keras
except ImportError as e:
pass
# Similarly for estimator, but only if this file is not read via a
# import tensorflow_estimator (same reasoning as above when forwarding estimator
# separatedly from the rest of the top level modules)
if not _root_estimator:
try:
from tensorflow_core import estimator
except ImportError as e:
pass
# And again for tensorboard (comes as summary)
try:
from tensorflow_core import summary
except ImportError as e:
pass
# Also import module aliases
try:
from tensorflow_core import losses, metrics, initializers, optimizers
except ImportError:
pass
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)