From 6a115c01e2c55ac7a69dc211ed2bc433e3584444 Mon Sep 17 00:00:00 2001 From: Yash Katariya <yashkatariya@google.com> Date: Thu, 5 Nov 2020 15:49:30 -0800 Subject: [PATCH] Adding support for overriding `do_not_doc_in_subclasses` decorator. This is useful to document `call` methods on `tf.keras.Model` class and all its child classes. Sequential won't document that method because of `do_not_doc_inheritable` decorator added in this CL. PiperOrigin-RevId: 340943392 Change-Id: I80bcbce117f14eb098236d6cdf0fbeab0ad7e720 --- tensorflow/python/keras/engine/functional.py | 2 + tensorflow/python/keras/engine/training.py | 6 ++ tensorflow/tools/docs/doc_controls.py | 65 ++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 84db63b273e..942f8035530 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest +from tensorflow.tools.docs import doc_controls # pylint: disable=g-classes-have-attributes @@ -403,6 +404,7 @@ class Functional(training_lib.Model): return nest.map_structure(lambda t: getattr(t, '_keras_mask', None), output_tensors) + @doc_controls.do_not_doc_inheritable def call(self, inputs, training=None, mask=None): """Calls the model on new inputs. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 07ea021e69e..a644b9511d2 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -422,6 +422,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): 'the correct dtype).') super(Model, self).build(input_shape) + @doc_controls.doc_in_current_and_subclasses def call(self, inputs, training=None, mask=None): """Calls the model on new inputs. @@ -429,6 +430,11 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs). + Note: This method should not be called directly. It is only meant to be + overridden when subclassing `tf.keras.Model`. + To call a model on an input, always use the `__call__` method, + i.e. `model(inputs)`, which relies on the underlying `call` method. + Arguments: inputs: A tensor or list of tensors. training: Boolean or boolean scalar tensor, indicating whether to run diff --git a/tensorflow/tools/docs/doc_controls.py b/tensorflow/tools/docs/doc_controls.py index 4899e4c7b1a..e6075249b2d 100644 --- a/tensorflow/tools/docs/doc_controls.py +++ b/tensorflow/tools/docs/doc_controls.py @@ -286,3 +286,68 @@ def doc_private(obj: T) -> T: setattr(obj, _DOC_PRIVATE, None) return obj + + +_DOC_IN_CURRENT_AND_SUBCLASSES = "_tf_docs_doc_in_current_and_subclasses" + + +def doc_in_current_and_subclasses(obj: T) -> T: + """Overrides `do_not_doc_in_subclasses` decorator. + + If this decorator is set on a child class's method whose parent's method + contains `do_not_doc_in_subclasses`, then that will be overriden and the + child method will get documented. All classes inherting from the child will + also document that method. + + For example: + + ``` + class Parent: + @do_not_doc_in_subclasses + def method1(self): + pass + def method2(self): + pass + + class Child1(Parent): + @doc_in_current_and_subclasses + def method1(self): + pass + def method2(self): + pass + + class Child2(Parent): + def method1(self): + pass + def method2(self): + pass + + class Child11(Child1): + pass + ``` + + This will produce the following docs: + + ``` + /Parent.md + # method1 + # method2 + /Child1.md + # method1 + # method2 + /Child2.md + # method2 + /Child11.md + # method1 + # method2 + ``` + + Args: + obj: The class-attribute to hide from the generated docs. + + Returns: + obj + """ + + setattr(obj, _DOC_IN_CURRENT_AND_SUBCLASSES, None) + return obj