diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index b75ec148ae8..8d5e43b672b 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -41,12 +41,15 @@ from tensorflow.python.tools import module_util as _module_util # API IMPORTS PLACEHOLDER +# WRAPPER_PLACEHOLDER + # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. -_API_MODULE = bitwise # pylint: disable=undefined-variable -_current_module = _sys.modules[__name__] +_API_MODULE = sys.modules[__name__].bitwise # pylint: disable=undefined-variable _tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +_current_module = _sys.modules[__name__] + if not hasattr(_current_module, '__path__'): __path__ = [_tf_api_dir] elif _tf_api_dir not in __path__: @@ -57,6 +60,7 @@ try: from tensorboard.summary._tf import summary _current_module.__path__ = ( [_module_util.get_parent_dir(summary)] + _current_module.__path__) + setattr(_current_module, "summary", summary) except ImportError: _logging.warning( "Limited tf.summary API due to missing TensorBoard installation.") @@ -65,6 +69,7 @@ try: from tensorflow_estimator.python.estimator.api._v2 import estimator _current_module.__path__ = ( [_module_util.get_parent_dir(estimator)] + _current_module.__path__) + setattr(_current_module, "estimator", estimator) except ImportError: pass @@ -72,6 +77,7 @@ try: from tensorflow.python.keras.api._v2 import keras _current_module.__path__ = ( [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) except ImportError: pass @@ -122,25 +128,17 @@ if _running_from_pip_package(): # pylint: disable=undefined-variable try: del python - if '__all__' in vars(): - vars()['__all__'].remove('python') - del core - if '__all__' in vars(): - vars()['__all__'].remove('core') except NameError: - # Don't fail if these modules are not available. - # For e.g. this file will be originally placed under tensorflow/_api/v1 which - # does not have 'python', 'core' directories. Then, it will be copied - # to tensorflow/ which does have these two directories. pass -# Similarly for compiler. Do it separately to make sure we do this even if the -# others don't exist. +try: + del core +except NameError: + pass try: del compiler - if '__all__' in vars(): - vars()['__all__'].remove('compiler') except NameError: pass +# pylint: enable=undefined-variable # Add module aliases if hasattr(_current_module, 'keras'): @@ -148,6 +146,10 @@ if hasattr(_current_module, 'keras'): metrics = keras.metrics optimizers = keras.optimizers initializers = keras.initializers + setattr(_current_module, "losses", losses) + setattr(_current_module, "metrics", metrics) + setattr(_current_module, "optimizers", optimizers) + setattr(_current_module, "initializers", initializers) -compat.v2.compat.v1 = compat.v1 +_current_module.compat.v2.compat.v1 = _current_module.compat.v1 # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 4fa92b07051..6d1c40a2428 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -30,10 +30,12 @@ from tensorflow.python.tools import module_util as _module_util # API IMPORTS PLACEHOLDER +# WRAPPER_PLACEHOLDER + # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. -_API_MODULE = bitwise # pylint: disable=undefined-variable +_API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable _current_module = _sys.modules[__name__] _tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) if not hasattr(_current_module, '__path__'): @@ -46,6 +48,7 @@ try: from tensorflow_estimator.python.estimator.api._v1 import estimator _current_module.__path__ = ( [_module_util.get_parent_dir(estimator)] + _current_module.__path__) + setattr(_current_module, "estimator", estimator) except ImportError: pass @@ -53,6 +56,7 @@ try: from tensorflow.python.keras.api._v1 import keras _current_module.__path__ = ( [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) except ImportError: pass @@ -77,9 +81,8 @@ if '__all__' in vars(): from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top # The 'app' module will be imported as part of the placeholder section above. -app.flags = flags # pylint: disable=undefined-variable -if '__all__' in vars(): - vars()['__all__'].append('flags') +_current_module.app.flags = flags # pylint: disable=undefined-variable +setattr(_current_module, "flags", flags) # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. @@ -122,25 +125,16 @@ if _running_from_pip_package(): # pylint: disable=undefined-variable try: del python - if '__all__' in vars(): - vars()['__all__'].remove('python') - del core - if '__all__' in vars(): - vars()['__all__'].remove('core') except NameError: - # Don't fail if these modules are not available. - # For e.g. this file will be originally placed under tensorflow/_api/v1 which - # does not have 'python', 'core' directories. Then, it will be copied - # to tensorflow/ which does have these two directories. pass -# Similarly for compiler. Do it separately to make sure we do this even if the -# others don't exist. +try: + del core +except NameError: + pass try: del compiler - if '__all__' in vars(): - vars()['__all__'].remove('compiler') except NameError: pass -compat.v2.compat.v1 = compat.v1 +_current_module.compat.v2.compat.v1 = _current_module.compat.v1 # pylint: enable=undefined-variable diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index ad2443a0c32..b830af58832 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -28,12 +28,16 @@ from tensorflow.python.tools import module_util as _module_util # API IMPORTS PLACEHOLDER +# WRAPPER_PLACEHOLDER + # Hook external TensorFlow modules. _current_module = _sys.modules[__name__] try: from tensorboard.summary._tf import summary _current_module.__path__ = ( [_module_util.get_parent_dir(summary)] + _current_module.__path__) + # Make sure we get the correct summary module with lazy loading + setattr(_current_module, "summary", summary) except ImportError: _logging.warning( "Limited tf.compat.v2.summary API due to missing TensorBoard " @@ -43,6 +47,7 @@ try: from tensorflow_estimator.python.estimator.api._v2 import estimator _current_module.__path__ = ( [_module_util.get_parent_dir(estimator)] + _current_module.__path__) + setattr(_current_module, "estimator", estimator) except ImportError: pass @@ -50,6 +55,7 @@ try: from tensorflow.python.keras.api._v2 import keras _current_module.__path__ = ( [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) except ImportError: pass @@ -61,11 +67,15 @@ except ImportError: # # This make this one symbol available directly. from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top +setattr(_current_module, "enable_v2_behavior", enable_v2_behavior) # Add module aliases -_current_module = _sys.modules[__name__] if hasattr(_current_module, 'keras'): losses = keras.losses metrics = keras.metrics optimizers = keras.optimizers initializers = keras.initializers + setattr(_current_module, "losses", losses) + setattr(_current_module, "metrics", metrics) + setattr(_current_module, "optimizers", optimizers) + setattr(_current_module, "initializers", initializers) diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 23c722edef7..48374b766b7 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -27,12 +27,15 @@ from tensorflow.python.tools import module_util as _module_util # API IMPORTS PLACEHOLDER +# WRAPPER_PLACEHOLDER + # Hook external TensorFlow modules. _current_module = _sys.modules[__name__] try: from tensorflow_estimator.python.estimator.api._v1 import estimator _current_module.__path__ = ( [_module_util.get_parent_dir(estimator)] + _current_module.__path__) + setattr(_current_module, "estimator", estimator) except ImportError: pass @@ -40,9 +43,11 @@ try: from tensorflow.python.keras.api._v1 import keras _current_module.__path__ = ( [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) except ImportError: pass from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top -app.flags = flags # pylint: disable=undefined-variable +_current_module.app.flags = flags # pylint: disable=undefined-variable +setattr(_current_module, "flags", flags) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a2ee9e07458..915f3518873 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4545,9 +4545,9 @@ tf_py_test( ) tf_py_test( - name = "deprecation_wrapper_test", + name = "module_wrapper_test", size = "small", - srcs = ["util/deprecation_wrapper_test.py"], + srcs = ["util/module_wrapper_test.py"], additional_deps = [ ":client_testlib", ":util", diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 7dd3f97b79d..aeeec69cec8 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -48,15 +48,48 @@ _GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. from __future__ import print_function as _print_function +import sys as _sys + """ _GENERATED_FILE_FOOTER = '\n\ndel _print_function\n' _DEPRECATION_FOOTER = """ -import sys as _sys -from tensorflow.python.util import deprecation_wrapper as _deprecation_wrapper +from tensorflow.python.util import module_wrapper as _module_wrapper -if not isinstance(_sys.modules[__name__], _deprecation_wrapper.DeprecationWrapper): - _sys.modules[__name__] = _deprecation_wrapper.DeprecationWrapper( - _sys.modules[__name__], "%s") +if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper): + _sys.modules[__name__] = _module_wrapper.TFModuleWrapper( + _sys.modules[__name__], "%s", public_apis=_PUBLIC_APIS, deprecation=%s, + has_lite=%s) +""" +_MODULE_TEXT_TEMPLATE = """ +# Inform pytype that this module is dynamically populated (b/111239204). +_LAZY_LOADING = False + +_PUBLIC_APIS = { +%s +} + +if _LAZY_LOADING: + _HAS_DYNAMIC_ATTRIBUTES = True +else: + import importlib as _importlib + for symbol, symbol_loc_info in _PUBLIC_APIS.items(): + if symbol_loc_info[0]: + attr = getattr(_importlib.import_module(symbol_loc_info[0]), symbol_loc_info[1]) + else: + attr = _importlib.import_module(symbol_loc_info[1]) + setattr(_sys.modules[__name__], symbol, attr) + try: + del symbol + except NameError: + pass + try: + del symbol_loc_info + except NameError: + pass + try: + del attr + except NameError: + pass """ @@ -76,17 +109,7 @@ def format_import(source_module_name, source_name, dest_name): Returns: An import statement string. """ - if source_module_name: - if source_name == dest_name: - return 'from %s import %s' % (source_module_name, source_name) - else: - return 'from %s import %s as %s' % ( - source_module_name, source_name, dest_name) - else: - if source_name == dest_name: - return 'import %s' % source_name - else: - return 'import %s as %s' % (source_name, dest_name) + return " '%s': ('%s', '%s')," % (dest_name, source_module_name, source_name) def get_canonical_import(import_set): @@ -129,7 +152,6 @@ class _ModuleInitCodeBuilder(object): lambda: collections.defaultdict(set)) self._dest_import_to_id = collections.defaultdict(int) # Names that start with underscore in the root module. - self._underscore_names_in_root = [] self._api_version = api_version def _check_already_imported(self, symbol_id, api_name): @@ -166,9 +188,6 @@ class _ModuleInitCodeBuilder(object): symbol_id = -1 if not symbol else id(symbol) self._check_already_imported(symbol_id, full_api_name) - if not dest_module_name and dest_name.startswith('_'): - self._underscore_names_in_root.append(dest_name) - # The same symbol can be available in multiple modules. # We store all possible ways of importing this symbol and later pick just # one. @@ -197,11 +216,13 @@ class _ModuleInitCodeBuilder(object): submodule = module_split[submodule_index-1] parent_module += '.' + submodule if parent_module else submodule import_from = self._output_package - if submodule_index > 0: - import_from += '.' + '.'.join(module_split[:submodule_index]) + import_from += '.' + '.'.join(module_split[:submodule_index + 1]) self.add_import( - None, import_from, module_split[submodule_index], - parent_module, module_split[submodule_index]) + symbol=None, + source_module_name='', + source_name=import_from, + dest_module_name=parent_module, + dest_name=module_split[submodule_index]) def build(self): """Get a map from destination module to __init__.py code for that module. @@ -221,26 +242,20 @@ class _ModuleInitCodeBuilder(object): get_canonical_import(imports) for _, imports in dest_name_to_imports.items() ] - module_text_map[dest_module] = '\n'.join(sorted(imports_list)) + module_text_map[dest_module] = _MODULE_TEXT_TEMPLATE % '\n'.join( + sorted(imports_list)) - # Expose exported symbols with underscores in root module - # since we import from it using * import. - underscore_names_str = ', '.join( - '\'%s\'' % name for name in self._underscore_names_in_root) - # We will always generate a root __init__.py file to let us handle * - # imports consistently. Be sure to have a root __init__.py file listed in - # the script outputs. - module_text_map[''] = module_text_map.get('', '') + ''' -_names_with_underscore = [%s] -__all__ = [_s for _s in dir() if not _s.startswith('_')] -__all__.extend([_s for _s in _names_with_underscore]) -''' % underscore_names_str - - if self._api_version == 1: # Add 1.* deprecations. - for dest_module, _ in self._module_imports.items(): + for dest_module, _ in self._module_imports.items(): + deprecation = 'False' + has_lite = 'False' + if self._api_version == 1: # Add 1.* deprecations. if not dest_module.startswith(_COMPAT_MODULE_PREFIX): - footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( - dest_module) + deprecation = 'True' + # Workaround to make sure not load lite from lite/__init__.py + if not dest_module and 'lite' in self._module_imports: + has_lite = 'True' + footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( + dest_module, deprecation, has_lite) return module_text_map, footer_text_map @@ -519,7 +534,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir, _GENERATED_FILE_HEADER % get_module_docstring( module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER) if module in deprecation_footer_map: - contents += deprecation_footer_map[module] + if '# WRAPPER_PLACEHOLDER' in contents: + contents = contents.replace('# WRAPPER_PLACEHOLDER', + deprecation_footer_map[module]) + else: + contents += deprecation_footer_map[module] with open(module_name_to_file_path[module], 'w') as fp: fp.write(contents) diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py index 6e0970ec80a..98afd9a241f 100644 --- a/tensorflow/python/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -67,15 +67,16 @@ class CreatePythonApiTest(test.TestCase): output_package='tensorflow', api_name='tensorflow', api_version=1) - expected_import = ( - 'from tensorflow.python.test_module ' - 'import test_op as test_op1') + expected_import = ('\'test_op1\': ' + '(\'tensorflow.python.test_module\',' + ' \'test_op\')') self.assertTrue( expected_import in str(imports), msg='%s not in %s' % (expected_import, str(imports))) - expected_import = ('from tensorflow.python.test_module ' - 'import test_op') + expected_import = ('\'test_op\': ' + '(\'tensorflow.python.test_module\',' + ' \'test_op\')') self.assertTrue( expected_import in str(imports), msg='%s not in %s' % (expected_import, str(imports))) @@ -89,8 +90,10 @@ class CreatePythonApiTest(test.TestCase): output_package='tensorflow', api_name='tensorflow', api_version=2) - expected_import = ('from tensorflow.python.test_module ' - 'import TestClass') + expected_import = ( + '\'NewTestClass\':' + ' (\'tensorflow.python.test_module\',' + ' \'TestClass\')') self.assertTrue( 'TestClass' in str(imports), msg='%s not in %s' % (expected_import, str(imports))) @@ -101,8 +104,9 @@ class CreatePythonApiTest(test.TestCase): output_package='tensorflow', api_name='tensorflow', api_version=1) - expected = ('from tensorflow.python.test_module ' - 'import _TEST_CONSTANT') + expected = ('\'_TEST_CONSTANT\':' + ' (\'tensorflow.python.test_module\',' + ' \'_TEST_CONSTANT\')') self.assertTrue(expected in str(imports), msg='%s not in %s' % (expected, str(imports))) diff --git a/tensorflow/python/util/deprecation_wrapper.py b/tensorflow/python/util/deprecation_wrapper.py index 0bdaf1631da..2e0eee5ea32 100644 --- a/tensorflow/python/util/deprecation_wrapper.py +++ b/tensorflow/python/util/deprecation_wrapper.py @@ -12,138 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Provides wrapper for TensorFlow modules to support deprecation messages. +"""Compatibility wrapper for TensorFlow modules to support deprecation messages. -TODO(annarev): potentially merge with LazyLoader. +Please use module_wrapper instead. +TODO(yifeif): remove once no longer referred by estimator """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys -import types +from tensorflow.python.util import module_wrapper -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_inspect -from tensorflow.python.util import tf_stack -from tensorflow.tools.compatibility import all_renames_v2 - - -_PER_MODULE_WARNING_LIMIT = 1 - - -def get_rename_v2(name): - if name not in all_renames_v2.symbol_renames: - return None - return all_renames_v2.symbol_renames[name] - - -def _call_location(): - # We want to get stack frame 2 frames up from current frame, - # i.e. above _getattr__ and _call_location calls. - stack = tf_stack.extract_stack_file_and_line(max_length=3) - if not stack: # should never happen as we're in a function - return 'UNKNOWN' - frame = stack[0] - return '{}:{}'.format(frame.file, frame.line) - - -def contains_deprecation_decorator(decorators): - return any( - d.decorator_name == 'deprecated' for d in decorators) - - -def has_deprecation_decorator(symbol): - """Checks if given object has a deprecation decorator. - - We check if deprecation decorator is in decorators as well as - whether symbol is a class whose __init__ method has a deprecation - decorator. - Args: - symbol: Python object. - - Returns: - True if symbol has deprecation decorator. - """ - decorators, symbol = tf_decorator.unwrap(symbol) - if contains_deprecation_decorator(decorators): - return True - if tf_inspect.isfunction(symbol): - return False - if not tf_inspect.isclass(symbol): - return False - if not hasattr(symbol, '__init__'): - return False - init_decorators, _ = tf_decorator.unwrap(symbol.__init__) - return contains_deprecation_decorator(init_decorators) - - -class DeprecationWrapper(types.ModuleType): - """Wrapper for TensorFlow modules to support deprecation messages.""" - - def __init__(self, wrapped, module_name): # pylint: disable=super-on-old-class - super(DeprecationWrapper, self).__init__(wrapped.__name__) - self.__dict__.update(wrapped.__dict__) - # Prefix all local attributes with _dw_ so that we can - # handle them differently in attribute access methods. - self._dw_wrapped_module = wrapped - self._dw_module_name = module_name - # names we already checked for deprecation - self._dw_deprecated_checked = set() - self._dw_warning_count = 0 - - def __getattribute__(self, name): # pylint: disable=super-on-old-class - attr = super(DeprecationWrapper, self).__getattribute__(name) - if name.startswith('__') or name.startswith('_dw_'): - return attr - - if (self._dw_warning_count < _PER_MODULE_WARNING_LIMIT and - name not in self._dw_deprecated_checked): - - self._dw_deprecated_checked.add(name) - - if self._dw_module_name: - full_name = 'tf.%s.%s' % (self._dw_module_name, name) - else: - full_name = 'tf.%s' % name - rename = get_rename_v2(full_name) - if rename and not has_deprecation_decorator(attr): - call_location = _call_location() - # skip locations in Python source - if not call_location.startswith('<'): - logging.warning( - 'From %s: The name %s is deprecated. Please use %s instead.\n', - _call_location(), full_name, rename) - self._dw_warning_count += 1 - return attr - - def __setattr__(self, arg, val): # pylint: disable=super-on-old-class - if arg.startswith('_dw_'): - super(DeprecationWrapper, self).__setattr__(arg, val) - else: - setattr(self._dw_wrapped_module, arg, val) - self.__dict__[arg] = val - - def __dir__(self): - return dir(self._dw_wrapped_module) - - def __delattr__(self, name): # pylint: disable=super-on-old-class - if name.startswith('_dw_'): - super(DeprecationWrapper, self).__delattr__(name) - else: - delattr(self._dw_wrapped_module, name) - - def __repr__(self): - return self._dw_wrapped_module.__repr__() - - def __getstate__(self): - return self.__name__ - - def __setstate__(self, d): - # pylint: disable=protected-access - self.__init__( - sys.modules[d]._dw_wrapped_module, - sys.modules[d]._dw_module_name) - # pylint: enable=protected-access +# For backward compatibility for other pip packages that use this class. +DeprecationWrapper = module_wrapper.TFModuleWrapper diff --git a/tensorflow/python/util/deprecation_wrapper_test.py b/tensorflow/python/util/deprecation_wrapper_test.py deleted file mode 100644 index 84ff22c5937..00000000000 --- a/tensorflow/python/util/deprecation_wrapper_test.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow.python.util.deprecation_wrapper.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import logging -import types - -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import deprecation_wrapper -from tensorflow.python.util import tf_inspect -from tensorflow.tools.compatibility import all_renames_v2 - -deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5 - - -class MockModule(types.ModuleType): - pass - - -class DeprecationWrapperTest(test.TestCase): - - def testWrapperIsAModule(self): - module = MockModule('test') - wrapped_module = deprecation_wrapper.DeprecationWrapper( - module, 'test') - self.assertTrue(tf_inspect.ismodule(wrapped_module)) - - @test.mock.patch.object(logging, 'warning', autospec=True) - def testDeprecationWarnings(self, mock_warning): - module = MockModule('test') - module.foo = 1 - module.bar = 2 - module.baz = 3 - all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2' - all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz' - - wrapped_module = deprecation_wrapper.DeprecationWrapper( - module, 'test') - self.assertTrue(tf_inspect.ismodule(wrapped_module)) - - self.assertEqual(0, mock_warning.call_count) - bar = wrapped_module.bar - self.assertEqual(1, mock_warning.call_count) - foo = wrapped_module.foo - self.assertEqual(1, mock_warning.call_count) - baz = wrapped_module.baz - self.assertEqual(2, mock_warning.call_count) - baz = wrapped_module.baz - self.assertEqual(2, mock_warning.call_count) - - # Check that values stayed the same - self.assertEqual(module.foo, foo) - self.assertEqual(module.bar, bar) - self.assertEqual(module.baz, baz) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/util/module_wrapper.py b/tensorflow/python/util/module_wrapper.py new file mode 100644 index 00000000000..aa232d58495 --- /dev/null +++ b/tensorflow/python/util/module_wrapper.py @@ -0,0 +1,205 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Provides wrapper for TensorFlow modules.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys +import types + +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect +from tensorflow.python.util import tf_stack +from tensorflow.tools.compatibility import all_renames_v2 + + +_PER_MODULE_WARNING_LIMIT = 1 + + +def get_rename_v2(name): + if name not in all_renames_v2.symbol_renames: + return None + return all_renames_v2.symbol_renames[name] + + +def _call_location(): + # We want to get stack frame 2 frames up from current frame, + # i.e. above _getattr__ and _call_location calls. + stack = tf_stack.extract_stack_file_and_line(max_length=3) + if not stack: # should never happen as we're in a function + return 'UNKNOWN' + frame = stack[0] + return '{}:{}'.format(frame.file, frame.line) + + +def contains_deprecation_decorator(decorators): + return any( + d.decorator_name == 'deprecated' for d in decorators) + + +def has_deprecation_decorator(symbol): + """Checks if given object has a deprecation decorator. + + We check if deprecation decorator is in decorators as well as + whether symbol is a class whose __init__ method has a deprecation + decorator. + Args: + symbol: Python object. + + Returns: + True if symbol has deprecation decorator. + """ + decorators, symbol = tf_decorator.unwrap(symbol) + if contains_deprecation_decorator(decorators): + return True + if tf_inspect.isfunction(symbol): + return False + if not tf_inspect.isclass(symbol): + return False + if not hasattr(symbol, '__init__'): + return False + init_decorators, _ = tf_decorator.unwrap(symbol.__init__) + return contains_deprecation_decorator(init_decorators) + + +class TFModuleWrapper(types.ModuleType): + """Wrapper for TF modules to support deprecation messages and lazyloading.""" + + def __init__( # pylint: disable=super-on-old-class + self, + wrapped, + module_name, + public_apis=None, + deprecation=True, + has_lite=False): # pylint: enable=super-on-old-class + super(TFModuleWrapper, self).__init__(wrapped.__name__) + self.__dict__.update(wrapped.__dict__) + # Prefix all local attributes with _tfmw_ so that we can + # handle them differently in attribute access methods. + self._tfmw_wrapped_module = wrapped + self._tfmw_module_name = module_name + self._tfmw_public_apis = public_apis + self._tfmw_print_deprecation_warnings = deprecation + self._tfmw_has_lite = has_lite + # Set __all__ so that import * work for lazy loaded modules + if self._tfmw_public_apis: + self._tfmw_wrapped_module.__all__ = list(self._tfmw_public_apis.keys()) + self.__all__ = list(self._tfmw_public_apis.keys()) + # names we already checked for deprecation + self._tfmw_deprecated_checked = set() + self._tfmw_warning_count = 0 + + def _tfmw_add_deprecation_warning(self, name, attr): + """Print deprecation warning for attr with given name if necessary.""" + if (self._tfmw_warning_count < _PER_MODULE_WARNING_LIMIT and + name not in self._tfmw_deprecated_checked): + + self._tfmw_deprecated_checked.add(name) + + if self._tfmw_module_name: + full_name = 'tf.%s.%s' % (self._tfmw_module_name, name) + else: + full_name = 'tf.%s' % name + rename = get_rename_v2(full_name) + if rename and not has_deprecation_decorator(attr): + call_location = _call_location() + # skip locations in Python source + if not call_location.startswith('<'): + logging.warning( + 'From %s: The name %s is deprecated. Please use %s instead.\n', + _call_location(), full_name, rename) + self._tfmw_warning_count += 1 + + def _tfmw_import_module(self, name): + symbol_loc_info = self._tfmw_public_apis[name] + if symbol_loc_info[0]: + module = importlib.import_module(symbol_loc_info[0]) + attr = getattr(module, symbol_loc_info[1]) + else: + attr = importlib.import_module(symbol_loc_info[1]) + setattr(self._tfmw_wrapped_module, name, attr) + self.__dict__[name] = attr + return attr + + def __getattribute__(self, name): # pylint: disable=super-on-old-class + # Workaround to make sure we do not import from tensorflow/lite/__init__.py + if name == 'lite': + if self._tfmw_has_lite: + attr = self._tfmw_import_module(name) + setattr(self._tfmw_wrapped_module, 'lite', attr) + return attr + + attr = super(TFModuleWrapper, self).__getattribute__(name) + if name.startswith('__') or name.startswith('_tfmw_'): + return attr + + if self._tfmw_print_deprecation_warnings: + self._tfmw_add_deprecation_warning(name, attr) + return attr + + def __getattr__(self, name): + try: + attr = getattr(self._tfmw_wrapped_module, name) + except AttributeError as e: + if not self._tfmw_public_apis: + raise e + if name not in self._tfmw_public_apis: + raise e + attr = self._tfmw_import_module(name) + + if self._tfmw_print_deprecation_warnings: + self._tfmw_add_deprecation_warning(name, attr) + return attr + + def __setattr__(self, arg, val): # pylint: disable=super-on-old-class + if not arg.startswith('_tfmw_'): + setattr(self._tfmw_wrapped_module, arg, val) + self.__dict__[arg] = val + if arg not in self.__all__ and arg != '__all__': + self.__all__.append(arg) + super(TFModuleWrapper, self).__setattr__(arg, val) + + def __dir__(self): + if self._tfmw_public_apis: + return list( + set(self._tfmw_public_apis.keys()).union( + set([ + attr for attr in dir(self._tfmw_wrapped_module) + if not attr.startswith('_') + ]))) + else: + return dir(self._tfmw_wrapped_module) + + def __delattr__(self, name): # pylint: disable=super-on-old-class + if name.startswith('_tfmw_'): + super(TFModuleWrapper, self).__delattr__(name) + else: + delattr(self._tfmw_wrapped_module, name) + + def __repr__(self): + return self._tfmw_wrapped_module.__repr__() + + def __getstate__(self): + return self.__name__ + + def __setstate__(self, d): + # pylint: disable=protected-access + self.__init__(sys.modules[d]._tfmw_wrapped_module, + sys.modules[d]._tfmw_module_name) + # pylint: enable=protected-access diff --git a/tensorflow/python/util/module_wrapper_test.py b/tensorflow/python/util/module_wrapper_test.py new file mode 100644 index 00000000000..582e98abdfa --- /dev/null +++ b/tensorflow/python/util/module_wrapper_test.py @@ -0,0 +1,136 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.util.module_wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import types + +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import module_wrapper +from tensorflow.python.util import tf_inspect +from tensorflow.tools.compatibility import all_renames_v2 + +module_wrapper._PER_MODULE_WARNING_LIMIT = 5 + + +class MockModule(types.ModuleType): + pass + + +class DeprecationWrapperTest(test.TestCase): + + def testWrapperIsAModule(self): + module = MockModule('test') + wrapped_module = module_wrapper.TFModuleWrapper(module, 'test') + self.assertTrue(tf_inspect.ismodule(wrapped_module)) + + @test.mock.patch.object(logging, 'warning', autospec=True) + def testDeprecationWarnings(self, mock_warning): + module = MockModule('test') + module.foo = 1 + module.bar = 2 + module.baz = 3 + all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2' + all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz' + + wrapped_module = module_wrapper.TFModuleWrapper(module, 'test') + self.assertTrue(tf_inspect.ismodule(wrapped_module)) + + self.assertEqual(0, mock_warning.call_count) + bar = wrapped_module.bar + self.assertEqual(1, mock_warning.call_count) + foo = wrapped_module.foo + self.assertEqual(1, mock_warning.call_count) + baz = wrapped_module.baz # pylint: disable=unused-variable + self.assertEqual(2, mock_warning.call_count) + baz = wrapped_module.baz + self.assertEqual(2, mock_warning.call_count) + + # Check that values stayed the same + self.assertEqual(module.foo, foo) + self.assertEqual(module.bar, bar) + + +class LazyLoadingWrapperTest(test.TestCase): + + def testLazyLoad(self): + module = MockModule('test') + apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')} + wrapped_module = module_wrapper.TFModuleWrapper( + module, 'test', public_apis=apis, deprecation=False) + import cmd as _cmd # pylint: disable=g-import-not-at-top + from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member + self.assertEqual(wrapped_module.cmd, _cmd) + self.assertEqual(wrapped_module.ABCMeta, _ABCMeta) + + def testLazyLoadLocalOverride(self): + # Test that we can override and add fields to the wrapped module. + module = MockModule('test') + apis = {'cmd': ('', 'cmd')} + wrapped_module = module_wrapper.TFModuleWrapper( + module, 'test', public_apis=apis, deprecation=False) + import cmd as _cmd # pylint: disable=g-import-not-at-top + self.assertEqual(wrapped_module.cmd, _cmd) + setattr(wrapped_module, 'cmd', 1) + setattr(wrapped_module, 'cgi', 2) + self.assertEqual(wrapped_module.cmd, 1) # override + self.assertEqual(wrapped_module.cgi, 2) # add + + def testLazyLoadDict(self): + # Test that we can override and add fields to the wrapped module. + module = MockModule('test') + apis = {'cmd': ('', 'cmd')} + wrapped_module = module_wrapper.TFModuleWrapper( + module, 'test', public_apis=apis, deprecation=False) + import cmd as _cmd # pylint: disable=g-import-not-at-top + # At first cmd key does not exist in __dict__ + self.assertNotIn('cmd', wrapped_module.__dict__) + # After it is referred (lazyloaded), it gets added to __dict__ + wrapped_module.cmd # pylint: disable=pointless-statement + self.assertEqual(wrapped_module.__dict__['cmd'], _cmd) + # When we call setattr, it also gets added to __dict__ + setattr(wrapped_module, 'cmd2', _cmd) + self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd) + + def testLazyLoadWildcardImport(self): + # Test that public APIs are in __all__. + module = MockModule('test') + module._should_not_be_public = 5 + apis = {'cmd': ('', 'cmd')} + wrapped_module = module_wrapper.TFModuleWrapper( + module, 'test', public_apis=apis, deprecation=False) + setattr(wrapped_module, 'hello', 1) + self.assertIn('hello', wrapped_module.__all__) + self.assertIn('cmd', wrapped_module.__all__) + self.assertNotIn('_should_not_be_public', wrapped_module.__all__) + + def testLazyLoadCorrectLiteModule(self): + # If set, always load lite module from public API list. + module = MockModule('test') + apis = {'lite': ('', 'cmd')} + module.lite = 5 + import cmd as _cmd # pylint: disable=g-import-not-at-top + wrapped_module = module_wrapper.TFModuleWrapper( + module, 'test', public_apis=apis, deprecation=False, has_lite=True) + self.assertEqual(wrapped_module.lite, _cmd) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/tools/api/tests/deprecation_test.py b/tensorflow/tools/api/tests/deprecation_test.py index 8f6748f5787..3a5cf0d043e 100644 --- a/tensorflow/tools/api/tests/deprecation_test.py +++ b/tensorflow/tools/api/tests/deprecation_test.py @@ -23,9 +23,9 @@ import tensorflow as tf from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import deprecation_wrapper +from tensorflow.python.util import module_wrapper -deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5 +module_wrapper._PER_MODULE_WARNING_LIMIT = 5 class DeprecationTest(test.TestCase): @@ -38,9 +38,8 @@ class DeprecationTest(test.TestCase): tf.tables_initializer() self.assertEqual(1, mock_warning.call_count) - self.assertRegexpMatches( - mock_warning.call_args[0][1], - "deprecation_test.py:") + self.assertRegexpMatches(mock_warning.call_args[0][1], + "module_wrapper.py:") self.assertRegexpMatches( mock_warning.call_args[0][2], r"tables_initializer") self.assertRegexpMatches( @@ -60,9 +59,8 @@ class DeprecationTest(test.TestCase): tf.ragged.RaggedTensorValue(value, row_splits) self.assertEqual(1, mock_warning.call_count) - self.assertRegexpMatches( - mock_warning.call_args[0][1], - "deprecation_test.py:") + self.assertRegexpMatches(mock_warning.call_args[0][1], + "module_wrapper.py:") self.assertRegexpMatches( mock_warning.call_args[0][2], r"ragged.RaggedTensorValue") self.assertRegexpMatches( @@ -84,9 +82,8 @@ class DeprecationTest(test.TestCase): tf.sparse_mask(array, mask_indices) self.assertEqual(1, mock_warning.call_count) - self.assertRegexpMatches( - mock_warning.call_args[0][1], - "deprecation_test.py:") + self.assertRegexpMatches(mock_warning.call_args[0][1], + "module_wrapper.py:") self.assertRegexpMatches( mock_warning.call_args[0][2], r"sparse_mask") self.assertRegexpMatches( @@ -103,9 +100,8 @@ class DeprecationTest(test.TestCase): tf.VarLenFeature(tf.dtypes.int32) self.assertEqual(1, mock_warning.call_count) - self.assertRegexpMatches( - mock_warning.call_args[0][1], - "deprecation_test.py:") + self.assertRegexpMatches(mock_warning.call_args[0][1], + "module_wrapper.py:") self.assertRegexpMatches( mock_warning.call_args[0][2], r"VarLenFeature") self.assertRegexpMatches( @@ -122,9 +118,8 @@ class DeprecationTest(test.TestCase): tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement self.assertEqual(1, mock_warning.call_count) - self.assertRegexpMatches( - mock_warning.call_args[0][1], - "deprecation_test.py:") + self.assertRegexpMatches(mock_warning.call_args[0][1], + "module_wrapper.py:") self.assertRegexpMatches( mock_warning.call_args[0][2], r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY") diff --git a/tensorflow/tools/api/tests/module_test.py b/tensorflow/tools/api/tests/module_test.py index 787df35ac30..257d558cda7 100644 --- a/tensorflow/tools/api/tests/module_test.py +++ b/tensorflow/tools/api/tests/module_test.py @@ -38,6 +38,11 @@ class ModuleTest(test.TestCase): def testDict(self): # Check that a few modules are in __dict__. + # pylint: disable=pointless-statement + tf.nn + tf.keras + tf.image + # pylint: enable=pointless-statement self.assertIn('nn', tf.__dict__) self.assertIn('keras', tf.__dict__) self.assertIn('image', tf.__dict__) diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 8c19ca010e9..5420769e25d 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -178,15 +178,11 @@ function prepare_src() { # # import tensorflow as tf # - # which is not ok. We are removing the deprecation stuff by using sed and - # deleting the pattern that the wrapper uses (all lines between a line ending - # with _deprecation_wrapper -- the import line -- and a line containing - # _sys.modules[__name__] as the argument of a function -- the last line in - # the deprecation autogenerated pattern) + # which is not ok. We disable deprecation by using sed to toggle the flag # TODO(mihaimaruseac): When we move the API to root, remove this hack # Note: Can't do in place sed that works on all OS, so use a temp file instead sed \ - "/_deprecation_wrapper$/,/_sys.modules[__name__],/ d" \ + "s/deprecation=True/deprecation=False/g" \ "${TMPDIR}/tensorflow_core/__init__.py" > "${TMPDIR}/tensorflow_core/__init__.out" mv "${TMPDIR}/tensorflow_core/__init__.out" "${TMPDIR}/tensorflow_core/__init__.py" } diff --git a/tensorflow/virtual_root_template_v1.__init__.py b/tensorflow/virtual_root_template_v1.__init__.py index bb076759e60..785043a1a3f 100644 --- a/tensorflow/virtual_root_template_v1.__init__.py +++ b/tensorflow/virtual_root_template_v1.__init__.py @@ -98,28 +98,6 @@ for _m in _top_level_modules: # We still need all the names that are toplevel on tensorflow_core from tensorflow_core import * -# We also need to bring in keras if available in tensorflow_core -# Above import * doesn't import it as __all__ is updated before keras is hooked -try: - from tensorflow_core import keras -except ImportError as e: - pass - -# Similarly for estimator, but only if this file is not read via a -# import tensorflow_estimator (same reasoning as above when forwarding estimator -# separatedly from the rest of the top level modules) -if not _root_estimator: - try: - from tensorflow_core import estimator - except ImportError as e: - pass - -# And again for tensorboard (comes as summary) -try: - from tensorflow_core import summary -except ImportError as e: - pass - # In V1 API we need to print deprecation messages from tensorflow.python.util import deprecation_wrapper as _deprecation if not isinstance(_sys.modules[__name__], _deprecation.DeprecationWrapper): diff --git a/tensorflow/virtual_root_template_v2.__init__.py b/tensorflow/virtual_root_template_v2.__init__.py index bd212adf3d2..7d40733be7b 100644 --- a/tensorflow/virtual_root_template_v2.__init__.py +++ b/tensorflow/virtual_root_template_v2.__init__.py @@ -97,32 +97,4 @@ for _m in _top_level_modules: # We still need all the names that are toplevel on tensorflow_core from tensorflow_core import * -# We also need to bring in keras if available in tensorflow_core -# Above import * doesn't import it as __all__ is updated before keras is hooked -try: - from tensorflow_core import keras -except ImportError as e: - pass - -# Similarly for estimator, but only if this file is not read via a -# import tensorflow_estimator (same reasoning as above when forwarding estimator -# separatedly from the rest of the top level modules) -if not _root_estimator: - try: - from tensorflow_core import estimator - except ImportError as e: - pass - -# And again for tensorboard (comes as summary) -try: - from tensorflow_core import summary -except ImportError as e: - pass - -# Also import module aliases -try: - from tensorflow_core import losses, metrics, initializers, optimizers -except ImportError: - pass - # LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)