- Adds MeanSquaredError
V2 loss implementation
- Adds support for the V2 losses in Keras. With the new losses the default loss reduction function in Keras has been changed from `weighted_mean` to `sum_over_batch_size`. PiperOrigin-RevId: 222720535
This commit is contained in:
parent
e4149e99dd
commit
31ac32eb03
@ -30,6 +30,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.utils.io_utils import HDF5Matrix
|
||||
from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
|
||||
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
|
||||
from tensorflow.python.keras.utils.np_utils import normalize
|
||||
from tensorflow.python.keras.utils.np_utils import to_categorical
|
||||
from tensorflow.python.keras.utils.vis_utils import plot_model
|
||||
|
@ -142,6 +142,7 @@ py_library(
|
||||
"regularizers.py",
|
||||
"utils/data_utils.py",
|
||||
"utils/io_utils.py",
|
||||
"utils/losses_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -41,6 +41,8 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.engine.network import Network
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import slice_arrays
|
||||
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import optimizer as tf_optimizer_module
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
@ -568,16 +570,16 @@ class Model(Network):
|
||||
'" missing from loss dictionary. We assume '
|
||||
'this was done on purpose. The fit and evaluate APIs will not be '
|
||||
'expecting any data to be passed to "' + name + '".')
|
||||
loss_functions.append(losses.get(loss.get(name)))
|
||||
loss_functions.append(training_utils.get_loss_function(loss.get(name)))
|
||||
elif isinstance(loss, list):
|
||||
if len(loss) != len(self.outputs):
|
||||
raise ValueError('When passing a list as loss, '
|
||||
'it should have one entry per model outputs. '
|
||||
'The model has ' + str(len(self.outputs)) +
|
||||
' outputs, but you passed loss=' + str(loss))
|
||||
loss_functions = [losses.get(l) for l in loss]
|
||||
loss_functions = [training_utils.get_loss_function(l) for l in loss]
|
||||
else:
|
||||
loss_function = losses.get(loss)
|
||||
loss_function = training_utils.get_loss_function(loss)
|
||||
loss_functions = [loss_function for _ in range(len(self.outputs))]
|
||||
self.loss_functions = loss_functions
|
||||
|
||||
@ -730,8 +732,21 @@ class Model(Network):
|
||||
mask = masks[i]
|
||||
loss_weight = loss_weights_list[i]
|
||||
with K.name_scope(self.output_names[i] + '_loss'):
|
||||
weighted_loss = training_utils.weighted_masked_objective(loss_fn)
|
||||
output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
|
||||
if isinstance(loss_fn, losses.Loss):
|
||||
if mask is not None:
|
||||
mask = math_ops.cast(mask, y_pred.dtype)
|
||||
# Update weights with mask.
|
||||
if sample_weight is None:
|
||||
sample_weight = mask
|
||||
else:
|
||||
# Update dimensions of weights to match with mask if possible.
|
||||
mask, _, sample_weight = squeeze_or_expand_dimensions(
|
||||
mask, None, sample_weight)
|
||||
sample_weight *= mask
|
||||
output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
|
||||
else:
|
||||
weighted_loss = training_utils.weighted_masked_objective(loss_fn)
|
||||
output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
|
||||
|
||||
if len(self.outputs) > 1:
|
||||
# Keep track of the un-aggregated loss result tensor.
|
||||
@ -739,8 +754,10 @@ class Model(Network):
|
||||
'_loss'] = output_loss
|
||||
|
||||
# Keep track of stateful result tensor and function for the loss.
|
||||
loss_name = loss_fn.name if isinstance(
|
||||
loss_fn, losses.Loss) else loss_fn.__name__
|
||||
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
|
||||
loss_fn, name=loss_fn.__name__)
|
||||
loss_fn, name=loss_name)
|
||||
result_tensor = training_utils.call_metric_function(
|
||||
mean_wrapped_loss,
|
||||
y_true,
|
||||
|
@ -31,9 +31,11 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import callbacks as cbks
|
||||
from tensorflow.python.keras import losses as losses_module
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
@ -128,11 +130,24 @@ def _model_loss(model,
|
||||
else:
|
||||
weights = None
|
||||
mask = masks[i]
|
||||
|
||||
weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
|
||||
with backend.name_scope(model.output_names[i] + '_loss'):
|
||||
output_loss = weighted_masked_fn(
|
||||
targets[i], outs[i], weights, mask=mask)
|
||||
if isinstance(loss_fn, losses_module.Loss):
|
||||
if mask is not None:
|
||||
mask = math_ops.cast(mask, outs[i].dtype)
|
||||
# Update weights with mask.
|
||||
if weights is None:
|
||||
weights = mask
|
||||
else:
|
||||
# Update dimensions of weights to match with mask if possible.
|
||||
mask, _, weights = squeeze_or_expand_dimensions(
|
||||
mask, None, weights)
|
||||
weights *= mask
|
||||
output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
|
||||
else:
|
||||
weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
|
||||
output_loss = weighted_masked_fn(
|
||||
targets[i], outs[i], weights, mask=mask)
|
||||
|
||||
# If the number of outputs is 1 then we don't append the loss metric
|
||||
# associated with each model output. When there are multiple outputs
|
||||
# associated with a model, each output's loss is calculated and returned
|
||||
@ -351,8 +366,10 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
|
||||
output_loss_metrics = []
|
||||
for i in range(len(model.outputs)):
|
||||
loss_fn = model.loss_functions[i]
|
||||
loss_name = loss_fn.name if isinstance(
|
||||
loss_fn, losses_module.Loss) else loss_fn.__name__
|
||||
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
|
||||
loss_fn, name=loss_fn.__name__)
|
||||
loss_fn, name=loss_name)
|
||||
output_loss_metrics.append(mean_wrapped_loss)
|
||||
|
||||
num_samples = 0
|
||||
@ -744,8 +761,10 @@ def fit_loop(model,
|
||||
output_loss_metrics = []
|
||||
for i in range(len(model.outputs)):
|
||||
loss_fn = model.loss_functions[i]
|
||||
loss_name = loss_fn.name if isinstance(
|
||||
loss_fn, losses_module.Loss) else loss_fn.__name__
|
||||
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
|
||||
loss_fn, name=loss_fn.__name__)
|
||||
loss_fn, name=loss_name)
|
||||
output_loss_metrics.append(mean_wrapped_loss)
|
||||
|
||||
callbacks.on_train_begin()
|
||||
|
@ -600,6 +600,34 @@ class TrainingTest(test.TestCase):
|
||||
np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'), epochs=10)
|
||||
self.assertTrue('Epoch 5/10' in mock_stdout.getvalue())
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_training_with_loss_instance(self):
|
||||
a = keras.layers.Input(shape=(3,), name='input_a')
|
||||
b = keras.layers.Input(shape=(3,), name='input_b')
|
||||
|
||||
dense = keras.layers.Dense(4, name='dense')
|
||||
c = dense(a)
|
||||
d = dense(b)
|
||||
e = keras.layers.Dropout(0.5, name='dropout')(c)
|
||||
|
||||
model = keras.models.Model([a, b], [d, e])
|
||||
loss_weights = [1., 0.5]
|
||||
model.compile(
|
||||
RMSPropOptimizer(learning_rate=0.001),
|
||||
loss=keras.losses.MeanSquaredError(),
|
||||
metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
|
||||
loss_weights=loss_weights)
|
||||
|
||||
input_a_np = np.random.random((10, 3))
|
||||
input_b_np = np.random.random((10, 3))
|
||||
|
||||
output_d_np = np.random.random((10, 4))
|
||||
output_e_np = np.random.random((10, 4))
|
||||
|
||||
model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
|
||||
epochs=1,
|
||||
batch_size=5)
|
||||
|
||||
|
||||
class TestExceptionsAndWarnings(test.TestCase):
|
||||
|
||||
@ -1918,7 +1946,7 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
|
||||
w = np.array([[3., 4.], [1., 2.]])
|
||||
outs = model.evaluate(x, y, sample_weight=w)
|
||||
self.assertArrayNear(outs, [0.3, 0.7, 0.3], .001)
|
||||
self.assertArrayNear(outs, [0.75, 0.7, 0.3], .001)
|
||||
|
||||
# Verify that metric value is same with arbitrary weights and batch size.
|
||||
x = np.random.random((50, 2, 1))
|
||||
@ -1988,7 +2016,7 @@ class TestTrainingWithMetrics(test.TestCase):
|
||||
# verify that masking is combined with sample weights.
|
||||
w = np.array([3, 2, 4])
|
||||
scores = model.train_on_batch(x, y, sample_weight=w)
|
||||
self.assertArrayNear(scores, [0.2, 0.8], 0.1)
|
||||
self.assertArrayNear(scores, [0.3328, 0.8], 0.001)
|
||||
|
||||
def test_add_metric_with_tensor_on_model_in_graph_mode(self):
|
||||
with self.cached_session():
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
@ -632,15 +633,14 @@ def weighted_masked_objective(fn):
|
||||
weights = mask
|
||||
else:
|
||||
# Update dimensions of weights to match with mask if possible.
|
||||
mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
|
||||
mask, None, weights)
|
||||
mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
|
||||
weights *= mask
|
||||
|
||||
# Apply sample weighting.
|
||||
if weights is not None:
|
||||
|
||||
# Update dimensions of weights to match with values if possible.
|
||||
score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
|
||||
score_array, _, weights = squeeze_or_expand_dimensions(
|
||||
score_array, None, weights)
|
||||
try:
|
||||
# Broadcast weights if possible.
|
||||
@ -838,12 +838,22 @@ def call_metric_function(metric_fn, y_true, y_pred, weights=None, mask=None):
|
||||
return metric_fn(y_true, y_pred, sample_weight=mask)
|
||||
|
||||
# Update dimensions of weights to match with mask.
|
||||
mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
|
||||
mask, None, weights)
|
||||
mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
|
||||
weights *= mask
|
||||
return metric_fn(y_true, y_pred, sample_weight=weights)
|
||||
|
||||
|
||||
def get_loss_function(loss):
|
||||
"""Returns the loss function corresponding to the given loss input."""
|
||||
if loss is None or isinstance(loss, losses.Loss):
|
||||
return loss
|
||||
|
||||
# TODO(psv): After we have added all V2 losses, update this function.
|
||||
if loss in ['mse', 'MSE', 'mean_squared_error']:
|
||||
return losses.MeanSquaredError()
|
||||
return losses.get(loss)
|
||||
|
||||
|
||||
def validate_iterator_input(x, y, sample_weight, validation_split=None):
|
||||
"""Validates user input arguments when a dataset iterator is passed.
|
||||
|
||||
|
@ -19,17 +19,120 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.utils.losses_utils import compute_weighted_loss
|
||||
from tensorflow.python.keras.utils.losses_utils import ReductionV2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops.losses import losses_impl
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
class Loss(object):
|
||||
"""Loss base class.
|
||||
|
||||
To be implemented by subclasses:
|
||||
* `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.
|
||||
|
||||
Example subclass implementation:
|
||||
```
|
||||
class MeanSquaredError(Loss):
|
||||
def call(self, y_true, y_pred):
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true = math_ops.cast(y_true, y_pred.dtype)
|
||||
return K.mean(math_ops.square(y_pred - y_true), axis=-1)
|
||||
```
|
||||
|
||||
Args:
|
||||
reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is
|
||||
`SUM_OVER_BATCH_SIZE`.
|
||||
name: Optional name for the op.
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=ReductionV2.SUM_OVER_BATCH_SIZE, name=None):
|
||||
self.reduction = reduction
|
||||
self.name = name
|
||||
|
||||
def __call__(self, y_true, y_pred, sample_weight=None):
|
||||
"""Invokes the `Loss` instance.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values.
|
||||
y_pred: The predicted values.
|
||||
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
|
||||
as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
|
||||
coefficient for the loss. If a scalar is provided, then the loss is
|
||||
simply scaled by the given value. If `sample_weight` is a tensor of size
|
||||
`[batch_size]`, then the total loss for each sample of the batch is
|
||||
rescaled by the corresponding element in the `sample_weight` vector. If
|
||||
the shape of `sample_weight` matches the shape of `y_pred`, then the
|
||||
loss of each measurable element of `y_pred` is scaled by the
|
||||
corresponding value of `sample_weight`.
|
||||
|
||||
Returns:
|
||||
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
|
||||
shape as `y_true`; otherwise, it is scalar.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shape of `sample_weight` is invalid.
|
||||
"""
|
||||
with ops.name_scope(self.name, format(self.__class__.__name__),
|
||||
(y_pred, y_true, sample_weight)):
|
||||
losses = self.call(y_true, y_pred)
|
||||
return compute_weighted_loss(
|
||||
losses, sample_weight, reduction=self.reduction)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Instantiates a `Loss` from its config (output of `get_config()`).
|
||||
|
||||
Args:
|
||||
config: Output of `get_config()`.
|
||||
|
||||
Returns:
|
||||
A `Loss` instance.
|
||||
"""
|
||||
return cls(**config)
|
||||
|
||||
def get_config(self):
|
||||
return {'reduction': self.reduction, 'name': self.name}
|
||||
|
||||
@abc.abstractmethod
|
||||
def call(self, y_true, y_pred):
|
||||
"""Invokes the `Loss` instance.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values, with the same shape as 'y_pred'.
|
||||
y_pred: The predicted values.
|
||||
"""
|
||||
NotImplementedError('Must be implemented in subclasses.')
|
||||
|
||||
|
||||
class MeanSquaredError(Loss):
|
||||
"""Computes the mean of squares of errors between labels and predictions."""
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""Invokes the `MeanSquaredError` instance.
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values.
|
||||
y_pred: The predicted values.
|
||||
|
||||
Returns:
|
||||
Mean squared error losses.
|
||||
"""
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true = math_ops.cast(y_true, y_pred.dtype)
|
||||
return mean_squared_error(y_true, y_pred)
|
||||
|
||||
|
||||
@tf_export('keras.metrics.mean_squared_error',
|
||||
'keras.metrics.mse',
|
||||
'keras.metrics.MSE',
|
||||
|
@ -24,6 +24,9 @@ import shutil
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
try:
|
||||
@ -138,5 +141,98 @@ class KerasLossesTest(test.TestCase):
|
||||
loaded_model.predict(np.random.rand(128, 2))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MeanSquaredErrorTest(test.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
mse_obj = keras.losses.MeanSquaredError(
|
||||
reduction=keras.losses.ReductionV2.SUM, name='mse_1')
|
||||
self.assertEqual(mse_obj.name, 'mse_1')
|
||||
self.assertEqual(mse_obj.reduction, keras.losses.ReductionV2.SUM)
|
||||
|
||||
def test_all_correct_unweighted(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
|
||||
loss = mse_obj(y_true, y_true)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
def test_unweighted(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = mse_obj(y_true, y_pred)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 49.5, 3)
|
||||
|
||||
def test_scalar_weighted(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=2.3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 113.85, 3)
|
||||
|
||||
def test_sample_weighted(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 767.8 / 6, 3)
|
||||
|
||||
def test_timestep_weighted(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3, 1),
|
||||
dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 587 / 6, 3)
|
||||
|
||||
def test_zero_weighted(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=0)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
|
||||
|
||||
def test_invalid_sample_weight(self):
|
||||
mse_obj = keras.losses.MeanSquaredError()
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3, 1))
|
||||
sample_weight = constant_op.constant([3, 6, 5, 0], shape=(2, 2))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'Shapes \(2, 2\) and \(2, 3\) are incompatible'):
|
||||
mse_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
|
||||
def test_no_reduction(self):
|
||||
mse_obj = keras.losses.MeanSquaredError(
|
||||
reduction=keras.losses.ReductionV2.NONE)
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=2.3)
|
||||
loss = self.evaluate(loss)
|
||||
self.assertArrayNear(loss, [84.3333, 143.3666], 1e-3)
|
||||
|
||||
def test_sum_reduction(self):
|
||||
mse_obj = keras.losses.MeanSquaredError(
|
||||
reduction=keras.losses.ReductionV2.SUM)
|
||||
y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
|
||||
y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
|
||||
shape=(2, 3),
|
||||
dtype=dtypes.float32)
|
||||
loss = mse_obj(y_true, y_pred, sample_weight=2.3)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 227.69998, 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -48,9 +48,9 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy
|
||||
from tensorflow.python.keras.losses import squared_hinge
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import confusion_matrix
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -172,77 +172,6 @@ def weakmethod(method):
|
||||
return inner
|
||||
|
||||
|
||||
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
|
||||
"""Squeeze or expand last dimension if needed.
|
||||
|
||||
1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
|
||||
(using `confusion_matrix.remove_squeezable_dimensions`).
|
||||
2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
|
||||
from the new rank of `y_pred`.
|
||||
If `sample_weight` is scalar, it is kept scalar.
|
||||
|
||||
This will use static shape if available. Otherwise, it will add graph
|
||||
operations, which could result in a performance hit.
|
||||
|
||||
Args:
|
||||
y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
|
||||
y_true: Optional label `Tensor` whose dimensions match `y_pred`.
|
||||
sample_weight: Optional weight scalar or `Tensor` whose dimensions match
|
||||
`y_pred`.
|
||||
|
||||
Returns:
|
||||
Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
|
||||
the last dimension squeezed,
|
||||
`sample_weight` could be extended by one dimension.
|
||||
"""
|
||||
if y_true is not None:
|
||||
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
|
||||
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
|
||||
y_true, y_pred)
|
||||
|
||||
if sample_weight is None:
|
||||
return y_pred, y_true, None
|
||||
|
||||
sample_weight = ops.convert_to_tensor(sample_weight)
|
||||
weights_shape = sample_weight.get_shape()
|
||||
weights_rank = weights_shape.ndims
|
||||
if weights_rank == 0: # If weights is scalar, do nothing.
|
||||
return y_pred, y_true, sample_weight
|
||||
|
||||
y_pred_shape = y_pred.get_shape()
|
||||
y_pred_rank = y_pred_shape.ndims
|
||||
if (y_pred_rank is not None) and (weights_rank is not None):
|
||||
# Use static rank.
|
||||
if weights_rank - y_pred_rank == 1:
|
||||
sample_weight = array_ops.squeeze(sample_weight, [-1])
|
||||
elif y_pred_rank - weights_rank == 1:
|
||||
sample_weight = array_ops.expand_dims(sample_weight, [-1])
|
||||
return y_pred, y_true, sample_weight
|
||||
|
||||
# Use dynamic rank.
|
||||
weights_rank_tensor = array_ops.rank(sample_weight)
|
||||
rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
|
||||
maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
|
||||
|
||||
def _maybe_expand_weights():
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(rank_diff,
|
||||
-1), lambda: array_ops.expand_dims(sample_weight, [-1]),
|
||||
lambda: sample_weight)
|
||||
|
||||
def _maybe_adjust_weights():
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
|
||||
_maybe_expand_weights)
|
||||
|
||||
# squeeze or expand last dim of `sample_weight` if its rank differs by 1
|
||||
# from the new rank of `y_pred`.
|
||||
sample_weight = control_flow_ops.cond(
|
||||
math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
|
||||
_maybe_adjust_weights)
|
||||
return y_pred, y_true, sample_weight
|
||||
|
||||
|
||||
class _ConfusionMatrix(Enum):
|
||||
TRUE_POSITIVES = 'tp'
|
||||
FALSE_POSITIVES = 'fp'
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.utils.io_utils import HDF5Matrix
|
||||
from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
|
||||
from tensorflow.python.keras.utils.layer_utils import get_source_inputs
|
||||
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
|
||||
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
|
||||
from tensorflow.python.keras.utils.np_utils import normalize
|
||||
from tensorflow.python.keras.utils.np_utils import to_categorical
|
||||
|
213
tensorflow/python/keras/utils/losses_utils.py
Normal file
213
tensorflow/python/keras/utils/losses_utils.py
Normal file
@ -0,0 +1,213 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Utilities related to loss functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import confusion_matrix
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export('losses.Reduction', 'keras.losses.Reduction', v1=[])
|
||||
class ReductionV2(object):
|
||||
"""Types of loss reduction.
|
||||
|
||||
Contains the following values:
|
||||
`NONE`: Un-reduced weighted losses with the same shape as input.
|
||||
`SUM`: Scalar sum of weighted losses.
|
||||
`SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
|
||||
"""
|
||||
|
||||
NONE = None
|
||||
SUM = 'sum'
|
||||
SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
|
||||
|
||||
@classmethod
|
||||
def all(cls):
|
||||
return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, key):
|
||||
if key not in cls.all():
|
||||
raise ValueError('Invalid Reduction Key %s.' % key)
|
||||
|
||||
|
||||
def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
|
||||
"""Squeeze or expand last dimension if needed.
|
||||
|
||||
1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
|
||||
(using `confusion_matrix.remove_squeezable_dimensions`).
|
||||
2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
|
||||
from the new rank of `y_pred`.
|
||||
If `sample_weight` is scalar, it is kept scalar.
|
||||
|
||||
This will use static shape if available. Otherwise, it will add graph
|
||||
operations, which could result in a performance hit.
|
||||
|
||||
Args:
|
||||
y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
|
||||
y_true: Optional label `Tensor` whose dimensions match `y_pred`.
|
||||
sample_weight: Optional weight scalar or `Tensor` whose dimensions match
|
||||
`y_pred`.
|
||||
|
||||
Returns:
|
||||
Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
|
||||
the last dimension squeezed,
|
||||
`sample_weight` could be extended by one dimension.
|
||||
"""
|
||||
if y_true is not None:
|
||||
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
|
||||
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
|
||||
y_true, y_pred)
|
||||
|
||||
if sample_weight is None:
|
||||
return y_pred, y_true, None
|
||||
|
||||
sample_weight = ops.convert_to_tensor(sample_weight)
|
||||
weights_shape = sample_weight.get_shape()
|
||||
weights_rank = weights_shape.ndims
|
||||
if weights_rank == 0: # If weights is scalar, do nothing.
|
||||
return y_pred, y_true, sample_weight
|
||||
|
||||
y_pred_shape = y_pred.get_shape()
|
||||
y_pred_rank = y_pred_shape.ndims
|
||||
if (y_pred_rank is not None) and (weights_rank is not None):
|
||||
# Use static rank.
|
||||
if weights_rank - y_pred_rank == 1:
|
||||
sample_weight = array_ops.squeeze(sample_weight, [-1])
|
||||
elif y_pred_rank - weights_rank == 1:
|
||||
sample_weight = array_ops.expand_dims(sample_weight, [-1])
|
||||
return y_pred, y_true, sample_weight
|
||||
|
||||
# Use dynamic rank.
|
||||
weights_rank_tensor = array_ops.rank(sample_weight)
|
||||
rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
|
||||
maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
|
||||
|
||||
def _maybe_expand_weights():
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(rank_diff,
|
||||
-1), lambda: array_ops.expand_dims(sample_weight, [-1]),
|
||||
lambda: sample_weight)
|
||||
|
||||
def _maybe_adjust_weights():
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
|
||||
_maybe_expand_weights)
|
||||
|
||||
# squeeze or expand last dim of `sample_weight` if its rank differs by 1
|
||||
# from the new rank of `y_pred`.
|
||||
sample_weight = control_flow_ops.cond(
|
||||
math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
|
||||
_maybe_adjust_weights)
|
||||
return y_pred, y_true, sample_weight
|
||||
|
||||
|
||||
def _safe_mean(losses, num_present):
|
||||
"""Computes a safe mean of the losses.
|
||||
|
||||
Args:
|
||||
losses: `Tensor` whose elements contain individual loss measurements.
|
||||
num_present: The number of measurable elements in `losses`.
|
||||
|
||||
Returns:
|
||||
A scalar representing the mean of `losses`. If `num_present` is zero,
|
||||
then zero is returned.
|
||||
"""
|
||||
total_loss = math_ops.reduce_sum(losses)
|
||||
return math_ops.div_no_nan(total_loss, num_present, name='value')
|
||||
|
||||
|
||||
def _num_elements(losses):
|
||||
"""Computes the number of elements in `losses` tensor."""
|
||||
with ops.name_scope(None, 'num_elements', values=[losses]) as scope:
|
||||
return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
|
||||
|
||||
|
||||
def _reduce_weighted_loss(weighted_losses,
|
||||
reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
|
||||
"""Reduces the individual weighted loss measurements."""
|
||||
if reduction == ReductionV2.NONE:
|
||||
loss = weighted_losses
|
||||
else:
|
||||
loss = math_ops.reduce_sum(weighted_losses)
|
||||
if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
|
||||
loss = _safe_mean(loss, _num_elements(weighted_losses))
|
||||
return loss
|
||||
|
||||
|
||||
def compute_weighted_loss(losses,
|
||||
sample_weight=None,
|
||||
reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
name=None):
|
||||
"""Computes the weighted loss.
|
||||
|
||||
Args:
|
||||
losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
|
||||
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`losses`, or be broadcastable to `losses`.
|
||||
reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is
|
||||
`SUM_OVER_BATCH_SIZE`.
|
||||
name: Optional name for the op.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shape of `sample_weight` is not compatible with `losses`.
|
||||
|
||||
Returns:
|
||||
Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
|
||||
`NONE`, this has the same shape as `losses`; otherwise, it is scalar.
|
||||
"""
|
||||
ReductionV2.validate(reduction)
|
||||
if sample_weight is None:
|
||||
sample_weight = 1.0
|
||||
with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)):
|
||||
# Save the `reduction` argument for loss normalization when distributing
|
||||
# to multiple replicas.
|
||||
# TODO(josh11b): Associate it with the returned op for more precision.
|
||||
ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
|
||||
|
||||
# Update dimensions of `sample_weight` to match with `losses` if possible.
|
||||
losses, _, sample_weight = squeeze_or_expand_dimensions(
|
||||
losses, None, sample_weight)
|
||||
losses = ops.convert_to_tensor(losses)
|
||||
input_dtype = losses.dtype
|
||||
losses = math_ops.to_float(losses)
|
||||
sample_weight = math_ops.to_float(sample_weight)
|
||||
|
||||
try:
|
||||
# Broadcast weights if possible.
|
||||
sample_weight = weights_broadcast_ops.broadcast_weights(
|
||||
sample_weight, losses)
|
||||
except ValueError:
|
||||
# Reduce values to same ndim as weight array.
|
||||
ndim = K.ndim(losses)
|
||||
weight_ndim = K.ndim(sample_weight)
|
||||
losses = K.mean(losses, axis=list(range(weight_ndim, ndim)))
|
||||
|
||||
sample_weight.get_shape().assert_is_compatible_with(losses.get_shape())
|
||||
weighted_losses = math_ops.multiply(losses, sample_weight)
|
||||
# Apply reduction function to the individual weighted losses.
|
||||
loss = _reduce_weighted_loss(weighted_losses, reduction)
|
||||
# Convert the result back to the input type.
|
||||
loss = math_ops.cast(loss, input_dtype)
|
||||
return loss
|
@ -33,32 +33,8 @@ from tensorflow.python.util.deprecation import deprecated_argument_lookup
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("losses.Reduction", v1=[])
|
||||
class ReductionV2(object):
|
||||
"""Types of loss reduction.
|
||||
|
||||
Contains the following values:
|
||||
`NONE`: Un-reduced weighted losses with the same shape as input.
|
||||
`SUM`: Scalar sum of weighted losses.
|
||||
`SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
SUM = "weighted_sum"
|
||||
SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
|
||||
|
||||
@classmethod
|
||||
def all(cls):
|
||||
return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, key):
|
||||
if key not in cls.all():
|
||||
raise ValueError("Invalid Reduction Key %s." % key)
|
||||
|
||||
|
||||
@tf_export(v1=["losses.Reduction"])
|
||||
class Reduction(ReductionV2):
|
||||
class Reduction(object):
|
||||
"""Types of loss reduction.
|
||||
|
||||
Contains the following values:
|
||||
@ -71,6 +47,9 @@ class Reduction(ReductionV2):
|
||||
`SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
SUM = "weighted_sum"
|
||||
SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
|
||||
MEAN = "weighted_mean"
|
||||
SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
|
||||
SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS
|
||||
|
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.losses.Reduction"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.Reduction\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.ReductionV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "MEAN"
|
||||
|
@ -0,0 +1,28 @@
|
||||
path: "tensorflow.keras.losses.Reduction"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.losses_utils.ReductionV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "NONE"
|
||||
mtype: "<type \'NoneType\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUM"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUM_OVER_BATCH_SIZE"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "all"
|
||||
argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "validate"
|
||||
argspec: "args=[\'cls\', \'key\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.keras.losses"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Reduction"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "KLD"
|
||||
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,10 +1,10 @@
|
||||
path: "tensorflow.losses.Reduction"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.ReductionV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.losses_utils.ReductionV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "NONE"
|
||||
mtype: "<type \'str\'>"
|
||||
mtype: "<type \'NoneType\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUM"
|
||||
|
Loading…
x
Reference in New Issue
Block a user