diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 5e822f87e8c..e634a2c67cf 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -29,6 +29,7 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_stack +from tensorflow.tools.docs import doc_controls # Allow deprecation warnings to be silenced temporarily with a context manager. @@ -305,8 +306,23 @@ def deprecated(date, instructions, warn_once=True): """ _validate_deprecation_args(date, instructions) - def deprecated_wrapper(func): + def deprecated_wrapper(func_or_class): """Deprecation wrapper.""" + if isinstance(func_or_class, type): + # If a class is deprecated, you actually want to wrap the constructor. + cls = func_or_class + if cls.__new__ is object.__new__: + func = cls.__init__ + constructor_name = '__init__' + else: + func = cls.__new__ + constructor_name = '__new__' + + else: + cls = None + constructor_name = None + func = func_or_class + decorator_utils.validate_callable(func, 'deprecated') @functools.wraps(func) def new_func(*args, **kwargs): # pylint: disable=missing-docstring @@ -322,10 +338,25 @@ def deprecated(date, instructions, warn_once=True): 'in a future version' if date is None else ('after %s' % date), instructions) return func(*args, **kwargs) - return tf_decorator.make_decorator( + + doc_controls.set_deprecated(new_func) + new_func = tf_decorator.make_decorator( func, new_func, 'deprecated', _add_deprecated_function_notice_to_docstring(func.__doc__, date, instructions)) + + if cls is None: + return new_func + else: + # Insert the wrapped function as the constructor + setattr(cls, constructor_name, new_func) + + # And update the docstring of the class. + cls.__doc__ = _add_deprecated_function_notice_to_docstring( + cls.__doc__, date, instructions) + + return cls + return deprecated_wrapper diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index 20c0846cfb8..a8babf3b011 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -19,6 +19,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import enum + from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -95,6 +98,72 @@ class DeprecationTest(test.TestCase): _fn() self.assertEqual(1, mock_warning.call_count) + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_init_class(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated(date, instructions, warn_once=True) + class MyClass(): + """A test class.""" + + def __init__(self, a): + pass + + MyClass("") + self.assertEqual(1, mock_warning.call_count) + MyClass("") + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", MyClass.__doc__) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_new_class(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated(date, instructions, warn_once=True) + class MyStr(str): + + def __new__(cls, value): + return str.__new__(cls, value) + + MyStr("abc") + self.assertEqual(1, mock_warning.call_count) + MyStr("abc") + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", MyStr.__doc__) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_enum(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + @deprecation.deprecated(date, instructions, warn_once=True) + class MyEnum(enum.Enum): + a = 1 + b = 2 + + self.assertIs(MyEnum(1), MyEnum.a) + self.assertEqual(1, mock_warning.call_count) + self.assertIs(MyEnum(2), MyEnum.b) + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", MyEnum.__doc__) + + @test.mock.patch.object(logging, "warning", autospec=True) + def test_deprecated_namedtuple(self, mock_warning): + date = "2016-07-04" + instructions = "This is how you update..." + + mytuple = deprecation.deprecated( + date, instructions, warn_once=True)( + collections.namedtuple("my_tuple", ["field1", "field2"])) + + mytuple(1, 2) + self.assertEqual(1, mock_warning.call_count) + mytuple(3, 4) + self.assertEqual(1, mock_warning.call_count) + self.assertIn("IS DEPRECATED", mytuple.__doc__) + @test.mock.patch.object(logging, "warning", autospec=True) def test_silence(self, mock_warning): date = "2016-07-04" diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 1998e47a3ad..25cece21e62 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -133,18 +133,6 @@ py_library( visibility = ["//visibility:public"], ) -py_test( - name = "doc_controls_test", - size = "small", - srcs = ["doc_controls_test.py"], - python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":doc_controls", - "//tensorflow/python:platform_test", - ], -) - py_test( name = "generate2_test", size = "medium", diff --git a/tensorflow/tools/docs/doc_controls.py b/tensorflow/tools/docs/doc_controls.py index 27a1d2075e9..631ac89d426 100644 --- a/tensorflow/tools/docs/doc_controls.py +++ b/tensorflow/tools/docs/doc_controls.py @@ -18,6 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +_DEPRECATED = "_tf_docs_deprecated" + + +def set_deprecated(obj): + """Explicitly tag an object as deprecated for the doc generator.""" + setattr(obj, _DEPRECATED, None) + return obj + + _DO_NOT_DOC = "_tf_docs_do_not_document" @@ -241,82 +250,3 @@ def for_subclass_implementers(obj): do_not_doc_in_subclasses = for_subclass_implementers - - -def should_skip(obj): - """Returns true if docs generation should be skipped for this object. - - checks for the `do_not_generate_docs` or `do_not_doc_inheritable` decorators. - - Args: - obj: The object to document, or skip. - - Returns: - True if the object should be skipped - """ - # Unwrap fget if the object is a property - if isinstance(obj, property): - obj = obj.fget - - return hasattr(obj, _DO_NOT_DOC) or hasattr(obj, _DO_NOT_DOC_INHERITABLE) - - -def should_skip_class_attr(cls, name): - """Returns true if docs should be skipped for this class attribute. - - Args: - cls: The class the attribute belongs to. - name: The name of the attribute. - - Returns: - True if the attribute should be skipped. - """ - # Get the object with standard lookup, from the nearest - # defining parent. - try: - obj = getattr(cls, name) - except AttributeError: - # Avoid error caused by enum metaclasses in python3 - if name in ("name", "value"): - return True - raise - - # Unwrap fget if the object is a property - if isinstance(obj, property): - obj = obj.fget - - # Skip if the object is decorated with `do_not_generate_docs` or - # `do_not_doc_inheritable` - if should_skip(obj): - return True - - # Use __dict__ lookup to get the version defined in *this* class. - obj = cls.__dict__.get(name, None) - if isinstance(obj, property): - obj = obj.fget - if obj is not None: - # If not none, the object is defined in *this* class. - # Do not skip if decorated with `for_subclass_implementers`. - if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS): - return False - - # for each parent class - for parent in cls.__mro__[1:]: - obj = getattr(parent, name, None) - - if obj is None: - continue - - if isinstance(obj, property): - obj = obj.fget - - # Skip if the parent's definition is decorated with `do_not_doc_inheritable` - # or `for_subclass_implementers` - if hasattr(obj, _DO_NOT_DOC_INHERITABLE): - return True - - if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS): - return True - - # No blockng decorators --> don't skip - return False diff --git a/tensorflow/tools/docs/doc_controls_test.py b/tensorflow/tools/docs/doc_controls_test.py deleted file mode 100644 index d5eb4ffc000..00000000000 --- a/tensorflow/tools/docs/doc_controls_test.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for documentation control decorators.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.platform import googletest -from tensorflow.tools.docs import doc_controls - - -class DocControlsTest(googletest.TestCase): - - def test_do_not_generate_docs(self): - - @doc_controls.do_not_generate_docs - def dummy_function(): - pass - - self.assertTrue(doc_controls.should_skip(dummy_function)) - - def test_do_not_doc_on_method(self): - """The simple decorator is not aware of inheritance.""" - - class Parent(object): - - @doc_controls.do_not_generate_docs - def my_method(self): - pass - - class Child(Parent): - - def my_method(self): - pass - - class GrandChild(Child): - pass - - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertFalse(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertFalse( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_do_not_doc_inheritable(self): - - class Parent(object): - - @doc_controls.do_not_doc_inheritable - def my_method(self): - pass - - class Child(Parent): - - def my_method(self): - pass - - class GrandChild(Child): - pass - - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_do_not_doc_inheritable_property(self): - - class Parent(object): - - @property - @doc_controls.do_not_doc_inheritable - def my_method(self): - pass - - class Child(Parent): - - @property - def my_method(self): - pass - - class GrandChild(Child): - pass - - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_do_not_doc_inheritable_staticmethod(self): - - class GrandParent(object): - - def my_method(self): - pass - - class Parent(GrandParent): - - @staticmethod - @doc_controls.do_not_doc_inheritable - def my_method(): - pass - - class Child(Parent): - - @staticmethod - def my_method(): - pass - - class GrandChild(Child): - pass - - self.assertFalse(doc_controls.should_skip(GrandParent.my_method)) - self.assertTrue(doc_controls.should_skip(Parent.my_method)) - self.assertFalse(doc_controls.should_skip(Child.my_method)) - self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) - - self.assertFalse( - doc_controls.should_skip_class_attr(GrandParent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - - def test_for_subclass_implementers(self): - - class GrandParent(object): - - def my_method(self): - pass - - class Parent(GrandParent): - - @doc_controls.for_subclass_implementers - def my_method(self): - pass - - class Child(Parent): - pass - - class GrandChild(Child): - - def my_method(self): - pass - - class Grand2Child(Child): - pass - - self.assertFalse( - doc_controls.should_skip_class_attr(GrandParent, 'my_method')) - self.assertFalse(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(Grand2Child, 'my_method')) - - def test_for_subclass_implementers_short_circuit(self): - - class GrandParent(object): - - @doc_controls.for_subclass_implementers - def my_method(self): - pass - - class Parent(GrandParent): - - def my_method(self): - pass - - class Child(Parent): - - @doc_controls.do_not_doc_inheritable - def my_method(self): - pass - - class GrandChild(Child): - - @doc_controls.for_subclass_implementers - def my_method(self): - pass - - class Grand2Child(Child): - pass - - self.assertFalse( - doc_controls.should_skip_class_attr(GrandParent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) - self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) - self.assertFalse( - doc_controls.should_skip_class_attr(GrandChild, 'my_method')) - self.assertTrue( - doc_controls.should_skip_class_attr(Grand2Child, 'my_method')) - - -if __name__ == '__main__': - googletest.main()