Add hooks to allow lazyLoad TensorFlow public API.
- Updates existing DeprecationWrapper with the ability to import modules only when they are referred. - Updates how TensorFlow generates public API. Wraps all generated TensorFlow __init__.py modules with this enhanced wrapper. To enable lazy-loading in the future, toggle _LAZY_LOADING flag in create_python_api.py. Once lazy loading is enabled, the wrapper will have the following behaviors: - dir() will always return module?s attributes. - __all__ will always return all public APIs. - __dict__ will be populated as attributes are being referred. - After wrapper instance is created, to add more attributes, use setattr(import does not explicitly call setattr) to make sure dir, __all__, __dict__ are updated. - import * will work as expected. Built and tested with pip package. PiperOrigin-RevId: 257240535
This commit is contained in:
parent
b34ed5bd1e
commit
7ece5ce95f
@ -41,12 +41,15 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# WRAPPER_PLACEHOLDER
|
||||
|
||||
# Make sure directory containing top level submodules is in
|
||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||
# We're using bitwise, but there's nothing special about that.
|
||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
||||
_current_module = _sys.modules[__name__]
|
||||
_API_MODULE = sys.modules[__name__].bitwise # pylint: disable=undefined-variable
|
||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||
_current_module = _sys.modules[__name__]
|
||||
|
||||
if not hasattr(_current_module, '__path__'):
|
||||
__path__ = [_tf_api_dir]
|
||||
elif _tf_api_dir not in __path__:
|
||||
@ -57,6 +60,7 @@ try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||
setattr(_current_module, "summary", summary)
|
||||
except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.summary API due to missing TensorBoard installation.")
|
||||
@ -65,6 +69,7 @@ try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -72,6 +77,7 @@ try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -122,25 +128,17 @@ if _running_from_pip_package():
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('python')
|
||||
del core
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('core')
|
||||
except NameError:
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
# does not have 'python', 'core' directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
pass
|
||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
||||
# others don't exist.
|
||||
try:
|
||||
del core
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del compiler
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('compiler')
|
||||
except NameError:
|
||||
pass
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
# Add module aliases
|
||||
if hasattr(_current_module, 'keras'):
|
||||
@ -148,6 +146,10 @@ if hasattr(_current_module, 'keras'):
|
||||
metrics = keras.metrics
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
|
||||
compat.v2.compat.v1 = compat.v1
|
||||
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -30,10 +30,12 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# WRAPPER_PLACEHOLDER
|
||||
|
||||
# Make sure directory containing top level submodules is in
|
||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||
# We're using bitwise, but there's nothing special about that.
|
||||
_API_MODULE = bitwise # pylint: disable=undefined-variable
|
||||
_API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable
|
||||
_current_module = _sys.modules[__name__]
|
||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
|
||||
if not hasattr(_current_module, '__path__'):
|
||||
@ -46,6 +48,7 @@ try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -53,6 +56,7 @@ try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -77,9 +81,8 @@ if '__all__' in vars():
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
# The 'app' module will be imported as part of the placeholder section above.
|
||||
app.flags = flags # pylint: disable=undefined-variable
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].append('flags')
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
setattr(_current_module, "flags", flags)
|
||||
|
||||
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
|
||||
# running under pip.
|
||||
@ -122,25 +125,16 @@ if _running_from_pip_package():
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('python')
|
||||
del core
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('core')
|
||||
except NameError:
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
# does not have 'python', 'core' directories. Then, it will be copied
|
||||
# to tensorflow/ which does have these two directories.
|
||||
pass
|
||||
# Similarly for compiler. Do it separately to make sure we do this even if the
|
||||
# others don't exist.
|
||||
try:
|
||||
del core
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del compiler
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('compiler')
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
compat.v2.compat.v1 = compat.v1
|
||||
_current_module.compat.v2.compat.v1 = _current_module.compat.v1
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -28,12 +28,16 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# WRAPPER_PLACEHOLDER
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
try:
|
||||
from tensorboard.summary._tf import summary
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
|
||||
# Make sure we get the correct summary module with lazy loading
|
||||
setattr(_current_module, "summary", summary)
|
||||
except ImportError:
|
||||
_logging.warning(
|
||||
"Limited tf.compat.v2.summary API due to missing TensorBoard "
|
||||
@ -43,6 +47,7 @@ try:
|
||||
from tensorflow_estimator.python.estimator.api._v2 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -50,6 +55,7 @@ try:
|
||||
from tensorflow.python.keras.api._v2 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -61,11 +67,15 @@ except ImportError:
|
||||
#
|
||||
# This make this one symbol available directly.
|
||||
from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top
|
||||
setattr(_current_module, "enable_v2_behavior", enable_v2_behavior)
|
||||
|
||||
# Add module aliases
|
||||
_current_module = _sys.modules[__name__]
|
||||
if hasattr(_current_module, 'keras'):
|
||||
losses = keras.losses
|
||||
metrics = keras.metrics
|
||||
optimizers = keras.optimizers
|
||||
initializers = keras.initializers
|
||||
setattr(_current_module, "losses", losses)
|
||||
setattr(_current_module, "metrics", metrics)
|
||||
setattr(_current_module, "optimizers", optimizers)
|
||||
setattr(_current_module, "initializers", initializers)
|
||||
|
@ -27,12 +27,15 @@ from tensorflow.python.tools import module_util as _module_util
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
# WRAPPER_PLACEHOLDER
|
||||
|
||||
# Hook external TensorFlow modules.
|
||||
_current_module = _sys.modules[__name__]
|
||||
try:
|
||||
from tensorflow_estimator.python.estimator.api._v1 import estimator
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(estimator)] + _current_module.__path__)
|
||||
setattr(_current_module, "estimator", estimator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -40,9 +43,11 @@ try:
|
||||
from tensorflow.python.keras.api._v1 import keras
|
||||
_current_module.__path__ = (
|
||||
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
|
||||
setattr(_current_module, "keras", keras)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
app.flags = flags # pylint: disable=undefined-variable
|
||||
_current_module.app.flags = flags # pylint: disable=undefined-variable
|
||||
setattr(_current_module, "flags", flags)
|
||||
|
@ -4545,9 +4545,9 @@ tf_py_test(
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "deprecation_wrapper_test",
|
||||
name = "module_wrapper_test",
|
||||
size = "small",
|
||||
srcs = ["util/deprecation_wrapper_test.py"],
|
||||
srcs = ["util/module_wrapper_test.py"],
|
||||
additional_deps = [
|
||||
":client_testlib",
|
||||
":util",
|
||||
|
@ -48,15 +48,48 @@ _GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
|
||||
|
||||
from __future__ import print_function as _print_function
|
||||
|
||||
import sys as _sys
|
||||
|
||||
"""
|
||||
_GENERATED_FILE_FOOTER = '\n\ndel _print_function\n'
|
||||
_DEPRECATION_FOOTER = """
|
||||
import sys as _sys
|
||||
from tensorflow.python.util import deprecation_wrapper as _deprecation_wrapper
|
||||
from tensorflow.python.util import module_wrapper as _module_wrapper
|
||||
|
||||
if not isinstance(_sys.modules[__name__], _deprecation_wrapper.DeprecationWrapper):
|
||||
_sys.modules[__name__] = _deprecation_wrapper.DeprecationWrapper(
|
||||
_sys.modules[__name__], "%s")
|
||||
if not isinstance(_sys.modules[__name__], _module_wrapper.TFModuleWrapper):
|
||||
_sys.modules[__name__] = _module_wrapper.TFModuleWrapper(
|
||||
_sys.modules[__name__], "%s", public_apis=_PUBLIC_APIS, deprecation=%s,
|
||||
has_lite=%s)
|
||||
"""
|
||||
_MODULE_TEXT_TEMPLATE = """
|
||||
# Inform pytype that this module is dynamically populated (b/111239204).
|
||||
_LAZY_LOADING = False
|
||||
|
||||
_PUBLIC_APIS = {
|
||||
%s
|
||||
}
|
||||
|
||||
if _LAZY_LOADING:
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
else:
|
||||
import importlib as _importlib
|
||||
for symbol, symbol_loc_info in _PUBLIC_APIS.items():
|
||||
if symbol_loc_info[0]:
|
||||
attr = getattr(_importlib.import_module(symbol_loc_info[0]), symbol_loc_info[1])
|
||||
else:
|
||||
attr = _importlib.import_module(symbol_loc_info[1])
|
||||
setattr(_sys.modules[__name__], symbol, attr)
|
||||
try:
|
||||
del symbol
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del symbol_loc_info
|
||||
except NameError:
|
||||
pass
|
||||
try:
|
||||
del attr
|
||||
except NameError:
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
@ -76,17 +109,7 @@ def format_import(source_module_name, source_name, dest_name):
|
||||
Returns:
|
||||
An import statement string.
|
||||
"""
|
||||
if source_module_name:
|
||||
if source_name == dest_name:
|
||||
return 'from %s import %s' % (source_module_name, source_name)
|
||||
else:
|
||||
return 'from %s import %s as %s' % (
|
||||
source_module_name, source_name, dest_name)
|
||||
else:
|
||||
if source_name == dest_name:
|
||||
return 'import %s' % source_name
|
||||
else:
|
||||
return 'import %s as %s' % (source_name, dest_name)
|
||||
return " '%s': ('%s', '%s')," % (dest_name, source_module_name, source_name)
|
||||
|
||||
|
||||
def get_canonical_import(import_set):
|
||||
@ -129,7 +152,6 @@ class _ModuleInitCodeBuilder(object):
|
||||
lambda: collections.defaultdict(set))
|
||||
self._dest_import_to_id = collections.defaultdict(int)
|
||||
# Names that start with underscore in the root module.
|
||||
self._underscore_names_in_root = []
|
||||
self._api_version = api_version
|
||||
|
||||
def _check_already_imported(self, symbol_id, api_name):
|
||||
@ -166,9 +188,6 @@ class _ModuleInitCodeBuilder(object):
|
||||
symbol_id = -1 if not symbol else id(symbol)
|
||||
self._check_already_imported(symbol_id, full_api_name)
|
||||
|
||||
if not dest_module_name and dest_name.startswith('_'):
|
||||
self._underscore_names_in_root.append(dest_name)
|
||||
|
||||
# The same symbol can be available in multiple modules.
|
||||
# We store all possible ways of importing this symbol and later pick just
|
||||
# one.
|
||||
@ -197,11 +216,13 @@ class _ModuleInitCodeBuilder(object):
|
||||
submodule = module_split[submodule_index-1]
|
||||
parent_module += '.' + submodule if parent_module else submodule
|
||||
import_from = self._output_package
|
||||
if submodule_index > 0:
|
||||
import_from += '.' + '.'.join(module_split[:submodule_index])
|
||||
import_from += '.' + '.'.join(module_split[:submodule_index + 1])
|
||||
self.add_import(
|
||||
None, import_from, module_split[submodule_index],
|
||||
parent_module, module_split[submodule_index])
|
||||
symbol=None,
|
||||
source_module_name='',
|
||||
source_name=import_from,
|
||||
dest_module_name=parent_module,
|
||||
dest_name=module_split[submodule_index])
|
||||
|
||||
def build(self):
|
||||
"""Get a map from destination module to __init__.py code for that module.
|
||||
@ -221,26 +242,20 @@ class _ModuleInitCodeBuilder(object):
|
||||
get_canonical_import(imports)
|
||||
for _, imports in dest_name_to_imports.items()
|
||||
]
|
||||
module_text_map[dest_module] = '\n'.join(sorted(imports_list))
|
||||
module_text_map[dest_module] = _MODULE_TEXT_TEMPLATE % '\n'.join(
|
||||
sorted(imports_list))
|
||||
|
||||
# Expose exported symbols with underscores in root module
|
||||
# since we import from it using * import.
|
||||
underscore_names_str = ', '.join(
|
||||
'\'%s\'' % name for name in self._underscore_names_in_root)
|
||||
# We will always generate a root __init__.py file to let us handle *
|
||||
# imports consistently. Be sure to have a root __init__.py file listed in
|
||||
# the script outputs.
|
||||
module_text_map[''] = module_text_map.get('', '') + '''
|
||||
_names_with_underscore = [%s]
|
||||
__all__ = [_s for _s in dir() if not _s.startswith('_')]
|
||||
__all__.extend([_s for _s in _names_with_underscore])
|
||||
''' % underscore_names_str
|
||||
|
||||
if self._api_version == 1: # Add 1.* deprecations.
|
||||
for dest_module, _ in self._module_imports.items():
|
||||
for dest_module, _ in self._module_imports.items():
|
||||
deprecation = 'False'
|
||||
has_lite = 'False'
|
||||
if self._api_version == 1: # Add 1.* deprecations.
|
||||
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
|
||||
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
|
||||
dest_module)
|
||||
deprecation = 'True'
|
||||
# Workaround to make sure not load lite from lite/__init__.py
|
||||
if not dest_module and 'lite' in self._module_imports:
|
||||
has_lite = 'True'
|
||||
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
|
||||
dest_module, deprecation, has_lite)
|
||||
|
||||
return module_text_map, footer_text_map
|
||||
|
||||
@ -519,7 +534,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
||||
_GENERATED_FILE_HEADER % get_module_docstring(
|
||||
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
|
||||
if module in deprecation_footer_map:
|
||||
contents += deprecation_footer_map[module]
|
||||
if '# WRAPPER_PLACEHOLDER' in contents:
|
||||
contents = contents.replace('# WRAPPER_PLACEHOLDER',
|
||||
deprecation_footer_map[module])
|
||||
else:
|
||||
contents += deprecation_footer_map[module]
|
||||
with open(module_name_to_file_path[module], 'w') as fp:
|
||||
fp.write(contents)
|
||||
|
||||
|
@ -67,15 +67,16 @@ class CreatePythonApiTest(test.TestCase):
|
||||
output_package='tensorflow',
|
||||
api_name='tensorflow',
|
||||
api_version=1)
|
||||
expected_import = (
|
||||
'from tensorflow.python.test_module '
|
||||
'import test_op as test_op1')
|
||||
expected_import = ('\'test_op1\': '
|
||||
'(\'tensorflow.python.test_module\','
|
||||
' \'test_op\')')
|
||||
self.assertTrue(
|
||||
expected_import in str(imports),
|
||||
msg='%s not in %s' % (expected_import, str(imports)))
|
||||
|
||||
expected_import = ('from tensorflow.python.test_module '
|
||||
'import test_op')
|
||||
expected_import = ('\'test_op\': '
|
||||
'(\'tensorflow.python.test_module\','
|
||||
' \'test_op\')')
|
||||
self.assertTrue(
|
||||
expected_import in str(imports),
|
||||
msg='%s not in %s' % (expected_import, str(imports)))
|
||||
@ -89,8 +90,10 @@ class CreatePythonApiTest(test.TestCase):
|
||||
output_package='tensorflow',
|
||||
api_name='tensorflow',
|
||||
api_version=2)
|
||||
expected_import = ('from tensorflow.python.test_module '
|
||||
'import TestClass')
|
||||
expected_import = (
|
||||
'\'NewTestClass\':'
|
||||
' (\'tensorflow.python.test_module\','
|
||||
' \'TestClass\')')
|
||||
self.assertTrue(
|
||||
'TestClass' in str(imports),
|
||||
msg='%s not in %s' % (expected_import, str(imports)))
|
||||
@ -101,8 +104,9 @@ class CreatePythonApiTest(test.TestCase):
|
||||
output_package='tensorflow',
|
||||
api_name='tensorflow',
|
||||
api_version=1)
|
||||
expected = ('from tensorflow.python.test_module '
|
||||
'import _TEST_CONSTANT')
|
||||
expected = ('\'_TEST_CONSTANT\':'
|
||||
' (\'tensorflow.python.test_module\','
|
||||
' \'_TEST_CONSTANT\')')
|
||||
self.assertTrue(expected in str(imports),
|
||||
msg='%s not in %s' % (expected, str(imports)))
|
||||
|
||||
|
@ -12,138 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Provides wrapper for TensorFlow modules to support deprecation messages.
|
||||
"""Compatibility wrapper for TensorFlow modules to support deprecation messages.
|
||||
|
||||
TODO(annarev): potentially merge with LazyLoader.
|
||||
Please use module_wrapper instead.
|
||||
TODO(yifeif): remove once no longer referred by estimator
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import types
|
||||
from tensorflow.python.util import module_wrapper
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util import tf_stack
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
|
||||
|
||||
_PER_MODULE_WARNING_LIMIT = 1
|
||||
|
||||
|
||||
def get_rename_v2(name):
|
||||
if name not in all_renames_v2.symbol_renames:
|
||||
return None
|
||||
return all_renames_v2.symbol_renames[name]
|
||||
|
||||
|
||||
def _call_location():
|
||||
# We want to get stack frame 2 frames up from current frame,
|
||||
# i.e. above _getattr__ and _call_location calls.
|
||||
stack = tf_stack.extract_stack_file_and_line(max_length=3)
|
||||
if not stack: # should never happen as we're in a function
|
||||
return 'UNKNOWN'
|
||||
frame = stack[0]
|
||||
return '{}:{}'.format(frame.file, frame.line)
|
||||
|
||||
|
||||
def contains_deprecation_decorator(decorators):
|
||||
return any(
|
||||
d.decorator_name == 'deprecated' for d in decorators)
|
||||
|
||||
|
||||
def has_deprecation_decorator(symbol):
|
||||
"""Checks if given object has a deprecation decorator.
|
||||
|
||||
We check if deprecation decorator is in decorators as well as
|
||||
whether symbol is a class whose __init__ method has a deprecation
|
||||
decorator.
|
||||
Args:
|
||||
symbol: Python object.
|
||||
|
||||
Returns:
|
||||
True if symbol has deprecation decorator.
|
||||
"""
|
||||
decorators, symbol = tf_decorator.unwrap(symbol)
|
||||
if contains_deprecation_decorator(decorators):
|
||||
return True
|
||||
if tf_inspect.isfunction(symbol):
|
||||
return False
|
||||
if not tf_inspect.isclass(symbol):
|
||||
return False
|
||||
if not hasattr(symbol, '__init__'):
|
||||
return False
|
||||
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
|
||||
return contains_deprecation_decorator(init_decorators)
|
||||
|
||||
|
||||
class DeprecationWrapper(types.ModuleType):
|
||||
"""Wrapper for TensorFlow modules to support deprecation messages."""
|
||||
|
||||
def __init__(self, wrapped, module_name): # pylint: disable=super-on-old-class
|
||||
super(DeprecationWrapper, self).__init__(wrapped.__name__)
|
||||
self.__dict__.update(wrapped.__dict__)
|
||||
# Prefix all local attributes with _dw_ so that we can
|
||||
# handle them differently in attribute access methods.
|
||||
self._dw_wrapped_module = wrapped
|
||||
self._dw_module_name = module_name
|
||||
# names we already checked for deprecation
|
||||
self._dw_deprecated_checked = set()
|
||||
self._dw_warning_count = 0
|
||||
|
||||
def __getattribute__(self, name): # pylint: disable=super-on-old-class
|
||||
attr = super(DeprecationWrapper, self).__getattribute__(name)
|
||||
if name.startswith('__') or name.startswith('_dw_'):
|
||||
return attr
|
||||
|
||||
if (self._dw_warning_count < _PER_MODULE_WARNING_LIMIT and
|
||||
name not in self._dw_deprecated_checked):
|
||||
|
||||
self._dw_deprecated_checked.add(name)
|
||||
|
||||
if self._dw_module_name:
|
||||
full_name = 'tf.%s.%s' % (self._dw_module_name, name)
|
||||
else:
|
||||
full_name = 'tf.%s' % name
|
||||
rename = get_rename_v2(full_name)
|
||||
if rename and not has_deprecation_decorator(attr):
|
||||
call_location = _call_location()
|
||||
# skip locations in Python source
|
||||
if not call_location.startswith('<'):
|
||||
logging.warning(
|
||||
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
||||
_call_location(), full_name, rename)
|
||||
self._dw_warning_count += 1
|
||||
return attr
|
||||
|
||||
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
|
||||
if arg.startswith('_dw_'):
|
||||
super(DeprecationWrapper, self).__setattr__(arg, val)
|
||||
else:
|
||||
setattr(self._dw_wrapped_module, arg, val)
|
||||
self.__dict__[arg] = val
|
||||
|
||||
def __dir__(self):
|
||||
return dir(self._dw_wrapped_module)
|
||||
|
||||
def __delattr__(self, name): # pylint: disable=super-on-old-class
|
||||
if name.startswith('_dw_'):
|
||||
super(DeprecationWrapper, self).__delattr__(name)
|
||||
else:
|
||||
delattr(self._dw_wrapped_module, name)
|
||||
|
||||
def __repr__(self):
|
||||
return self._dw_wrapped_module.__repr__()
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__name__
|
||||
|
||||
def __setstate__(self, d):
|
||||
# pylint: disable=protected-access
|
||||
self.__init__(
|
||||
sys.modules[d]._dw_wrapped_module,
|
||||
sys.modules[d]._dw_module_name)
|
||||
# pylint: enable=protected-access
|
||||
# For backward compatibility for other pip packages that use this class.
|
||||
DeprecationWrapper = module_wrapper.TFModuleWrapper
|
||||
|
@ -1,75 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.util.deprecation_wrapper."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import types
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation_wrapper
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
|
||||
deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
||||
|
||||
|
||||
class MockModule(types.ModuleType):
|
||||
pass
|
||||
|
||||
|
||||
class DeprecationWrapperTest(test.TestCase):
|
||||
|
||||
def testWrapperIsAModule(self):
|
||||
module = MockModule('test')
|
||||
wrapped_module = deprecation_wrapper.DeprecationWrapper(
|
||||
module, 'test')
|
||||
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
||||
|
||||
@test.mock.patch.object(logging, 'warning', autospec=True)
|
||||
def testDeprecationWarnings(self, mock_warning):
|
||||
module = MockModule('test')
|
||||
module.foo = 1
|
||||
module.bar = 2
|
||||
module.baz = 3
|
||||
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
|
||||
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
|
||||
|
||||
wrapped_module = deprecation_wrapper.DeprecationWrapper(
|
||||
module, 'test')
|
||||
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
||||
|
||||
self.assertEqual(0, mock_warning.call_count)
|
||||
bar = wrapped_module.bar
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
foo = wrapped_module.foo
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
baz = wrapped_module.baz
|
||||
self.assertEqual(2, mock_warning.call_count)
|
||||
baz = wrapped_module.baz
|
||||
self.assertEqual(2, mock_warning.call_count)
|
||||
|
||||
# Check that values stayed the same
|
||||
self.assertEqual(module.foo, foo)
|
||||
self.assertEqual(module.bar, bar)
|
||||
self.assertEqual(module.baz, baz)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
205
tensorflow/python/util/module_wrapper.py
Normal file
205
tensorflow/python/util/module_wrapper.py
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Provides wrapper for TensorFlow modules."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util import tf_stack
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
|
||||
|
||||
_PER_MODULE_WARNING_LIMIT = 1
|
||||
|
||||
|
||||
def get_rename_v2(name):
|
||||
if name not in all_renames_v2.symbol_renames:
|
||||
return None
|
||||
return all_renames_v2.symbol_renames[name]
|
||||
|
||||
|
||||
def _call_location():
|
||||
# We want to get stack frame 2 frames up from current frame,
|
||||
# i.e. above _getattr__ and _call_location calls.
|
||||
stack = tf_stack.extract_stack_file_and_line(max_length=3)
|
||||
if not stack: # should never happen as we're in a function
|
||||
return 'UNKNOWN'
|
||||
frame = stack[0]
|
||||
return '{}:{}'.format(frame.file, frame.line)
|
||||
|
||||
|
||||
def contains_deprecation_decorator(decorators):
|
||||
return any(
|
||||
d.decorator_name == 'deprecated' for d in decorators)
|
||||
|
||||
|
||||
def has_deprecation_decorator(symbol):
|
||||
"""Checks if given object has a deprecation decorator.
|
||||
|
||||
We check if deprecation decorator is in decorators as well as
|
||||
whether symbol is a class whose __init__ method has a deprecation
|
||||
decorator.
|
||||
Args:
|
||||
symbol: Python object.
|
||||
|
||||
Returns:
|
||||
True if symbol has deprecation decorator.
|
||||
"""
|
||||
decorators, symbol = tf_decorator.unwrap(symbol)
|
||||
if contains_deprecation_decorator(decorators):
|
||||
return True
|
||||
if tf_inspect.isfunction(symbol):
|
||||
return False
|
||||
if not tf_inspect.isclass(symbol):
|
||||
return False
|
||||
if not hasattr(symbol, '__init__'):
|
||||
return False
|
||||
init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
|
||||
return contains_deprecation_decorator(init_decorators)
|
||||
|
||||
|
||||
class TFModuleWrapper(types.ModuleType):
|
||||
"""Wrapper for TF modules to support deprecation messages and lazyloading."""
|
||||
|
||||
def __init__( # pylint: disable=super-on-old-class
|
||||
self,
|
||||
wrapped,
|
||||
module_name,
|
||||
public_apis=None,
|
||||
deprecation=True,
|
||||
has_lite=False): # pylint: enable=super-on-old-class
|
||||
super(TFModuleWrapper, self).__init__(wrapped.__name__)
|
||||
self.__dict__.update(wrapped.__dict__)
|
||||
# Prefix all local attributes with _tfmw_ so that we can
|
||||
# handle them differently in attribute access methods.
|
||||
self._tfmw_wrapped_module = wrapped
|
||||
self._tfmw_module_name = module_name
|
||||
self._tfmw_public_apis = public_apis
|
||||
self._tfmw_print_deprecation_warnings = deprecation
|
||||
self._tfmw_has_lite = has_lite
|
||||
# Set __all__ so that import * work for lazy loaded modules
|
||||
if self._tfmw_public_apis:
|
||||
self._tfmw_wrapped_module.__all__ = list(self._tfmw_public_apis.keys())
|
||||
self.__all__ = list(self._tfmw_public_apis.keys())
|
||||
# names we already checked for deprecation
|
||||
self._tfmw_deprecated_checked = set()
|
||||
self._tfmw_warning_count = 0
|
||||
|
||||
def _tfmw_add_deprecation_warning(self, name, attr):
|
||||
"""Print deprecation warning for attr with given name if necessary."""
|
||||
if (self._tfmw_warning_count < _PER_MODULE_WARNING_LIMIT and
|
||||
name not in self._tfmw_deprecated_checked):
|
||||
|
||||
self._tfmw_deprecated_checked.add(name)
|
||||
|
||||
if self._tfmw_module_name:
|
||||
full_name = 'tf.%s.%s' % (self._tfmw_module_name, name)
|
||||
else:
|
||||
full_name = 'tf.%s' % name
|
||||
rename = get_rename_v2(full_name)
|
||||
if rename and not has_deprecation_decorator(attr):
|
||||
call_location = _call_location()
|
||||
# skip locations in Python source
|
||||
if not call_location.startswith('<'):
|
||||
logging.warning(
|
||||
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
||||
_call_location(), full_name, rename)
|
||||
self._tfmw_warning_count += 1
|
||||
|
||||
def _tfmw_import_module(self, name):
|
||||
symbol_loc_info = self._tfmw_public_apis[name]
|
||||
if symbol_loc_info[0]:
|
||||
module = importlib.import_module(symbol_loc_info[0])
|
||||
attr = getattr(module, symbol_loc_info[1])
|
||||
else:
|
||||
attr = importlib.import_module(symbol_loc_info[1])
|
||||
setattr(self._tfmw_wrapped_module, name, attr)
|
||||
self.__dict__[name] = attr
|
||||
return attr
|
||||
|
||||
def __getattribute__(self, name): # pylint: disable=super-on-old-class
|
||||
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
|
||||
if name == 'lite':
|
||||
if self._tfmw_has_lite:
|
||||
attr = self._tfmw_import_module(name)
|
||||
setattr(self._tfmw_wrapped_module, 'lite', attr)
|
||||
return attr
|
||||
|
||||
attr = super(TFModuleWrapper, self).__getattribute__(name)
|
||||
if name.startswith('__') or name.startswith('_tfmw_'):
|
||||
return attr
|
||||
|
||||
if self._tfmw_print_deprecation_warnings:
|
||||
self._tfmw_add_deprecation_warning(name, attr)
|
||||
return attr
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
attr = getattr(self._tfmw_wrapped_module, name)
|
||||
except AttributeError as e:
|
||||
if not self._tfmw_public_apis:
|
||||
raise e
|
||||
if name not in self._tfmw_public_apis:
|
||||
raise e
|
||||
attr = self._tfmw_import_module(name)
|
||||
|
||||
if self._tfmw_print_deprecation_warnings:
|
||||
self._tfmw_add_deprecation_warning(name, attr)
|
||||
return attr
|
||||
|
||||
def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
|
||||
if not arg.startswith('_tfmw_'):
|
||||
setattr(self._tfmw_wrapped_module, arg, val)
|
||||
self.__dict__[arg] = val
|
||||
if arg not in self.__all__ and arg != '__all__':
|
||||
self.__all__.append(arg)
|
||||
super(TFModuleWrapper, self).__setattr__(arg, val)
|
||||
|
||||
def __dir__(self):
|
||||
if self._tfmw_public_apis:
|
||||
return list(
|
||||
set(self._tfmw_public_apis.keys()).union(
|
||||
set([
|
||||
attr for attr in dir(self._tfmw_wrapped_module)
|
||||
if not attr.startswith('_')
|
||||
])))
|
||||
else:
|
||||
return dir(self._tfmw_wrapped_module)
|
||||
|
||||
def __delattr__(self, name): # pylint: disable=super-on-old-class
|
||||
if name.startswith('_tfmw_'):
|
||||
super(TFModuleWrapper, self).__delattr__(name)
|
||||
else:
|
||||
delattr(self._tfmw_wrapped_module, name)
|
||||
|
||||
def __repr__(self):
|
||||
return self._tfmw_wrapped_module.__repr__()
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__name__
|
||||
|
||||
def __setstate__(self, d):
|
||||
# pylint: disable=protected-access
|
||||
self.__init__(sys.modules[d]._tfmw_wrapped_module,
|
||||
sys.modules[d]._tfmw_module_name)
|
||||
# pylint: enable=protected-access
|
136
tensorflow/python/util/module_wrapper_test.py
Normal file
136
tensorflow/python/util/module_wrapper_test.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.util.module_wrapper."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import types
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import module_wrapper
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
|
||||
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
||||
|
||||
|
||||
class MockModule(types.ModuleType):
|
||||
pass
|
||||
|
||||
|
||||
class DeprecationWrapperTest(test.TestCase):
|
||||
|
||||
def testWrapperIsAModule(self):
|
||||
module = MockModule('test')
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
|
||||
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
||||
|
||||
@test.mock.patch.object(logging, 'warning', autospec=True)
|
||||
def testDeprecationWarnings(self, mock_warning):
|
||||
module = MockModule('test')
|
||||
module.foo = 1
|
||||
module.bar = 2
|
||||
module.baz = 3
|
||||
all_renames_v2.symbol_renames['tf.test.bar'] = 'tf.bar2'
|
||||
all_renames_v2.symbol_renames['tf.test.baz'] = 'tf.compat.v1.baz'
|
||||
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
|
||||
self.assertTrue(tf_inspect.ismodule(wrapped_module))
|
||||
|
||||
self.assertEqual(0, mock_warning.call_count)
|
||||
bar = wrapped_module.bar
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
foo = wrapped_module.foo
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
baz = wrapped_module.baz # pylint: disable=unused-variable
|
||||
self.assertEqual(2, mock_warning.call_count)
|
||||
baz = wrapped_module.baz
|
||||
self.assertEqual(2, mock_warning.call_count)
|
||||
|
||||
# Check that values stayed the same
|
||||
self.assertEqual(module.foo, foo)
|
||||
self.assertEqual(module.bar, bar)
|
||||
|
||||
|
||||
class LazyLoadingWrapperTest(test.TestCase):
|
||||
|
||||
def testLazyLoad(self):
|
||||
module = MockModule('test')
|
||||
apis = {'cmd': ('', 'cmd'), 'ABCMeta': ('abc', 'ABCMeta')}
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||
module, 'test', public_apis=apis, deprecation=False)
|
||||
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||
from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member
|
||||
self.assertEqual(wrapped_module.cmd, _cmd)
|
||||
self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
|
||||
|
||||
def testLazyLoadLocalOverride(self):
|
||||
# Test that we can override and add fields to the wrapped module.
|
||||
module = MockModule('test')
|
||||
apis = {'cmd': ('', 'cmd')}
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||
module, 'test', public_apis=apis, deprecation=False)
|
||||
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||
self.assertEqual(wrapped_module.cmd, _cmd)
|
||||
setattr(wrapped_module, 'cmd', 1)
|
||||
setattr(wrapped_module, 'cgi', 2)
|
||||
self.assertEqual(wrapped_module.cmd, 1) # override
|
||||
self.assertEqual(wrapped_module.cgi, 2) # add
|
||||
|
||||
def testLazyLoadDict(self):
|
||||
# Test that we can override and add fields to the wrapped module.
|
||||
module = MockModule('test')
|
||||
apis = {'cmd': ('', 'cmd')}
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||
module, 'test', public_apis=apis, deprecation=False)
|
||||
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||
# At first cmd key does not exist in __dict__
|
||||
self.assertNotIn('cmd', wrapped_module.__dict__)
|
||||
# After it is referred (lazyloaded), it gets added to __dict__
|
||||
wrapped_module.cmd # pylint: disable=pointless-statement
|
||||
self.assertEqual(wrapped_module.__dict__['cmd'], _cmd)
|
||||
# When we call setattr, it also gets added to __dict__
|
||||
setattr(wrapped_module, 'cmd2', _cmd)
|
||||
self.assertEqual(wrapped_module.__dict__['cmd2'], _cmd)
|
||||
|
||||
def testLazyLoadWildcardImport(self):
|
||||
# Test that public APIs are in __all__.
|
||||
module = MockModule('test')
|
||||
module._should_not_be_public = 5
|
||||
apis = {'cmd': ('', 'cmd')}
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||
module, 'test', public_apis=apis, deprecation=False)
|
||||
setattr(wrapped_module, 'hello', 1)
|
||||
self.assertIn('hello', wrapped_module.__all__)
|
||||
self.assertIn('cmd', wrapped_module.__all__)
|
||||
self.assertNotIn('_should_not_be_public', wrapped_module.__all__)
|
||||
|
||||
def testLazyLoadCorrectLiteModule(self):
|
||||
# If set, always load lite module from public API list.
|
||||
module = MockModule('test')
|
||||
apis = {'lite': ('', 'cmd')}
|
||||
module.lite = 5
|
||||
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(
|
||||
module, 'test', public_apis=apis, deprecation=False, has_lite=True)
|
||||
self.assertEqual(wrapped_module.lite, _cmd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -23,9 +23,9 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation_wrapper
|
||||
from tensorflow.python.util import module_wrapper
|
||||
|
||||
deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
||||
module_wrapper._PER_MODULE_WARNING_LIMIT = 5
|
||||
|
||||
|
||||
class DeprecationTest(test.TestCase):
|
||||
@ -38,9 +38,8 @@ class DeprecationTest(test.TestCase):
|
||||
|
||||
tf.tables_initializer()
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][1],
|
||||
"deprecation_test.py:")
|
||||
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||
"module_wrapper.py:")
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][2], r"tables_initializer")
|
||||
self.assertRegexpMatches(
|
||||
@ -60,9 +59,8 @@ class DeprecationTest(test.TestCase):
|
||||
|
||||
tf.ragged.RaggedTensorValue(value, row_splits)
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][1],
|
||||
"deprecation_test.py:")
|
||||
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||
"module_wrapper.py:")
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][2], r"ragged.RaggedTensorValue")
|
||||
self.assertRegexpMatches(
|
||||
@ -84,9 +82,8 @@ class DeprecationTest(test.TestCase):
|
||||
|
||||
tf.sparse_mask(array, mask_indices)
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][1],
|
||||
"deprecation_test.py:")
|
||||
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||
"module_wrapper.py:")
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][2], r"sparse_mask")
|
||||
self.assertRegexpMatches(
|
||||
@ -103,9 +100,8 @@ class DeprecationTest(test.TestCase):
|
||||
|
||||
tf.VarLenFeature(tf.dtypes.int32)
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][1],
|
||||
"deprecation_test.py:")
|
||||
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||
"module_wrapper.py:")
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][2], r"VarLenFeature")
|
||||
self.assertRegexpMatches(
|
||||
@ -122,9 +118,8 @@ class DeprecationTest(test.TestCase):
|
||||
|
||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=pointless-statement
|
||||
self.assertEqual(1, mock_warning.call_count)
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][1],
|
||||
"deprecation_test.py:")
|
||||
self.assertRegexpMatches(mock_warning.call_args[0][1],
|
||||
"module_wrapper.py:")
|
||||
self.assertRegexpMatches(
|
||||
mock_warning.call_args[0][2],
|
||||
r"saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY")
|
||||
|
@ -38,6 +38,11 @@ class ModuleTest(test.TestCase):
|
||||
|
||||
def testDict(self):
|
||||
# Check that a few modules are in __dict__.
|
||||
# pylint: disable=pointless-statement
|
||||
tf.nn
|
||||
tf.keras
|
||||
tf.image
|
||||
# pylint: enable=pointless-statement
|
||||
self.assertIn('nn', tf.__dict__)
|
||||
self.assertIn('keras', tf.__dict__)
|
||||
self.assertIn('image', tf.__dict__)
|
||||
|
@ -178,15 +178,11 @@ function prepare_src() {
|
||||
#
|
||||
# import tensorflow as tf
|
||||
#
|
||||
# which is not ok. We are removing the deprecation stuff by using sed and
|
||||
# deleting the pattern that the wrapper uses (all lines between a line ending
|
||||
# with _deprecation_wrapper -- the import line -- and a line containing
|
||||
# _sys.modules[__name__] as the argument of a function -- the last line in
|
||||
# the deprecation autogenerated pattern)
|
||||
# which is not ok. We disable deprecation by using sed to toggle the flag
|
||||
# TODO(mihaimaruseac): When we move the API to root, remove this hack
|
||||
# Note: Can't do in place sed that works on all OS, so use a temp file instead
|
||||
sed \
|
||||
"/_deprecation_wrapper$/,/_sys.modules[__name__],/ d" \
|
||||
"s/deprecation=True/deprecation=False/g" \
|
||||
"${TMPDIR}/tensorflow_core/__init__.py" > "${TMPDIR}/tensorflow_core/__init__.out"
|
||||
mv "${TMPDIR}/tensorflow_core/__init__.out" "${TMPDIR}/tensorflow_core/__init__.py"
|
||||
}
|
||||
|
@ -98,28 +98,6 @@ for _m in _top_level_modules:
|
||||
# We still need all the names that are toplevel on tensorflow_core
|
||||
from tensorflow_core import *
|
||||
|
||||
# We also need to bring in keras if available in tensorflow_core
|
||||
# Above import * doesn't import it as __all__ is updated before keras is hooked
|
||||
try:
|
||||
from tensorflow_core import keras
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
# Similarly for estimator, but only if this file is not read via a
|
||||
# import tensorflow_estimator (same reasoning as above when forwarding estimator
|
||||
# separatedly from the rest of the top level modules)
|
||||
if not _root_estimator:
|
||||
try:
|
||||
from tensorflow_core import estimator
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
# And again for tensorboard (comes as summary)
|
||||
try:
|
||||
from tensorflow_core import summary
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
# In V1 API we need to print deprecation messages
|
||||
from tensorflow.python.util import deprecation_wrapper as _deprecation
|
||||
if not isinstance(_sys.modules[__name__], _deprecation.DeprecationWrapper):
|
||||
|
@ -97,32 +97,4 @@ for _m in _top_level_modules:
|
||||
# We still need all the names that are toplevel on tensorflow_core
|
||||
from tensorflow_core import *
|
||||
|
||||
# We also need to bring in keras if available in tensorflow_core
|
||||
# Above import * doesn't import it as __all__ is updated before keras is hooked
|
||||
try:
|
||||
from tensorflow_core import keras
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
# Similarly for estimator, but only if this file is not read via a
|
||||
# import tensorflow_estimator (same reasoning as above when forwarding estimator
|
||||
# separatedly from the rest of the top level modules)
|
||||
if not _root_estimator:
|
||||
try:
|
||||
from tensorflow_core import estimator
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
# And again for tensorboard (comes as summary)
|
||||
try:
|
||||
from tensorflow_core import summary
|
||||
except ImportError as e:
|
||||
pass
|
||||
|
||||
# Also import module aliases
|
||||
try:
|
||||
from tensorflow_core import losses, metrics, initializers, optimizers
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)
|
||||
|
Loading…
Reference in New Issue
Block a user