Add deprecation aliases for classes as well. Also, while at it, fixed lint warning about DeprecatedNamesAlreadySet error name.

PiperOrigin-RevId: 242218510
This commit is contained in:
Anna R 2019-04-05 16:48:20 -07:00 committed by TensorFlower Gardener
parent 124ef01100
commit 39c48fcd78
3 changed files with 82 additions and 34 deletions

View File

@ -85,6 +85,14 @@ def format_import(source_module_name, source_name, dest_name):
return 'import %s as %s' % (source_name, dest_name)
def is_supported_type_for_deprecation(symbol):
# Exclude Exception subclasses since users should be able to
# "except" the same type of exception that was "raised" (i.e. we
# shouldn't wrap it with deprecation alias).
return tf_inspect.isfunction(symbol) or (
tf_inspect.isclass(symbol) and not issubclass(symbol, Exception))
class _ModuleInitCodeBuilder(object):
"""Builds a map from module name to imports included in that module."""
@ -330,10 +338,9 @@ def add_imports_for_symbol(
dest_module = _join_modules(output_module_prefix, dest_module)
# Add deprecated alias if only some of the endpoints are deprecated
# and symbol is not under compat.v*.
# TODO(annarev): handle deprecated class endpoints as well.
if (export not in exports_v2 and canonical_endpoint and
not dest_module.startswith(_COMPAT_MODULE_PREFIX) and
tf_inspect.isfunction(symbol)):
is_supported_type_for_deprecation(symbol)):
module_code_builder.add_deprecated_endpoint(
id(symbol), dest_module, source_module_name, source_name, dest_name,
canonical_endpoint)

View File

@ -38,7 +38,7 @@ _PRINT_DEPRECATION_WARNINGS = True
_PRINTED_WARNING = {}
class DeprecatedNamesAlreadySet(Exception):
class DeprecatedNamesAlreadySetError(Exception):
"""Raised when setting deprecated names multiple times for the same symbol."""
pass
@ -97,6 +97,11 @@ def _validate_deprecation_args(date, instructions):
raise ValueError('Don\'t deprecate things without conversion instructions!')
def istuplesubclass(f):
"""Returns True if the argument inherits from tuple."""
return tuple in tf_inspect.getmro(f)
def _call_location(outer=False):
"""Returns call location given level up from current call."""
stack = tf_stack.extract_stack_file_and_line(max_length=4)
@ -186,6 +191,41 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
use and has a modified docstring.
"""
if tf_inspect.isclass(func_or_class):
# In most cases, we want to override __init__ to print deprecation warning.
# However, if a class inherits from immutable type, then we need to
# override __new__ instead. Unfortunately, there doesn't seem to be an
# easy way to check if a type is immutable.
if istuplesubclass(func_or_class):
# Make a new class with __new__ wrapped in a warning.
class _NewClass(func_or_class): # pylint: disable=missing-docstring
__doc__ = decorator_utils.add_notice_to_docstring(
func_or_class.__doc__, 'Please use %s instead.' % name,
'DEPRECATED CLASS',
'(deprecated)', ['THIS CLASS IS DEPRECATED. '
'It will be removed in a future version. '])
__name__ = func_or_class.__name__
__module__ = _call_location(outer=True)
@_wrap_decorator(func_or_class.__new__)
def __new__(cls, *args, **kwargs):
_NewClass.__new__.__doc__ = func_or_class.__new__.__doc__
if _PRINT_DEPRECATION_WARNINGS:
# We're making the alias as we speak. The original may have other
# aliases, so we cannot use it to check for whether it's already
# been warned about.
if _NewClass.__new__ not in _PRINTED_WARNING:
if warn_once:
_PRINTED_WARNING[_NewClass.__new__] = True
logging.warning(
'From %s: The name %s is deprecated. '
'Please use %s instead.\n',
_call_location(), deprecated_name, name)
return super(_NewClass, cls).__new__(cls, *args, **kwargs)
return _NewClass
else:
# Make a new class with __init__ wrapped in a warning.
class _NewClass(func_or_class): # pylint: disable=missing-docstring
@ -208,15 +248,16 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
if _PRINT_DEPRECATION_WARNINGS:
# We're making the alias as we speak. The original may have other
# aliases, so we cannot use it to check for whether it's already been
# warned about.
# aliases, so we cannot use it to check for whether it's already
# been warned about.
if _NewClass.__init__ not in _PRINTED_WARNING:
if warn_once:
_PRINTED_WARNING[_NewClass.__init__] = True
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
'From %s: The name %s is deprecated. '
'Please use %s instead.\n',
_call_location(), deprecated_name, name)
super(_NewClass, self).__init__(*args, **kwargs)
super(_NewClass, self).__init__(*args, **kwargs) # pylint: disable=bad-super-call
return _NewClass
else:
@ -261,7 +302,7 @@ def deprecated_endpoints(*args):
def deprecated_wrapper(func):
# pylint: disable=protected-access
if '_tf_deprecated_api_names' in func.__dict__:
raise DeprecatedNamesAlreadySet(
raise DeprecatedNamesAlreadySetError(
'Cannot set deprecated names for %s to %s. '
'Deprecated names are already set to %s.' % (
func.__name__, str(args), str(func._tf_deprecated_api_names)))

View File

@ -976,7 +976,7 @@ class DeprecatedEndpointsTest(test.TestCase):
self.assertEqual(("foo1", "foo2"), foo._tf_deprecated_api_names)
def testCannotSetDeprecatedEndpointsTwice(self):
with self.assertRaises(deprecation.DeprecatedNamesAlreadySet):
with self.assertRaises(deprecation.DeprecatedNamesAlreadySetError):
@deprecation.deprecated_endpoints("foo1")
@deprecation.deprecated_endpoints("foo2")
def foo(): # pylint: disable=unused-variable