Updates the Keras weight/variable docstrings to explicitly call out that they won't track the weights/variables of nested tf.Modules, because this is a common point of user confusion and we don't currently have a viable approach to change this.

PiperOrigin-RevId: 330850040
Change-Id: I12cbfc39c6fae11545256083c691bf913cb11f07
This commit is contained in:
Tomer Kaftan 2020-09-09 19:37:50 -07:00 committed by TensorFlower Gardener
parent 4698aad6fe
commit 9d36befed5
2 changed files with 15 additions and 0 deletions

View File

@ -1320,6 +1320,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Trainable weights are updated via gradient descent during training.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of trainable variables.
"""
@ -1336,6 +1339,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Non-trainable weights are *not* updated during training. They are expected
to be updated manually in `call()`.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of non-trainable variables.
"""
@ -1354,6 +1360,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def weights(self):
"""Returns the list of all layer variables/weights.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of variables.
"""
@ -2251,6 +2260,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Alias of `self.weights`.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of variables.
"""

View File

@ -2298,6 +2298,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
def weights(self):
"""Returns the list of all layer variables/weights.
Note: This will not track the weights of nested `tf.Modules` that are not
themselves Keras layers.
Returns:
A list of variables.
"""