Update usage of tf.keras.losses.BinaryCrossEntropy
PiperOrigin-RevId: 347092623 Change-Id: I956364fdda51f099f950faf411612b8604d7d194
This commit is contained in:
parent
5a51f9bbed
commit
28c7e4d9f2
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Built-in loss functions.
|
||||
"""
|
||||
"""Built-in loss functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -92,8 +91,8 @@ class Loss(object):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op.
|
||||
"""
|
||||
losses_utils.ReductionV2.validate(reduction)
|
||||
@ -122,15 +121,15 @@ class Loss(object):
|
||||
sparse loss functions such as sparse categorical crossentropy where
|
||||
shape = `[batch_size, d0, .. dN-1]`
|
||||
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
|
||||
sample_weight: Optional `sample_weight` acts as a
|
||||
coefficient for the loss. If a scalar is provided, then the loss is
|
||||
simply scaled by the given value. If `sample_weight` is a tensor of size
|
||||
`[batch_size]`, then the total loss for each sample of the batch is
|
||||
rescaled by the corresponding element in the `sample_weight` vector. If
|
||||
the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
|
||||
broadcasted to this shape), then each loss element of `y_pred` is scaled
|
||||
sample_weight: Optional `sample_weight` acts as a coefficient for the
|
||||
loss. If a scalar is provided, then the loss is simply scaled by the
|
||||
given value. If `sample_weight` is a tensor of size `[batch_size]`, then
|
||||
the total loss for each sample of the batch is rescaled by the
|
||||
corresponding element in the `sample_weight` vector. If the shape of
|
||||
`sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to
|
||||
this shape), then each loss element of `y_pred` is scaled
|
||||
by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
|
||||
functions reduce by 1 dimension, usually axis=-1.)
|
||||
functions reduce by 1 dimension, usually axis=-1.)
|
||||
|
||||
Returns:
|
||||
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has
|
||||
@ -230,8 +229,8 @@ class LossFunctionWrapper(Loss):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: (Optional) name for the loss.
|
||||
**kwargs: The keyword arguments that are passed on to `fn`.
|
||||
"""
|
||||
@ -250,8 +249,7 @@ class LossFunctionWrapper(Loss):
|
||||
Loss values per sample.
|
||||
"""
|
||||
if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
|
||||
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
|
||||
y_pred, y_true)
|
||||
y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true)
|
||||
ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx())
|
||||
return ag_fn(y_true, y_pred, **self._fn_kwargs)
|
||||
|
||||
@ -314,8 +312,8 @@ class MeanSquaredError(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'mean_squared_error'.
|
||||
"""
|
||||
super(MeanSquaredError, self).__init__(
|
||||
@ -373,8 +371,8 @@ class MeanAbsoluteError(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'mean_absolute_error'.
|
||||
"""
|
||||
super(MeanAbsoluteError, self).__init__(
|
||||
@ -433,8 +431,8 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to
|
||||
'mean_absolute_percentage_error'.
|
||||
"""
|
||||
@ -494,8 +492,8 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to
|
||||
'mean_squared_logarithmic_error'.
|
||||
"""
|
||||
@ -507,44 +505,64 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
||||
class BinaryCrossentropy(LossFunctionWrapper):
|
||||
"""Computes the cross-entropy loss between true labels and predicted labels.
|
||||
|
||||
Use this cross-entropy loss when there are only two label classes (assumed to
|
||||
be 0 and 1). For each example, there should be a single floating-point value
|
||||
per prediction.
|
||||
Use this cross-entropy loss for binary (0 or 1) classification applications.
|
||||
The loss function requires the following inputs:
|
||||
|
||||
In the snippet below, each of the four examples has only a single
|
||||
floating-pointing value, and both `y_pred` and `y_true` have the shape
|
||||
`[batch_size]`.
|
||||
- `y_true` (true label): This is either 0 or 1.
|
||||
- `y_pred` (predicted value): This is the model's prediction, i.e, a single
|
||||
floating-point value which either represents a
|
||||
[logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]
|
||||
when `from_logits=True`) or a probability (i.e, value in [0., 1.] when
|
||||
`from_logits=False`).
|
||||
|
||||
Standalone usage:
|
||||
**Recommended Usage:** (set `from_logits=True`)
|
||||
|
||||
>>> y_true = [[0., 1.], [0., 0.]]
|
||||
>>> y_pred = [[0.6, 0.4], [0.4, 0.6]]
|
||||
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy()
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
0.815
|
||||
|
||||
>>> # Calling with 'sample_weight'.
|
||||
>>> bce(y_true, y_pred, sample_weight=[1, 0]).numpy()
|
||||
0.458
|
||||
|
||||
>>> # Using 'sum' reduction type.
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy(
|
||||
... reduction=tf.keras.losses.Reduction.SUM)
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
1.630
|
||||
|
||||
>>> # Using 'none' reduction type.
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy(
|
||||
... reduction=tf.keras.losses.Reduction.NONE)
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
array([0.916 , 0.714], dtype=float32)
|
||||
|
||||
Usage with the `tf.keras` API:
|
||||
With `tf.keras` API:
|
||||
|
||||
```python
|
||||
model.compile(optimizer='sgd', loss=tf.keras.losses.BinaryCrossentropy())
|
||||
model.compile(
|
||||
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
|
||||
....
|
||||
)
|
||||
```
|
||||
|
||||
As a standalone function:
|
||||
|
||||
>>> # Example 1: (batch_size = 1, number of samples = 4)
|
||||
>>> y_true = [0, 1, 0, 0]
|
||||
>>> y_pred = [-18.6, 0.51, 2.94, -12.8]
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
0.865
|
||||
|
||||
>>> # Example 2: (batch_size = 2, number of samples = 4)
|
||||
>>> y_true = [[0, 1], [0, 0]]
|
||||
>>> y_pred = [[-18.6, 0.51], [2.94, -12.8]]
|
||||
>>> # Using default 'auto'/'sum_over_batch_size' reduction type.
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
0.865
|
||||
>>> # Using 'sample_weight' attribute
|
||||
>>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy()
|
||||
0.243
|
||||
>>> # Using 'sum' reduction` type.
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
|
||||
... reduction=tf.keras.losses.Reduction.SUM)
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
1.730
|
||||
>>> # Using 'none' reduction type.
|
||||
>>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,
|
||||
... reduction=tf.keras.losses.Reduction.NONE)
|
||||
>>> bce(y_true, y_pred).numpy()
|
||||
array([0.235, 1.496], dtype=float32)
|
||||
|
||||
**Default Usage:** (set `from_logits=False`)
|
||||
|
||||
>>> # Make the following updates to the above "Recommended Usage" section
|
||||
>>> # 1. Set `from_logits=False`
|
||||
>>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')
|
||||
>>> # 2. Update `y_pred` to use probabilities instead of logits
|
||||
>>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -570,8 +588,8 @@ class BinaryCrossentropy(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: (Optional) Name for the op. Defaults to 'binary_crossentropy'.
|
||||
"""
|
||||
super(BinaryCrossentropy, self).__init__(
|
||||
@ -650,8 +668,8 @@ class CategoricalCrossentropy(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'categorical_crossentropy'.
|
||||
"""
|
||||
super(CategoricalCrossentropy, self).__init__(
|
||||
@ -727,8 +745,8 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to
|
||||
'sparse_categorical_crossentropy'.
|
||||
"""
|
||||
@ -791,8 +809,8 @@ class Hinge(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'hinge'.
|
||||
"""
|
||||
super(Hinge, self).__init__(hinge, name=name, reduction=reduction)
|
||||
@ -852,8 +870,8 @@ class SquaredHinge(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'squared_hinge'.
|
||||
"""
|
||||
super(SquaredHinge, self).__init__(
|
||||
@ -912,8 +930,8 @@ class CategoricalHinge(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'categorical_hinge'.
|
||||
"""
|
||||
super(CategoricalHinge, self).__init__(
|
||||
@ -969,8 +987,8 @@ class Poisson(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'poisson'.
|
||||
"""
|
||||
super(Poisson, self).__init__(poisson, name=name, reduction=reduction)
|
||||
@ -1026,8 +1044,8 @@ class LogCosh(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'log_cosh'.
|
||||
"""
|
||||
super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction)
|
||||
@ -1086,8 +1104,8 @@ class KLDivergence(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'kl_divergence'.
|
||||
"""
|
||||
super(KLDivergence, self).__init__(
|
||||
@ -1154,20 +1172,17 @@ class Huber(LossFunctionWrapper):
|
||||
`tf.distribute.Strategy`, outside of built-in training loops such as
|
||||
`tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
|
||||
will raise an error. Please see this custom training [tutorial](
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training)
|
||||
for more details.
|
||||
https://www.tensorflow.org/tutorials/distribute/custom_training) for
|
||||
more details.
|
||||
name: Optional name for the op. Defaults to 'huber_loss'.
|
||||
"""
|
||||
super(Huber, self).__init__(
|
||||
huber, name=name, reduction=reduction, delta=delta)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.mean_squared_error',
|
||||
'keras.metrics.mse',
|
||||
'keras.metrics.MSE',
|
||||
'keras.losses.mean_squared_error',
|
||||
'keras.losses.mse',
|
||||
'keras.losses.MSE')
|
||||
@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse',
|
||||
'keras.metrics.MSE', 'keras.losses.mean_squared_error',
|
||||
'keras.losses.mse', 'keras.losses.MSE')
|
||||
@dispatch.add_dispatch_support
|
||||
def mean_squared_error(y_true, y_pred):
|
||||
"""Computes the mean squared error between labels and predictions.
|
||||
@ -1198,12 +1213,9 @@ def mean_squared_error(y_true, y_pred):
|
||||
return K.mean(math_ops.squared_difference(y_pred, y_true), axis=-1)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.mean_absolute_error',
|
||||
'keras.metrics.mae',
|
||||
'keras.metrics.MAE',
|
||||
'keras.losses.mean_absolute_error',
|
||||
'keras.losses.mae',
|
||||
'keras.losses.MAE')
|
||||
@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae',
|
||||
'keras.metrics.MAE', 'keras.losses.mean_absolute_error',
|
||||
'keras.losses.mae', 'keras.losses.MAE')
|
||||
@dispatch.add_dispatch_support
|
||||
def mean_absolute_error(y_true, y_pred):
|
||||
"""Computes the mean absolute error between labels and predictions.
|
||||
@ -1232,11 +1244,9 @@ def mean_absolute_error(y_true, y_pred):
|
||||
|
||||
|
||||
@keras_export('keras.metrics.mean_absolute_percentage_error',
|
||||
'keras.metrics.mape',
|
||||
'keras.metrics.MAPE',
|
||||
'keras.metrics.mape', 'keras.metrics.MAPE',
|
||||
'keras.losses.mean_absolute_percentage_error',
|
||||
'keras.losses.mape',
|
||||
'keras.losses.MAPE')
|
||||
'keras.losses.mape', 'keras.losses.MAPE')
|
||||
@dispatch.add_dispatch_support
|
||||
def mean_absolute_percentage_error(y_true, y_pred):
|
||||
"""Computes the mean absolute percentage error between `y_true` and `y_pred`.
|
||||
@ -1269,11 +1279,9 @@ def mean_absolute_percentage_error(y_true, y_pred):
|
||||
|
||||
|
||||
@keras_export('keras.metrics.mean_squared_logarithmic_error',
|
||||
'keras.metrics.msle',
|
||||
'keras.metrics.MSLE',
|
||||
'keras.metrics.msle', 'keras.metrics.MSLE',
|
||||
'keras.losses.mean_squared_logarithmic_error',
|
||||
'keras.losses.msle',
|
||||
'keras.losses.MSLE')
|
||||
'keras.losses.msle', 'keras.losses.MSLE')
|
||||
@dispatch.add_dispatch_support
|
||||
def mean_squared_logarithmic_error(y_true, y_pred):
|
||||
"""Computes the mean squared logarithmic error between `y_true` and `y_pred`.
|
||||
@ -1609,12 +1617,9 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
|
||||
|
||||
|
||||
@keras_export('keras.metrics.kl_divergence',
|
||||
'keras.metrics.kullback_leibler_divergence',
|
||||
'keras.metrics.kld',
|
||||
'keras.metrics.KLD',
|
||||
'keras.losses.kl_divergence',
|
||||
'keras.losses.kullback_leibler_divergence',
|
||||
'keras.losses.kld',
|
||||
'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld',
|
||||
'keras.metrics.KLD', 'keras.losses.kl_divergence',
|
||||
'keras.losses.kullback_leibler_divergence', 'keras.losses.kld',
|
||||
'keras.losses.KLD')
|
||||
@dispatch.add_dispatch_support
|
||||
def kl_divergence(y_true, y_pred):
|
||||
|
Loading…
x
Reference in New Issue
Block a user