diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 5e822f87e8c..e634a2c67cf 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -29,6 +29,7 @@ from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util import tf_stack
+from tensorflow.tools.docs import doc_controls
 
 
 # Allow deprecation warnings to be silenced temporarily with a context manager.
@@ -305,8 +306,23 @@ def deprecated(date, instructions, warn_once=True):
   """
   _validate_deprecation_args(date, instructions)
 
-  def deprecated_wrapper(func):
+  def deprecated_wrapper(func_or_class):
     """Deprecation wrapper."""
+    if isinstance(func_or_class, type):
+      # If a class is deprecated, you actually want to wrap the constructor.
+      cls = func_or_class
+      if cls.__new__ is object.__new__:
+        func = cls.__init__
+        constructor_name = '__init__'
+      else:
+        func = cls.__new__
+        constructor_name = '__new__'
+
+    else:
+      cls = None
+      constructor_name = None
+      func = func_or_class
+
     decorator_utils.validate_callable(func, 'deprecated')
     @functools.wraps(func)
     def new_func(*args, **kwargs):  # pylint: disable=missing-docstring
@@ -322,10 +338,25 @@ def deprecated(date, instructions, warn_once=True):
               'in a future version' if date is None else ('after %s' % date),
               instructions)
       return func(*args, **kwargs)
-    return tf_decorator.make_decorator(
+
+    doc_controls.set_deprecated(new_func)
+    new_func = tf_decorator.make_decorator(
         func, new_func, 'deprecated',
         _add_deprecated_function_notice_to_docstring(func.__doc__, date,
                                                      instructions))
+
+    if cls is None:
+      return new_func
+    else:
+      # Insert the wrapped function as the constructor
+      setattr(cls, constructor_name, new_func)
+
+      # And update the docstring of the class.
+      cls.__doc__ = _add_deprecated_function_notice_to_docstring(
+          cls.__doc__, date, instructions)
+
+      return cls
+
   return deprecated_wrapper
 
 
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index 20c0846cfb8..a8babf3b011 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -19,6 +19,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
+import enum
+
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
@@ -95,6 +98,72 @@ class DeprecationTest(test.TestCase):
     _fn()
     self.assertEqual(1, mock_warning.call_count)
 
+  @test.mock.patch.object(logging, "warning", autospec=True)
+  def test_deprecated_init_class(self, mock_warning):
+    date = "2016-07-04"
+    instructions = "This is how you update..."
+
+    @deprecation.deprecated(date, instructions, warn_once=True)
+    class MyClass():
+      """A test class."""
+
+      def __init__(self, a):
+        pass
+
+    MyClass("")
+    self.assertEqual(1, mock_warning.call_count)
+    MyClass("")
+    self.assertEqual(1, mock_warning.call_count)
+    self.assertIn("IS DEPRECATED", MyClass.__doc__)
+
+  @test.mock.patch.object(logging, "warning", autospec=True)
+  def test_deprecated_new_class(self, mock_warning):
+    date = "2016-07-04"
+    instructions = "This is how you update..."
+
+    @deprecation.deprecated(date, instructions, warn_once=True)
+    class MyStr(str):
+
+      def __new__(cls, value):
+        return str.__new__(cls, value)
+
+    MyStr("abc")
+    self.assertEqual(1, mock_warning.call_count)
+    MyStr("abc")
+    self.assertEqual(1, mock_warning.call_count)
+    self.assertIn("IS DEPRECATED", MyStr.__doc__)
+
+  @test.mock.patch.object(logging, "warning", autospec=True)
+  def test_deprecated_enum(self, mock_warning):
+    date = "2016-07-04"
+    instructions = "This is how you update..."
+
+    @deprecation.deprecated(date, instructions, warn_once=True)
+    class MyEnum(enum.Enum):
+      a = 1
+      b = 2
+
+    self.assertIs(MyEnum(1), MyEnum.a)
+    self.assertEqual(1, mock_warning.call_count)
+    self.assertIs(MyEnum(2), MyEnum.b)
+    self.assertEqual(1, mock_warning.call_count)
+    self.assertIn("IS DEPRECATED", MyEnum.__doc__)
+
+  @test.mock.patch.object(logging, "warning", autospec=True)
+  def test_deprecated_namedtuple(self, mock_warning):
+    date = "2016-07-04"
+    instructions = "This is how you update..."
+
+    mytuple = deprecation.deprecated(
+        date, instructions, warn_once=True)(
+            collections.namedtuple("my_tuple", ["field1", "field2"]))
+
+    mytuple(1, 2)
+    self.assertEqual(1, mock_warning.call_count)
+    mytuple(3, 4)
+    self.assertEqual(1, mock_warning.call_count)
+    self.assertIn("IS DEPRECATED", mytuple.__doc__)
+
   @test.mock.patch.object(logging, "warning", autospec=True)
   def test_silence(self, mock_warning):
     date = "2016-07-04"
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 1998e47a3ad..25cece21e62 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -133,18 +133,6 @@ py_library(
     visibility = ["//visibility:public"],
 )
 
-py_test(
-    name = "doc_controls_test",
-    size = "small",
-    srcs = ["doc_controls_test.py"],
-    python_version = "PY3",
-    srcs_version = "PY2AND3",
-    deps = [
-        ":doc_controls",
-        "//tensorflow/python:platform_test",
-    ],
-)
-
 py_test(
     name = "generate2_test",
     size = "medium",
diff --git a/tensorflow/tools/docs/doc_controls.py b/tensorflow/tools/docs/doc_controls.py
index 27a1d2075e9..631ac89d426 100644
--- a/tensorflow/tools/docs/doc_controls.py
+++ b/tensorflow/tools/docs/doc_controls.py
@@ -18,6 +18,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+_DEPRECATED = "_tf_docs_deprecated"
+
+
+def set_deprecated(obj):
+  """Explicitly tag an object as deprecated for the doc generator."""
+  setattr(obj, _DEPRECATED, None)
+  return obj
+
+
 _DO_NOT_DOC = "_tf_docs_do_not_document"
 
 
@@ -241,82 +250,3 @@ def for_subclass_implementers(obj):
 
 
 do_not_doc_in_subclasses = for_subclass_implementers
-
-
-def should_skip(obj):
-  """Returns true if docs generation should be skipped for this object.
-
-  checks for the `do_not_generate_docs` or `do_not_doc_inheritable` decorators.
-
-  Args:
-    obj: The object to document, or skip.
-
-  Returns:
-    True if the object should be skipped
-  """
-  # Unwrap fget if the object is a property
-  if isinstance(obj, property):
-    obj = obj.fget
-
-  return hasattr(obj, _DO_NOT_DOC) or hasattr(obj, _DO_NOT_DOC_INHERITABLE)
-
-
-def should_skip_class_attr(cls, name):
-  """Returns true if docs should be skipped for this class attribute.
-
-  Args:
-    cls: The class the attribute belongs to.
-    name: The name of the attribute.
-
-  Returns:
-    True if the attribute should be skipped.
-  """
-  # Get the object with standard lookup, from the nearest
-  # defining parent.
-  try:
-    obj = getattr(cls, name)
-  except AttributeError:
-    # Avoid error caused by enum metaclasses in python3
-    if name in ("name", "value"):
-      return True
-    raise
-
-  # Unwrap fget if the object is a property
-  if isinstance(obj, property):
-    obj = obj.fget
-
-  # Skip if the object is decorated with `do_not_generate_docs` or
-  # `do_not_doc_inheritable`
-  if should_skip(obj):
-    return True
-
-  # Use __dict__ lookup to get the version defined in *this* class.
-  obj = cls.__dict__.get(name, None)
-  if isinstance(obj, property):
-    obj = obj.fget
-  if obj is not None:
-    # If not none, the object is defined in *this* class.
-    # Do not skip if decorated with `for_subclass_implementers`.
-    if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS):
-      return False
-
-  # for each parent class
-  for parent in cls.__mro__[1:]:
-    obj = getattr(parent, name, None)
-
-    if obj is None:
-      continue
-
-    if isinstance(obj, property):
-      obj = obj.fget
-
-    # Skip if the parent's definition is decorated with `do_not_doc_inheritable`
-    # or `for_subclass_implementers`
-    if hasattr(obj, _DO_NOT_DOC_INHERITABLE):
-      return True
-
-    if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS):
-      return True
-
-  # No blockng decorators --> don't skip
-  return False
diff --git a/tensorflow/tools/docs/doc_controls_test.py b/tensorflow/tools/docs/doc_controls_test.py
deleted file mode 100644
index d5eb4ffc000..00000000000
--- a/tensorflow/tools/docs/doc_controls_test.py
+++ /dev/null
@@ -1,220 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for documentation control decorators."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.platform import googletest
-from tensorflow.tools.docs import doc_controls
-
-
-class DocControlsTest(googletest.TestCase):
-
-  def test_do_not_generate_docs(self):
-
-    @doc_controls.do_not_generate_docs
-    def dummy_function():
-      pass
-
-    self.assertTrue(doc_controls.should_skip(dummy_function))
-
-  def test_do_not_doc_on_method(self):
-    """The simple decorator is not aware of inheritance."""
-
-    class Parent(object):
-
-      @doc_controls.do_not_generate_docs
-      def my_method(self):
-        pass
-
-    class Child(Parent):
-
-      def my_method(self):
-        pass
-
-    class GrandChild(Child):
-      pass
-
-    self.assertTrue(doc_controls.should_skip(Parent.my_method))
-    self.assertFalse(doc_controls.should_skip(Child.my_method))
-    self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
-
-    self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
-    self.assertFalse(doc_controls.should_skip_class_attr(Child, 'my_method'))
-    self.assertFalse(
-        doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
-
-  def test_do_not_doc_inheritable(self):
-
-    class Parent(object):
-
-      @doc_controls.do_not_doc_inheritable
-      def my_method(self):
-        pass
-
-    class Child(Parent):
-
-      def my_method(self):
-        pass
-
-    class GrandChild(Child):
-      pass
-
-    self.assertTrue(doc_controls.should_skip(Parent.my_method))
-    self.assertFalse(doc_controls.should_skip(Child.my_method))
-    self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
-
-    self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
-    self.assertTrue(
-        doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
-
-  def test_do_not_doc_inheritable_property(self):
-
-    class Parent(object):
-
-      @property
-      @doc_controls.do_not_doc_inheritable
-      def my_method(self):
-        pass
-
-    class Child(Parent):
-
-      @property
-      def my_method(self):
-        pass
-
-    class GrandChild(Child):
-      pass
-
-    self.assertTrue(doc_controls.should_skip(Parent.my_method))
-    self.assertFalse(doc_controls.should_skip(Child.my_method))
-    self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
-
-    self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
-    self.assertTrue(
-        doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
-
-  def test_do_not_doc_inheritable_staticmethod(self):
-
-    class GrandParent(object):
-
-      def my_method(self):
-        pass
-
-    class Parent(GrandParent):
-
-      @staticmethod
-      @doc_controls.do_not_doc_inheritable
-      def my_method():
-        pass
-
-    class Child(Parent):
-
-      @staticmethod
-      def my_method():
-        pass
-
-    class GrandChild(Child):
-      pass
-
-    self.assertFalse(doc_controls.should_skip(GrandParent.my_method))
-    self.assertTrue(doc_controls.should_skip(Parent.my_method))
-    self.assertFalse(doc_controls.should_skip(Child.my_method))
-    self.assertFalse(doc_controls.should_skip(GrandChild.my_method))
-
-    self.assertFalse(
-        doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
-    self.assertTrue(
-        doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
-
-  def test_for_subclass_implementers(self):
-
-    class GrandParent(object):
-
-      def my_method(self):
-        pass
-
-    class Parent(GrandParent):
-
-      @doc_controls.for_subclass_implementers
-      def my_method(self):
-        pass
-
-    class Child(Parent):
-      pass
-
-    class GrandChild(Child):
-
-      def my_method(self):
-        pass
-
-    class Grand2Child(Child):
-      pass
-
-    self.assertFalse(
-        doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
-    self.assertFalse(doc_controls.should_skip_class_attr(Parent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
-    self.assertTrue(
-        doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
-    self.assertTrue(
-        doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
-
-  def test_for_subclass_implementers_short_circuit(self):
-
-    class GrandParent(object):
-
-      @doc_controls.for_subclass_implementers
-      def my_method(self):
-        pass
-
-    class Parent(GrandParent):
-
-      def my_method(self):
-        pass
-
-    class Child(Parent):
-
-      @doc_controls.do_not_doc_inheritable
-      def my_method(self):
-        pass
-
-    class GrandChild(Child):
-
-      @doc_controls.for_subclass_implementers
-      def my_method(self):
-        pass
-
-    class Grand2Child(Child):
-      pass
-
-    self.assertFalse(
-        doc_controls.should_skip_class_attr(GrandParent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method'))
-    self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method'))
-    self.assertFalse(
-        doc_controls.should_skip_class_attr(GrandChild, 'my_method'))
-    self.assertTrue(
-        doc_controls.should_skip_class_attr(Grand2Child, 'my_method'))
-
-
-if __name__ == '__main__':
-  googletest.main()