- Adds MeanSquaredError V2 loss implementation

- Adds support for the V2 losses in Keras.
With the new losses the default loss reduction function in Keras has been changed from `weighted_mean` to `sum_over_batch_size`.

PiperOrigin-RevId: 222720535
This commit is contained in:
Pavithra Vijay 2018-11-25 00:47:04 -08:00 committed by TensorFlower Gardener
parent e4149e99dd
commit 31ac32eb03
16 changed files with 547 additions and 119 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.io_utils import HDF5Matrix
from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.keras.utils.np_utils import normalize
from tensorflow.python.keras.utils.np_utils import to_categorical
from tensorflow.python.keras.utils.vis_utils import plot_model

View File

@ -142,6 +142,7 @@ py_library(
"regularizers.py",
"utils/data_utils.py",
"utils/io_utils.py",
"utils/losses_utils.py",
],
srcs_version = "PY2AND3",
deps = [

View File

@ -41,6 +41,8 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
@ -568,16 +570,16 @@ class Model(Network):
'" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
loss_functions.append(training_utils.get_loss_function(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
raise ValueError('When passing a list as loss, '
'it should have one entry per model outputs. '
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed loss=' + str(loss))
loss_functions = [losses.get(l) for l in loss]
loss_functions = [training_utils.get_loss_function(l) for l in loss]
else:
loss_function = losses.get(loss)
loss_function = training_utils.get_loss_function(loss)
loss_functions = [loss_function for _ in range(len(self.outputs))]
self.loss_functions = loss_functions
@ -730,8 +732,21 @@ class Model(Network):
mask = masks[i]
loss_weight = loss_weights_list[i]
with K.name_scope(self.output_names[i] + '_loss'):
weighted_loss = training_utils.weighted_masked_objective(loss_fn)
output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
if isinstance(loss_fn, losses.Loss):
if mask is not None:
mask = math_ops.cast(mask, y_pred.dtype)
# Update weights with mask.
if sample_weight is None:
sample_weight = mask
else:
# Update dimensions of weights to match with mask if possible.
mask, _, sample_weight = squeeze_or_expand_dimensions(
mask, None, sample_weight)
sample_weight *= mask
output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
else:
weighted_loss = training_utils.weighted_masked_objective(loss_fn)
output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
if len(self.outputs) > 1:
# Keep track of the un-aggregated loss result tensor.
@ -739,8 +754,10 @@ class Model(Network):
'_loss'] = output_loss
# Keep track of stateful result tensor and function for the loss.
loss_name = loss_fn.name if isinstance(
loss_fn, losses.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
loss_fn, name=loss_fn.__name__)
loss_fn, name=loss_name)
result_tensor = training_utils.call_metric_function(
mean_wrapped_loss,
y_true,

View File

@ -31,9 +31,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import losses as losses_module
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
@ -128,11 +130,24 @@ def _model_loss(model,
else:
weights = None
mask = masks[i]
weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
with backend.name_scope(model.output_names[i] + '_loss'):
output_loss = weighted_masked_fn(
targets[i], outs[i], weights, mask=mask)
if isinstance(loss_fn, losses_module.Loss):
if mask is not None:
mask = math_ops.cast(mask, outs[i].dtype)
# Update weights with mask.
if weights is None:
weights = mask
else:
# Update dimensions of weights to match with mask if possible.
mask, _, weights = squeeze_or_expand_dimensions(
mask, None, weights)
weights *= mask
output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
else:
weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
output_loss = weighted_masked_fn(
targets[i], outs[i], weights, mask=mask)
# If the number of outputs is 1 then we don't append the loss metric
# associated with each model output. When there are multiple outputs
# associated with a model, each output's loss is calculated and returned
@ -351,8 +366,10 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
output_loss_metrics = []
for i in range(len(model.outputs)):
loss_fn = model.loss_functions[i]
loss_name = loss_fn.name if isinstance(
loss_fn, losses_module.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
loss_fn, name=loss_fn.__name__)
loss_fn, name=loss_name)
output_loss_metrics.append(mean_wrapped_loss)
num_samples = 0
@ -744,8 +761,10 @@ def fit_loop(model,
output_loss_metrics = []
for i in range(len(model.outputs)):
loss_fn = model.loss_functions[i]
loss_name = loss_fn.name if isinstance(
loss_fn, losses_module.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
loss_fn, name=loss_fn.__name__)
loss_fn, name=loss_name)
output_loss_metrics.append(mean_wrapped_loss)
callbacks.on_train_begin()

View File

@ -600,6 +600,34 @@ class TrainingTest(test.TestCase):
np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'), epochs=10)
self.assertTrue('Epoch 5/10' in mock_stdout.getvalue())
@tf_test_util.run_in_graph_and_eager_modes
def test_training_with_loss_instance(self):
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
dense = keras.layers.Dense(4, name='dense')
c = dense(a)
d = dense(b)
e = keras.layers.Dropout(0.5, name='dropout')(c)
model = keras.models.Model([a, b], [d, e])
loss_weights = [1., 0.5]
model.compile(
RMSPropOptimizer(learning_rate=0.001),
loss=keras.losses.MeanSquaredError(),
metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
loss_weights=loss_weights)
input_a_np = np.random.random((10, 3))
input_b_np = np.random.random((10, 3))
output_d_np = np.random.random((10, 4))
output_e_np = np.random.random((10, 4))
model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
epochs=1,
batch_size=5)
class TestExceptionsAndWarnings(test.TestCase):
@ -1918,7 +1946,7 @@ class TestTrainingWithMetrics(test.TestCase):
w = np.array([[3., 4.], [1., 2.]])
outs = model.evaluate(x, y, sample_weight=w)
self.assertArrayNear(outs, [0.3, 0.7, 0.3], .001)
self.assertArrayNear(outs, [0.75, 0.7, 0.3], .001)
# Verify that metric value is same with arbitrary weights and batch size.
x = np.random.random((50, 2, 1))
@ -1988,7 +2016,7 @@ class TestTrainingWithMetrics(test.TestCase):
# verify that masking is combined with sample weights.
w = np.array([3, 2, 4])
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.2, 0.8], 0.1)
self.assertArrayNear(scores, [0.3328, 0.8], 0.001)
def test_add_metric_with_tensor_on_model_in_graph_mode(self):
with self.cached_session():

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
@ -632,15 +633,14 @@ def weighted_masked_objective(fn):
weights = mask
else:
# Update dimensions of weights to match with mask if possible.
mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
mask, None, weights)
mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
weights *= mask
# Apply sample weighting.
if weights is not None:
# Update dimensions of weights to match with values if possible.
score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
score_array, _, weights = squeeze_or_expand_dimensions(
score_array, None, weights)
try:
# Broadcast weights if possible.
@ -838,12 +838,22 @@ def call_metric_function(metric_fn, y_true, y_pred, weights=None, mask=None):
return metric_fn(y_true, y_pred, sample_weight=mask)
# Update dimensions of weights to match with mask.
mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
mask, None, weights)
mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
weights *= mask
return metric_fn(y_true, y_pred, sample_weight=weights)
def get_loss_function(loss):
"""Returns the loss function corresponding to the given loss input."""
if loss is None or isinstance(loss, losses.Loss):
return loss
# TODO(psv): After we have added all V2 losses, update this function.
if loss in ['mse', 'MSE', 'mean_squared_error']:
return losses.MeanSquaredError()
return losses.get(loss)
def validate_iterator_input(x, y, sample_weight, validation_split=None):
"""Validates user input arguments when a dataset iterator is passed.

View File

@ -19,17 +19,120 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.losses_utils import compute_weighted_loss
from tensorflow.python.keras.utils.losses_utils import ReductionV2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.util.tf_export import tf_export
class Loss(object):
"""Loss base class.
To be implemented by subclasses:
* `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.
Example subclass implementation:
```
class MeanSquaredError(Loss):
def call(self, y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return K.mean(math_ops.square(y_pred - y_true), axis=-1)
```
Args:
reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is
`SUM_OVER_BATCH_SIZE`.
name: Optional name for the op.
"""
def __init__(self, reduction=ReductionV2.SUM_OVER_BATCH_SIZE, name=None):
self.reduction = reduction
self.name = name
def __call__(self, y_true, y_pred, sample_weight=None):
"""Invokes the `Loss` instance.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
coefficient for the loss. If a scalar is provided, then the loss is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the total loss for each sample of the batch is
rescaled by the corresponding element in the `sample_weight` vector. If
the shape of `sample_weight` matches the shape of `y_pred`, then the
loss of each measurable element of `y_pred` is scaled by the
corresponding value of `sample_weight`.
Returns:
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
shape as `y_true`; otherwise, it is scalar.
Raises:
ValueError: If the shape of `sample_weight` is invalid.
"""
with ops.name_scope(self.name, format(self.__class__.__name__),
(y_pred, y_true, sample_weight)):
losses = self.call(y_true, y_pred)
return compute_weighted_loss(
losses, sample_weight, reduction=self.reduction)
@classmethod
def from_config(cls, config):
"""Instantiates a `Loss` from its config (output of `get_config()`).
Args:
config: Output of `get_config()`.
Returns:
A `Loss` instance.
"""
return cls(**config)
def get_config(self):
return {'reduction': self.reduction, 'name': self.name}
@abc.abstractmethod
def call(self, y_true, y_pred):
"""Invokes the `Loss` instance.
Args:
y_true: Ground truth values, with the same shape as 'y_pred'.
y_pred: The predicted values.
"""
NotImplementedError('Must be implemented in subclasses.')
class MeanSquaredError(Loss):
"""Computes the mean of squares of errors between labels and predictions."""
def call(self, y_true, y_pred):
"""Invokes the `MeanSquaredError` instance.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Mean squared error losses.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return mean_squared_error(y_true, y_pred)
@tf_export('keras.metrics.mean_squared_error',
'keras.metrics.mse',
'keras.metrics.MSE',

View File

@ -24,6 +24,9 @@ import shutil
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
try:
@ -138,5 +141,98 @@ class KerasLossesTest(test.TestCase):
loaded_model.predict(np.random.rand(128, 2))
@test_util.run_all_in_graph_and_eager_modes
class MeanSquaredErrorTest(test.TestCase):
def test_config(self):
mse_obj = keras.losses.MeanSquaredError(
reduction=keras.losses.ReductionV2.SUM, name='mse_1')
self.assertEqual(mse_obj.name, 'mse_1')
self.assertEqual(mse_obj.reduction, keras.losses.ReductionV2.SUM)
def test_all_correct_unweighted(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
loss = mse_obj(y_true, y_true)
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
def test_unweighted(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = mse_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), 49.5, 3)
def test_scalar_weighted(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = mse_obj(y_true, y_pred, sample_weight=2.3)
self.assertAlmostEqual(self.evaluate(loss), 113.85, 3)
def test_sample_weighted(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), 767.8 / 6, 3)
def test_timestep_weighted(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3, 1),
dtype=dtypes.float32)
sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), 587 / 6, 3)
def test_zero_weighted(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = mse_obj(y_true, y_pred, sample_weight=0)
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
def test_invalid_sample_weight(self):
mse_obj = keras.losses.MeanSquaredError()
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3, 1))
sample_weight = constant_op.constant([3, 6, 5, 0], shape=(2, 2))
with self.assertRaisesRegexp(
ValueError, r'Shapes \(2, 2\) and \(2, 3\) are incompatible'):
mse_obj(y_true, y_pred, sample_weight=sample_weight)
def test_no_reduction(self):
mse_obj = keras.losses.MeanSquaredError(
reduction=keras.losses.ReductionV2.NONE)
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = mse_obj(y_true, y_pred, sample_weight=2.3)
loss = self.evaluate(loss)
self.assertArrayNear(loss, [84.3333, 143.3666], 1e-3)
def test_sum_reduction(self):
mse_obj = keras.losses.MeanSquaredError(
reduction=keras.losses.ReductionV2.SUM)
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
shape=(2, 3),
dtype=dtypes.float32)
loss = mse_obj(y_true, y_pred, sample_weight=2.3)
self.assertAlmostEqual(self.evaluate(loss), 227.69998, 3)
if __name__ == '__main__':
test.main()

View File

@ -48,9 +48,9 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.losses import squared_hinge
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@ -172,77 +172,6 @@ def weakmethod(method):
return inner
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
(using `confusion_matrix.remove_squeezable_dimensions`).
2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
from the new rank of `y_pred`.
If `sample_weight` is scalar, it is kept scalar.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
y_true: Optional label `Tensor` whose dimensions match `y_pred`.
sample_weight: Optional weight scalar or `Tensor` whose dimensions match
`y_pred`.
Returns:
Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
the last dimension squeezed,
`sample_weight` could be extended by one dimension.
"""
if y_true is not None:
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
y_true, y_pred)
if sample_weight is None:
return y_pred, y_true, None
sample_weight = ops.convert_to_tensor(sample_weight)
weights_shape = sample_weight.get_shape()
weights_rank = weights_shape.ndims
if weights_rank == 0: # If weights is scalar, do nothing.
return y_pred, y_true, sample_weight
y_pred_shape = y_pred.get_shape()
y_pred_rank = y_pred_shape.ndims
if (y_pred_rank is not None) and (weights_rank is not None):
# Use static rank.
if weights_rank - y_pred_rank == 1:
sample_weight = array_ops.squeeze(sample_weight, [-1])
elif y_pred_rank - weights_rank == 1:
sample_weight = array_ops.expand_dims(sample_weight, [-1])
return y_pred, y_true, sample_weight
# Use dynamic rank.
weights_rank_tensor = array_ops.rank(sample_weight)
rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
def _maybe_expand_weights():
return control_flow_ops.cond(
math_ops.equal(rank_diff,
-1), lambda: array_ops.expand_dims(sample_weight, [-1]),
lambda: sample_weight)
def _maybe_adjust_weights():
return control_flow_ops.cond(
math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
_maybe_expand_weights)
# squeeze or expand last dim of `sample_weight` if its rank differs by 1
# from the new rank of `y_pred`.
sample_weight = control_flow_ops.cond(
math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
_maybe_adjust_weights)
return y_pred, y_true, sample_weight
class _ConfusionMatrix(Enum):
TRUE_POSITIVES = 'tp'
FALSE_POSITIVES = 'fp'

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.io_utils import HDF5Matrix
from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
from tensorflow.python.keras.utils.layer_utils import get_source_inputs
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
from tensorflow.python.keras.utils.np_utils import normalize
from tensorflow.python.keras.utils.np_utils import to_categorical

View File

@ -0,0 +1,213 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities related to loss functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export('losses.Reduction', 'keras.losses.Reduction', v1=[])
class ReductionV2(object):
"""Types of loss reduction.
Contains the following values:
`NONE`: Un-reduced weighted losses with the same shape as input.
`SUM`: Scalar sum of weighted losses.
`SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
"""
NONE = None
SUM = 'sum'
SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
@classmethod
def all(cls):
return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
@classmethod
def validate(cls, key):
if key not in cls.all():
raise ValueError('Invalid Reduction Key %s.' % key)
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
(using `confusion_matrix.remove_squeezable_dimensions`).
2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
from the new rank of `y_pred`.
If `sample_weight` is scalar, it is kept scalar.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
y_true: Optional label `Tensor` whose dimensions match `y_pred`.
sample_weight: Optional weight scalar or `Tensor` whose dimensions match
`y_pred`.
Returns:
Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
the last dimension squeezed,
`sample_weight` could be extended by one dimension.
"""
if y_true is not None:
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
y_true, y_pred)
if sample_weight is None:
return y_pred, y_true, None
sample_weight = ops.convert_to_tensor(sample_weight)
weights_shape = sample_weight.get_shape()
weights_rank = weights_shape.ndims
if weights_rank == 0: # If weights is scalar, do nothing.
return y_pred, y_true, sample_weight
y_pred_shape = y_pred.get_shape()
y_pred_rank = y_pred_shape.ndims
if (y_pred_rank is not None) and (weights_rank is not None):
# Use static rank.
if weights_rank - y_pred_rank == 1:
sample_weight = array_ops.squeeze(sample_weight, [-1])
elif y_pred_rank - weights_rank == 1:
sample_weight = array_ops.expand_dims(sample_weight, [-1])
return y_pred, y_true, sample_weight
# Use dynamic rank.
weights_rank_tensor = array_ops.rank(sample_weight)
rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
def _maybe_expand_weights():
return control_flow_ops.cond(
math_ops.equal(rank_diff,
-1), lambda: array_ops.expand_dims(sample_weight, [-1]),
lambda: sample_weight)
def _maybe_adjust_weights():
return control_flow_ops.cond(
math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
_maybe_expand_weights)
# squeeze or expand last dim of `sample_weight` if its rank differs by 1
# from the new rank of `y_pred`.
sample_weight = control_flow_ops.cond(
math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
_maybe_adjust_weights)
return y_pred, y_true, sample_weight
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
Args:
losses: `Tensor` whose elements contain individual loss measurements.
num_present: The number of measurable elements in `losses`.
Returns:
A scalar representing the mean of `losses`. If `num_present` is zero,
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
return math_ops.div_no_nan(total_loss, num_present, name='value')
def _num_elements(losses):
"""Computes the number of elements in `losses` tensor."""
with ops.name_scope(None, 'num_elements', values=[losses]) as scope:
return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
def _reduce_weighted_loss(weighted_losses,
reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
"""Reduces the individual weighted loss measurements."""
if reduction == ReductionV2.NONE:
loss = weighted_losses
else:
loss = math_ops.reduce_sum(weighted_losses)
if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
loss = _safe_mean(loss, _num_elements(weighted_losses))
return loss
def compute_weighted_loss(losses,
sample_weight=None,
reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
name=None):
"""Computes the weighted loss.
Args:
losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
`losses`, or be broadcastable to `losses`.
reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is
`SUM_OVER_BATCH_SIZE`.
name: Optional name for the op.
Raises:
ValueError: If the shape of `sample_weight` is not compatible with `losses`.
Returns:
Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
`NONE`, this has the same shape as `losses`; otherwise, it is scalar.
"""
ReductionV2.validate(reduction)
if sample_weight is None:
sample_weight = 1.0
with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)):
# Save the `reduction` argument for loss normalization when distributing
# to multiple replicas.
# TODO(josh11b): Associate it with the returned op for more precision.
ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
# Update dimensions of `sample_weight` to match with `losses` if possible.
losses, _, sample_weight = squeeze_or_expand_dimensions(
losses, None, sample_weight)
losses = ops.convert_to_tensor(losses)
input_dtype = losses.dtype
losses = math_ops.to_float(losses)
sample_weight = math_ops.to_float(sample_weight)
try:
# Broadcast weights if possible.
sample_weight = weights_broadcast_ops.broadcast_weights(
sample_weight, losses)
except ValueError:
# Reduce values to same ndim as weight array.
ndim = K.ndim(losses)
weight_ndim = K.ndim(sample_weight)
losses = K.mean(losses, axis=list(range(weight_ndim, ndim)))
sample_weight.get_shape().assert_is_compatible_with(losses.get_shape())
weighted_losses = math_ops.multiply(losses, sample_weight)
# Apply reduction function to the individual weighted losses.
loss = _reduce_weighted_loss(weighted_losses, reduction)
# Convert the result back to the input type.
loss = math_ops.cast(loss, input_dtype)
return loss

View File

@ -33,32 +33,8 @@ from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@tf_export("losses.Reduction", v1=[])
class ReductionV2(object):
"""Types of loss reduction.
Contains the following values:
`NONE`: Un-reduced weighted losses with the same shape as input.
`SUM`: Scalar sum of weighted losses.
`SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
"""
NONE = "none"
SUM = "weighted_sum"
SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
@classmethod
def all(cls):
return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
@classmethod
def validate(cls, key):
if key not in cls.all():
raise ValueError("Invalid Reduction Key %s." % key)
@tf_export(v1=["losses.Reduction"])
class Reduction(ReductionV2):
class Reduction(object):
"""Types of loss reduction.
Contains the following values:
@ -71,6 +47,9 @@ class Reduction(ReductionV2):
`SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`.
"""
NONE = "none"
SUM = "weighted_sum"
SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
MEAN = "weighted_mean"
SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS

View File

@ -1,7 +1,6 @@
path: "tensorflow.losses.Reduction"
tf_class {
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.Reduction\'>"
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.ReductionV2\'>"
is_instance: "<type \'object\'>"
member {
name: "MEAN"

View File

@ -0,0 +1,28 @@
path: "tensorflow.keras.losses.Reduction"
tf_class {
is_instance: "<class \'tensorflow.python.keras.utils.losses_utils.ReductionV2\'>"
is_instance: "<type \'object\'>"
member {
name: "NONE"
mtype: "<type \'NoneType\'>"
}
member {
name: "SUM"
mtype: "<type \'str\'>"
}
member {
name: "SUM_OVER_BATCH_SIZE"
mtype: "<type \'str\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "all"
argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "validate"
argspec: "args=[\'cls\', \'key\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.keras.losses"
tf_module {
member {
name: "Reduction"
mtype: "<type \'type\'>"
}
member_method {
name: "KLD"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"

View File

@ -1,10 +1,10 @@
path: "tensorflow.losses.Reduction"
tf_class {
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.ReductionV2\'>"
is_instance: "<class \'tensorflow.python.keras.utils.losses_utils.ReductionV2\'>"
is_instance: "<type \'object\'>"
member {
name: "NONE"
mtype: "<type \'str\'>"
mtype: "<type \'NoneType\'>"
}
member {
name: "SUM"