Add examples to layer losses, metrics, add_metric, model reset_metrics APIs.

PiperOrigin-RevId: 306448789
Change-Id: I8a21d9802ab7c0bf371ab219768609f33cf1146a
This commit is contained in:
Pavithra Vijay 2020-04-14 09:01:32 -07:00 committed by TensorFlower Gardener
parent 86592330b1
commit 22573b59c0
2 changed files with 111 additions and 9 deletions

View File

@ -123,8 +123,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
precision is used with a `tf.keras.mixed_precision.experimental.Policy`,
this is instead just the dtype of the layer's weights, as the computations
are done in a different dtype.
losses: List of losses added to this layer (via `self.add_loss()`).
metrics: List of metrics added to this layer (via `self.add_metric()`)..
trainable_weights: List of variables to be included in backprop.
non_trainable_weights: List of variables that should not be
included in backprop.
@ -1140,12 +1138,42 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@property
def losses(self):
"""Losses which are associated with this `Layer`.
"""List of losses added using the `add_loss()` API.
Variable regularization tensors are created when this property is accessed,
so it is eager safe: accessing `losses` under a `tf.GradientTape` will
propagate gradients back to the corresponding variables.
Examples:
>>> class MyLayer(tf.keras.layers.Layer):
... def call(self, inputs):
... self.add_loss(tf.abs(tf.reduce_mean(inputs)))
... return inputs
>>> l = MyLayer()
>>> l(np.ones((10, 1)))
>>> l.losses
[1.0]
>>> inputs = tf.keras.Input(shape=(10,))
>>> x = tf.keras.layers.Dense(10)(inputs)
>>> outputs = tf.keras.layers.Dense(1)(x)
>>> model = tf.keras.Model(inputs, outputs)
>>> # Activity regularization.
>>> model.add_loss(tf.abs(tf.reduce_mean(x)))
>>> model.losses
[<tf.Tensor 'Abs:0' shape=() dtype=float32>]
>>> inputs = tf.keras.Input(shape=(10,))
>>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
>>> x = d(inputs)
>>> outputs = tf.keras.layers.Dense(1)(x)
>>> model = tf.keras.Model(inputs, outputs)
>>> # Weight regularization.
>>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
>>> model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
Returns:
A list of tensors.
"""
@ -1215,11 +1243,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
```python
inputs = tf.keras.Input(shape=(10,))
x = tf.keras.layers.Dense(10)(inputs)
d = tf.keras.layers.Dense(10)
x = d(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
# Weight regularization.
model.add_loss(lambda: tf.reduce_mean(x.kernel))
model.add_loss(lambda: tf.reduce_mean(d.kernel))
```
The `get_losses_for` method allows to retrieve the losses relevant to a
@ -1302,7 +1331,21 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@property
def metrics(self):
"""List of `tf.keras.metrics.Metric` instances tracked by the layer."""
"""List of metrics added using the `add_metric()` API.
Example:
>>> input = tf.keras.layers.Input(shape=(3,))
>>> d = tf.keras.layers.Dense(2)
>>> output = d(input)
>>> d.add_metric(tf.reduce_max(output), name='max')
>>> d.add_metric(tf.reduce_min(output), name='min')
>>> [m.name for m in d.metrics]
['max', 'min']
Returns:
A list of tensors.
"""
collected_metrics = []
all_layers = self._gather_unique_layers()
for layer in all_layers:
@ -1313,6 +1356,48 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
def add_metric(self, value, name=None, **kwargs):
"""Adds metric tensor to the layer.
This method can be used inside the `call()` method of a subclassed layer
or model.
```python
class MyMetricLayer(tf.keras.layers.Layer):
def __init__(self):
super(MyMetricLayer, self).__init__(name='my_metric_layer')
self.mean = metrics_module.Mean(name='metric_1')
def call(self, inputs):
# Provide same name as in the instance created in __init__
self.add_metric(self.mean(x), name='metric_1')
self.add_metric(math_ops.reduce_sum(x), name='metric_2')
return inputs
```
This method can also be called directly on a Functional Model during
construction. In this case, any tensor passed to this Model must
be symbolic and be able to be traced back to the model's `Input`s. These
metrics become part of the model's topology and are tracked when you
save the model via `save()`.
```python
inputs = tf.keras.Input(shape=(10,))
x = tf.keras.layers.Dense(10)(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
model.add_metric(math_ops.reduce_sum(x), name='metric_1')
```
Note: Calling `add_metric()` with the result of a metric object on a
Functional Model, as shown in the example below, is not supported. This is
because we cannot trace the metric result tensor back to the model's inputs.
```python
inputs = tf.keras.Input(shape=(10,))
x = tf.keras.layers.Dense(10)(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)
model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
```
Args:
value: Metric tensor.
name: String metric name.

View File

@ -397,8 +397,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
def metrics(self):
"""Returns the model's metrics added using `compile`, `add_metric` APIs.
Note: `metrics` are available only after a `keras.Model` has been
trained/evaluated on actual data.
Note: Metrics passed to `compile()` are available only after a `keras.Model`
has been trained/evaluated on actual data.
Examples:
@ -1375,7 +1375,24 @@ class Model(network.Network, version_utils.ModelVersionSelector):
return tf_utils.to_numpy_or_python_type(all_outputs)
def reset_metrics(self):
"""Resets the state of metrics."""
"""Resets the state of all the metrics in the model.
Examples:
>>> inputs = tf.keras.layers.Input(shape=(3,))
>>> outputs = tf.keras.layers.Dense(2)(inputs)
>>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
>>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
>>> x = np.random.random((2, 3))
>>> y = np.random.randint(0, 2, (2, 2))
>>> _ = model.fit(x, y, verbose=0)
>>> assert all(float(m.result()) for m in model.metrics)
>>> model.reset_metrics()
>>> assert all(float(m.result()) == 0 for m in model.metrics)
"""
for m in self.metrics:
m.reset_states()