Support classes in deprecation.

For docs tools: Use a specific tag for deprecation, and just use the decorator detector as a fallback.

PiperOrigin-RevId: 337493352
Change-Id: Iaeace587c0db7d843771719610649664db3e262b
This commit is contained in:
Mark Daoust 2020-10-16 06:17:37 -07:00 committed by TensorFlower Gardener
parent dc843fad65
commit 383fabd5a8
5 changed files with 111 additions and 313 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import tf_stack
from tensorflow.tools.docs import doc_controls
# Allow deprecation warnings to be silenced temporarily with a context manager.
@ -305,8 +306,23 @@ def deprecated(date, instructions, warn_once=True):
"""
_validate_deprecation_args(date, instructions)
def deprecated_wrapper(func):
def deprecated_wrapper(func_or_class):
"""Deprecation wrapper."""
if isinstance(func_or_class, type):
# If a class is deprecated, you actually want to wrap the constructor.
cls = func_or_class
if cls.__new__ is object.__new__:
func = cls.__init__
constructor_name = '__init__'
else:
func = cls.__new__
constructor_name = '__new__'
else:
cls = None
constructor_name = None
func = func_or_class
decorator_utils.validate_callable(func, 'deprecated')
@functools.wraps(func)
def new_func(*args, **kwargs): # pylint: disable=missing-docstring
@ -322,10 +338,25 @@ def deprecated(date, instructions, warn_once=True):
'in a future version' if date is None else ('after %s' % date),
instructions)
return func(*args, **kwargs)
return tf_decorator.make_decorator(
doc_controls.set_deprecated(new_func)
new_func = tf_decorator.make_decorator(
func, new_func, 'deprecated',
_add_deprecated_function_notice_to_docstring(func.__doc__, date,
instructions))
if cls is None:
return new_func
else:
# Insert the wrapped function as the constructor
setattr(cls, constructor_name, new_func)
# And update the docstring of the class.
cls.__doc__ = _add_deprecated_function_notice_to_docstring(
cls.__doc__, date, instructions)
return cls
return deprecated_wrapper

View File

@ -19,6 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import enum
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@ -95,6 +98,72 @@ class DeprecationTest(test.TestCase):
_fn()
self.assertEqual(1, mock_warning.call_count)
@test.mock.patch.object(logging, "warning", autospec=True)
def test_deprecated_init_class(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions, warn_once=True)
class MyClass():
"""A test class."""
def __init__(self, a):
pass
MyClass("")
self.assertEqual(1, mock_warning.call_count)
MyClass("")
self.assertEqual(1, mock_warning.call_count)
self.assertIn("IS DEPRECATED", MyClass.__doc__)
@test.mock.patch.object(logging, "warning", autospec=True)
def test_deprecated_new_class(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions, warn_once=True)
class MyStr(str):
def __new__(cls, value):
return str.__new__(cls, value)
MyStr("abc")
self.assertEqual(1, mock_warning.call_count)
MyStr("abc")
self.assertEqual(1, mock_warning.call_count)
self.assertIn("IS DEPRECATED", MyStr.__doc__)
@test.mock.patch.object(logging, "warning", autospec=True)
def test_deprecated_enum(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
@deprecation.deprecated(date, instructions, warn_once=True)
class MyEnum(enum.Enum):
a = 1
b = 2
self.assertIs(MyEnum(1), MyEnum.a)
self.assertEqual(1, mock_warning.call_count)
self.assertIs(MyEnum(2), MyEnum.b)
self.assertEqual(1, mock_warning.call_count)
self.assertIn("IS DEPRECATED", MyEnum.__doc__)
@test.mock.patch.object(logging, "warning", autospec=True)
def test_deprecated_namedtuple(self, mock_warning):
date = "2016-07-04"
instructions = "This is how you update..."
mytuple = deprecation.deprecated(
date, instructions, warn_once=True)(
collections.namedtuple("my_tuple", ["field1", "field2"]))
mytuple(1, 2)
self.assertEqual(1, mock_warning.call_count)
mytuple(3, 4)
self.assertEqual(1, mock_warning.call_count)
self.assertIn("IS DEPRECATED", mytuple.__doc__)
@test.mock.patch.object(logging, "warning", autospec=True)
def test_silence(self, mock_warning):
date = "2016-07-04"

View File

@ -133,18 +133,6 @@ py_library(
visibility = ["//visibility:public"],
)
py_test(
name = "doc_controls_test",
size = "small",
srcs = ["doc_controls_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":doc_controls",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "generate2_test",
size = "medium",

View File

@ -18,6 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
_DEPRECATED = "_tf_docs_deprecated"
def set_deprecated(obj):
"""Explicitly tag an object as deprecated for the doc generator."""
setattr(obj, _DEPRECATED, None)
return obj
_DO_NOT_DOC = "_tf_docs_do_not_document"
@ -241,82 +250,3 @@ def for_subclass_implementers(obj):
do_not_doc_in_subclasses = for_subclass_implementers
def should_skip(obj):
"""Returns true if docs generation should be skipped for this object.
checks for the `do_not_generate_docs` or `do_not_doc_inheritable` decorators.
Args:
obj: The object to document, or skip.
Returns:
True if the object should be skipped
"""
# Unwrap fget if the object is a property
if isinstance(obj, property):
obj = obj.fget
return hasattr(obj, _DO_NOT_DOC) or hasattr(obj, _DO_NOT_DOC_INHERITABLE)
def should_skip_class_attr(cls, name):
"""Returns true if docs should be skipped for this class attribute.
Args:
cls: The class the attribute belongs to.
name: The name of the attribute.
Returns:
True if the attribute should be skipped.
"""
# Get the object with standard lookup, from the nearest
# defining parent.
try:
obj = getattr(cls, name)
except AttributeError:
# Avoid error caused by enum metaclasses in python3
if name in ("name", "value"):
return True
raise
# Unwrap fget if the object is a property
if isinstance(obj, property):
obj = obj.fget
# Skip if the object is decorated with `do_not_generate_docs` or
# `do_not_doc_inheritable`
if should_skip(obj):
return True
# Use __dict__ lookup to get the version defined in *this* class.
obj = cls.__dict__.get(name, None)
if isinstance(obj, property):
obj = obj.fget
if obj is not None:
# If not none, the object is defined in *this* class.
# Do not skip if decorated with `for_subclass_implementers`.
if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS):
return False
# for each parent class
for parent in cls.__mro__[1:]:
obj = getattr(parent, name, None)
if obj is None:
continue
if isinstance(obj, property):
obj = obj.fget
# Skip if the parent's definition is decorated with `do_not_doc_inheritable`
# or `for_subclass_implementers`
if hasattr(obj, _DO_NOT_DOC_INHERITABLE):
return True
if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS):
return True
# No blockng decorators --> don't skip
return False

View File

@ -1,220 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for documentation control decorators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import googletest
from tensorflow.tools.docs import doc_controls
class DocControlsTest(googletest.TestCase):
def test_do_not_generate_docs(self):
@doc_controls.do_not_generate_docs
def dummy_function():
pass
self.assertTrue(doc_controls.should_skip(dummy_function))
def test_do_not_doc_on_method(self):
"""The simple decorator is not aware of inheritance."""
class Parent(object):
@doc_controls.do_not_generate_docs
def my_method(self):
pass
class Child(Parent):
def my_method(self):
pass
class GrandChild(Child):
pass
self.assertTrue(doc_controls.should_skip(Parent.my_method))
self.assertFalse(doc_controls.should_skip(Child.my_method))
self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
self.assertFalse(doc_controls.should_skip_class_attr(Child, 'my_method'))
self.assertFalse(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
def test_do_not_doc_inheritable(self):
class Parent(object):
@doc_controls.do_not_doc_inheritable
def my_method(self):
pass
class Child(Parent):
def my_method(self):
pass
class GrandChild(Child):
pass
self.assertTrue(doc_controls.should_skip(Parent.my_method))
self.assertFalse(doc_controls.should_skip(Child.my_method))
self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
self.assertTrue(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
def test_do_not_doc_inheritable_property(self):
class Parent(object):
@property
@doc_controls.do_not_doc_inheritable
def my_method(self):
pass
class Child(Parent):
@property
def my_method(self):
pass
class GrandChild(Child):
pass
self.assertTrue(doc_controls.should_skip(Parent.my_method))
self.assertFalse(doc_controls.should_skip(Child.my_method))
self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
self.assertTrue(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
def test_do_not_doc_inheritable_staticmethod(self):
class GrandParent(object):
def my_method(self):
pass
class Parent(GrandParent):
@staticmethod
@doc_controls.do_not_doc_inheritable
def my_method():
pass
class Child(Parent):
@staticmethod
def my_method():
pass
class GrandChild(Child):
pass
self.assertFalse(doc_controls.should_skip(GrandParent.my_method))
self.assertTrue(doc_controls.should_skip(Parent.my_method))
self.assertFalse(doc_controls.should_skip(Child.my_method))
self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
self.assertFalse(
doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
self.assertTrue(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
def test_for_subclass_implementers(self):
class GrandParent(object):
def my_method(self):
pass
class Parent(GrandParent):
@doc_controls.for_subclass_implementers
def my_method(self):
pass
class Child(Parent):
pass
class GrandChild(Child):
def my_method(self):
pass
class Grand2Child(Child):
pass
self.assertFalse(
doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
self.assertFalse(doc_controls.should_skip_class_attr(Parent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
self.assertTrue(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
self.assertTrue(
doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
def test_for_subclass_implementers_short_circuit(self):
class GrandParent(object):
@doc_controls.for_subclass_implementers
def my_method(self):
pass
class Parent(GrandParent):
def my_method(self):
pass
class Child(Parent):
@doc_controls.do_not_doc_inheritable
def my_method(self):
pass
class GrandChild(Child):
@doc_controls.for_subclass_implementers
def my_method(self):
pass
class Grand2Child(Child):
pass
self.assertFalse(
doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
self.assertFalse(
doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
self.assertTrue(
doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
if __name__ == '__main__':
googletest.main()