diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index 47cd01b924f..3b9fa1b230b 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -30,6 +30,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.keras.utils.np_utils import normalize from tensorflow.python.keras.utils.np_utils import to_categorical from tensorflow.python.keras.utils.vis_utils import plot_model diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 540dd03768f..fa1cad2359d 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -142,6 +142,7 @@ py_library( "regularizers.py", "utils/data_utils.py", "utils/io_utils.py", + "utils/losses_utils.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 888d8eb9420..8c564ed61b2 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -41,6 +41,8 @@ from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils.generic_utils import slice_arrays +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions +from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training.checkpointable import base as checkpointable @@ -568,16 +570,16 @@ class Model(Network): '" missing from loss dictionary. We assume ' 'this was done on purpose. The fit and evaluate APIs will not be ' 'expecting any data to be passed to "' + name + '".') - loss_functions.append(losses.get(loss.get(name))) + loss_functions.append(training_utils.get_loss_function(loss.get(name))) elif isinstance(loss, list): if len(loss) != len(self.outputs): raise ValueError('When passing a list as loss, ' 'it should have one entry per model outputs. ' 'The model has ' + str(len(self.outputs)) + ' outputs, but you passed loss=' + str(loss)) - loss_functions = [losses.get(l) for l in loss] + loss_functions = [training_utils.get_loss_function(l) for l in loss] else: - loss_function = losses.get(loss) + loss_function = training_utils.get_loss_function(loss) loss_functions = [loss_function for _ in range(len(self.outputs))] self.loss_functions = loss_functions @@ -730,8 +732,21 @@ class Model(Network): mask = masks[i] loss_weight = loss_weights_list[i] with K.name_scope(self.output_names[i] + '_loss'): - weighted_loss = training_utils.weighted_masked_objective(loss_fn) - output_loss = weighted_loss(y_true, y_pred, sample_weight, mask) + if isinstance(loss_fn, losses.Loss): + if mask is not None: + mask = math_ops.cast(mask, y_pred.dtype) + # Update weights with mask. + if sample_weight is None: + sample_weight = mask + else: + # Update dimensions of weights to match with mask if possible. + mask, _, sample_weight = squeeze_or_expand_dimensions( + mask, None, sample_weight) + sample_weight *= mask + output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + else: + weighted_loss = training_utils.weighted_masked_objective(loss_fn) + output_loss = weighted_loss(y_true, y_pred, sample_weight, mask) if len(self.outputs) > 1: # Keep track of the un-aggregated loss result tensor. @@ -739,8 +754,10 @@ class Model(Network): '_loss'] = output_loss # Keep track of stateful result tensor and function for the loss. + loss_name = loss_fn.name if isinstance( + loss_fn, losses.Loss) else loss_fn.__name__ mean_wrapped_loss = metrics_module.MeanMetricWrapper( - loss_fn, name=loss_fn.__name__) + loss_fn, name=loss_name) result_tensor = training_utils.call_metric_function( mean_wrapped_loss, y_true, diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index b2dace84aa3..cd85c365db4 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -31,9 +31,11 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras import losses as losses_module from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging @@ -128,11 +130,24 @@ def _model_loss(model, else: weights = None mask = masks[i] - - weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn) with backend.name_scope(model.output_names[i] + '_loss'): - output_loss = weighted_masked_fn( - targets[i], outs[i], weights, mask=mask) + if isinstance(loss_fn, losses_module.Loss): + if mask is not None: + mask = math_ops.cast(mask, outs[i].dtype) + # Update weights with mask. + if weights is None: + weights = mask + else: + # Update dimensions of weights to match with mask if possible. + mask, _, weights = squeeze_or_expand_dimensions( + mask, None, weights) + weights *= mask + output_loss = loss_fn(targets[i], outs[i], sample_weight=weights) + else: + weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn) + output_loss = weighted_masked_fn( + targets[i], outs[i], weights, mask=mask) + # If the number of outputs is 1 then we don't append the loss metric # associated with each model output. When there are multiple outputs # associated with a model, each output's loss is calculated and returned @@ -351,8 +366,10 @@ def iterator_test_loop(model, inputs, steps, verbose=0): output_loss_metrics = [] for i in range(len(model.outputs)): loss_fn = model.loss_functions[i] + loss_name = loss_fn.name if isinstance( + loss_fn, losses_module.Loss) else loss_fn.__name__ mean_wrapped_loss = metrics_module.MeanMetricWrapper( - loss_fn, name=loss_fn.__name__) + loss_fn, name=loss_name) output_loss_metrics.append(mean_wrapped_loss) num_samples = 0 @@ -744,8 +761,10 @@ def fit_loop(model, output_loss_metrics = [] for i in range(len(model.outputs)): loss_fn = model.loss_functions[i] + loss_name = loss_fn.name if isinstance( + loss_fn, losses_module.Loss) else loss_fn.__name__ mean_wrapped_loss = metrics_module.MeanMetricWrapper( - loss_fn, name=loss_fn.__name__) + loss_fn, name=loss_name) output_loss_metrics.append(mean_wrapped_loss) callbacks.on_train_begin() diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 1009ef71387..97dfe6d9003 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -600,6 +600,34 @@ class TrainingTest(test.TestCase): np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'), epochs=10) self.assertTrue('Epoch 5/10' in mock_stdout.getvalue()) + @tf_test_util.run_in_graph_and_eager_modes + def test_training_with_loss_instance(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + loss_weights = [1., 0.5] + model.compile( + RMSPropOptimizer(learning_rate=0.001), + loss=keras.losses.MeanSquaredError(), + metrics=[metrics_module.CategoricalAccuracy(), 'mae'], + loss_weights=loss_weights) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + model.fit([input_a_np, input_b_np], [output_d_np, output_e_np], + epochs=1, + batch_size=5) + class TestExceptionsAndWarnings(test.TestCase): @@ -1918,7 +1946,7 @@ class TestTrainingWithMetrics(test.TestCase): w = np.array([[3., 4.], [1., 2.]]) outs = model.evaluate(x, y, sample_weight=w) - self.assertArrayNear(outs, [0.3, 0.7, 0.3], .001) + self.assertArrayNear(outs, [0.75, 0.7, 0.3], .001) # Verify that metric value is same with arbitrary weights and batch size. x = np.random.random((50, 2, 1)) @@ -1988,7 +2016,7 @@ class TestTrainingWithMetrics(test.TestCase): # verify that masking is combined with sample weights. w = np.array([3, 2, 4]) scores = model.train_on_batch(x, y, sample_weight=w) - self.assertArrayNear(scores, [0.2, 0.8], 0.1) + self.assertArrayNear(scores, [0.3328, 0.8], 0.001) def test_add_metric_with_tensor_on_model_in_graph_mode(self): with self.cached_session(): diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 1735db8b6b9..347582aa95a 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -35,6 +35,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import weights_broadcast_ops @@ -632,15 +633,14 @@ def weighted_masked_objective(fn): weights = mask else: # Update dimensions of weights to match with mask if possible. - mask, _, weights = metrics_module.squeeze_or_expand_dimensions( - mask, None, weights) + mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights) weights *= mask # Apply sample weighting. if weights is not None: # Update dimensions of weights to match with values if possible. - score_array, _, weights = metrics_module.squeeze_or_expand_dimensions( + score_array, _, weights = squeeze_or_expand_dimensions( score_array, None, weights) try: # Broadcast weights if possible. @@ -838,12 +838,22 @@ def call_metric_function(metric_fn, y_true, y_pred, weights=None, mask=None): return metric_fn(y_true, y_pred, sample_weight=mask) # Update dimensions of weights to match with mask. - mask, _, weights = metrics_module.squeeze_or_expand_dimensions( - mask, None, weights) + mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights) weights *= mask return metric_fn(y_true, y_pred, sample_weight=weights) +def get_loss_function(loss): + """Returns the loss function corresponding to the given loss input.""" + if loss is None or isinstance(loss, losses.Loss): + return loss + + # TODO(psv): After we have added all V2 losses, update this function. + if loss in ['mse', 'MSE', 'mean_squared_error']: + return losses.MeanSquaredError() + return losses.get(loss) + + def validate_iterator_input(x, y, sample_weight, validation_split=None): """Validates user input arguments when a dataset iterator is passed. diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index f871ee409ec..0e274d4d501 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -19,17 +19,120 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc + import six +from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.losses_utils import compute_weighted_loss +from tensorflow.python.keras.utils.losses_utils import ReductionV2 from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.losses import losses_impl from tensorflow.python.util.tf_export import tf_export +class Loss(object): + """Loss base class. + + To be implemented by subclasses: + * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`. + + Example subclass implementation: + ``` + class MeanSquaredError(Loss): + def call(self, y_true, y_pred): + y_pred = ops.convert_to_tensor(y_pred) + y_true = math_ops.cast(y_true, y_pred.dtype) + return K.mean(math_ops.square(y_pred - y_true), axis=-1) + ``` + + Args: + reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is + `SUM_OVER_BATCH_SIZE`. + name: Optional name for the op. + """ + + def __init__(self, reduction=ReductionV2.SUM_OVER_BATCH_SIZE, name=None): + self.reduction = reduction + self.name = name + + def __call__(self, y_true, y_pred, sample_weight=None): + """Invokes the `Loss` instance. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank + as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a + coefficient for the loss. If a scalar is provided, then the loss is + simply scaled by the given value. If `sample_weight` is a tensor of size + `[batch_size]`, then the total loss for each sample of the batch is + rescaled by the corresponding element in the `sample_weight` vector. If + the shape of `sample_weight` matches the shape of `y_pred`, then the + loss of each measurable element of `y_pred` is scaled by the + corresponding value of `sample_weight`. + + Returns: + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `y_true`; otherwise, it is scalar. + + Raises: + ValueError: If the shape of `sample_weight` is invalid. + """ + with ops.name_scope(self.name, format(self.__class__.__name__), + (y_pred, y_true, sample_weight)): + losses = self.call(y_true, y_pred) + return compute_weighted_loss( + losses, sample_weight, reduction=self.reduction) + + @classmethod + def from_config(cls, config): + """Instantiates a `Loss` from its config (output of `get_config()`). + + Args: + config: Output of `get_config()`. + + Returns: + A `Loss` instance. + """ + return cls(**config) + + def get_config(self): + return {'reduction': self.reduction, 'name': self.name} + + @abc.abstractmethod + def call(self, y_true, y_pred): + """Invokes the `Loss` instance. + + Args: + y_true: Ground truth values, with the same shape as 'y_pred'. + y_pred: The predicted values. + """ + NotImplementedError('Must be implemented in subclasses.') + + +class MeanSquaredError(Loss): + """Computes the mean of squares of errors between labels and predictions.""" + + def call(self, y_true, y_pred): + """Invokes the `MeanSquaredError` instance. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Mean squared error losses. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = math_ops.cast(y_true, y_pred.dtype) + return mean_squared_error(y_true, y_pred) + + @tf_export('keras.metrics.mean_squared_error', 'keras.metrics.mse', 'keras.metrics.MSE', diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index c7015270acc..b056f920abf 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -24,6 +24,9 @@ import shutil import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.platform import test try: @@ -138,5 +141,98 @@ class KerasLossesTest(test.TestCase): loaded_model.predict(np.random.rand(128, 2)) +@test_util.run_all_in_graph_and_eager_modes +class MeanSquaredErrorTest(test.TestCase): + + def test_config(self): + mse_obj = keras.losses.MeanSquaredError( + reduction=keras.losses.ReductionV2.SUM, name='mse_1') + self.assertEqual(mse_obj.name, 'mse_1') + self.assertEqual(mse_obj.reduction, keras.losses.ReductionV2.SUM) + + def test_all_correct_unweighted(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) + loss = mse_obj(y_true, y_true) + self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) + + def test_unweighted(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = mse_obj(y_true, y_pred) + self.assertAlmostEqual(self.evaluate(loss), 49.5, 3) + + def test_scalar_weighted(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = mse_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), 113.85, 3) + + def test_sample_weighted(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1)) + loss = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 767.8 / 6, 3) + + def test_timestep_weighted(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3, 1), + dtype=dtypes.float32) + sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3)) + loss = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 587 / 6, 3) + + def test_zero_weighted(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = mse_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) + + def test_invalid_sample_weight(self): + mse_obj = keras.losses.MeanSquaredError() + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3, 1)) + sample_weight = constant_op.constant([3, 6, 5, 0], shape=(2, 2)) + with self.assertRaisesRegexp( + ValueError, r'Shapes \(2, 2\) and \(2, 3\) are incompatible'): + mse_obj(y_true, y_pred, sample_weight=sample_weight) + + def test_no_reduction(self): + mse_obj = keras.losses.MeanSquaredError( + reduction=keras.losses.ReductionV2.NONE) + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = mse_obj(y_true, y_pred, sample_weight=2.3) + loss = self.evaluate(loss) + self.assertArrayNear(loss, [84.3333, 143.3666], 1e-3) + + def test_sum_reduction(self): + mse_obj = keras.losses.MeanSquaredError( + reduction=keras.losses.ReductionV2.SUM) + y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], + shape=(2, 3), + dtype=dtypes.float32) + loss = mse_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), 227.69998, 3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 668c56243b0..1ddeb0bee78 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -48,9 +48,9 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras.losses import squared_hinge from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -172,77 +172,6 @@ def weakmethod(method): return inner -def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): - """Squeeze or expand last dimension if needed. - - 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 - (using `confusion_matrix.remove_squeezable_dimensions`). - 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 - from the new rank of `y_pred`. - If `sample_weight` is scalar, it is kept scalar. - - This will use static shape if available. Otherwise, it will add graph - operations, which could result in a performance hit. - - Args: - y_pred: Predicted values, a `Tensor` of arbitrary dimensions. - y_true: Optional label `Tensor` whose dimensions match `y_pred`. - sample_weight: Optional weight scalar or `Tensor` whose dimensions match - `y_pred`. - - Returns: - Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has - the last dimension squeezed, - `sample_weight` could be extended by one dimension. - """ - if y_true is not None: - # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1 - y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( - y_true, y_pred) - - if sample_weight is None: - return y_pred, y_true, None - - sample_weight = ops.convert_to_tensor(sample_weight) - weights_shape = sample_weight.get_shape() - weights_rank = weights_shape.ndims - if weights_rank == 0: # If weights is scalar, do nothing. - return y_pred, y_true, sample_weight - - y_pred_shape = y_pred.get_shape() - y_pred_rank = y_pred_shape.ndims - if (y_pred_rank is not None) and (weights_rank is not None): - # Use static rank. - if weights_rank - y_pred_rank == 1: - sample_weight = array_ops.squeeze(sample_weight, [-1]) - elif y_pred_rank - weights_rank == 1: - sample_weight = array_ops.expand_dims(sample_weight, [-1]) - return y_pred, y_true, sample_weight - - # Use dynamic rank. - weights_rank_tensor = array_ops.rank(sample_weight) - rank_diff = weights_rank_tensor - array_ops.rank(y_pred) - maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) - - def _maybe_expand_weights(): - return control_flow_ops.cond( - math_ops.equal(rank_diff, - -1), lambda: array_ops.expand_dims(sample_weight, [-1]), - lambda: sample_weight) - - def _maybe_adjust_weights(): - return control_flow_ops.cond( - math_ops.equal(rank_diff, 1), maybe_squeeze_weights, - _maybe_expand_weights) - - # squeeze or expand last dim of `sample_weight` if its rank differs by 1 - # from the new rank of `y_pred`. - sample_weight = control_flow_ops.cond( - math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, - _maybe_adjust_weights) - return y_pred, y_true, sample_weight - - class _ConfusionMatrix(Enum): TRUE_POSITIVES = 'tp' FALSE_POSITIVES = 'fp' diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py index 8939044f71d..61940ad789c 100644 --- a/tensorflow/python/keras/utils/__init__.py +++ b/tensorflow/python/keras/utils/__init__.py @@ -34,6 +34,7 @@ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model from tensorflow.python.keras.utils.layer_utils import get_source_inputs +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model from tensorflow.python.keras.utils.np_utils import normalize from tensorflow.python.keras.utils.np_utils import to_categorical diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py new file mode 100644 index 00000000000..d11d7853569 --- /dev/null +++ b/tensorflow/python/keras/utils/losses_utils.py @@ -0,0 +1,213 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Utilities related to loss functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.keras import backend as K +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import confusion_matrix +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export('losses.Reduction', 'keras.losses.Reduction', v1=[]) +class ReductionV2(object): + """Types of loss reduction. + + Contains the following values: + `NONE`: Un-reduced weighted losses with the same shape as input. + `SUM`: Scalar sum of weighted losses. + `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. + """ + + NONE = None + SUM = 'sum' + SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' + + @classmethod + def all(cls): + return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) + + @classmethod + def validate(cls, key): + if key not in cls.all(): + raise ValueError('Invalid Reduction Key %s.' % key) + + +def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): + """Squeeze or expand last dimension if needed. + + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 + (using `confusion_matrix.remove_squeezable_dimensions`). + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 + from the new rank of `y_pred`. + If `sample_weight` is scalar, it is kept scalar. + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + y_pred: Predicted values, a `Tensor` of arbitrary dimensions. + y_true: Optional label `Tensor` whose dimensions match `y_pred`. + sample_weight: Optional weight scalar or `Tensor` whose dimensions match + `y_pred`. + + Returns: + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has + the last dimension squeezed, + `sample_weight` could be extended by one dimension. + """ + if y_true is not None: + # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1 + y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( + y_true, y_pred) + + if sample_weight is None: + return y_pred, y_true, None + + sample_weight = ops.convert_to_tensor(sample_weight) + weights_shape = sample_weight.get_shape() + weights_rank = weights_shape.ndims + if weights_rank == 0: # If weights is scalar, do nothing. + return y_pred, y_true, sample_weight + + y_pred_shape = y_pred.get_shape() + y_pred_rank = y_pred_shape.ndims + if (y_pred_rank is not None) and (weights_rank is not None): + # Use static rank. + if weights_rank - y_pred_rank == 1: + sample_weight = array_ops.squeeze(sample_weight, [-1]) + elif y_pred_rank - weights_rank == 1: + sample_weight = array_ops.expand_dims(sample_weight, [-1]) + return y_pred, y_true, sample_weight + + # Use dynamic rank. + weights_rank_tensor = array_ops.rank(sample_weight) + rank_diff = weights_rank_tensor - array_ops.rank(y_pred) + maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) + + def _maybe_expand_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, + -1), lambda: array_ops.expand_dims(sample_weight, [-1]), + lambda: sample_weight) + + def _maybe_adjust_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, 1), maybe_squeeze_weights, + _maybe_expand_weights) + + # squeeze or expand last dim of `sample_weight` if its rank differs by 1 + # from the new rank of `y_pred`. + sample_weight = control_flow_ops.cond( + math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, + _maybe_adjust_weights) + return y_pred, y_true, sample_weight + + +def _safe_mean(losses, num_present): + """Computes a safe mean of the losses. + + Args: + losses: `Tensor` whose elements contain individual loss measurements. + num_present: The number of measurable elements in `losses`. + + Returns: + A scalar representing the mean of `losses`. If `num_present` is zero, + then zero is returned. + """ + total_loss = math_ops.reduce_sum(losses) + return math_ops.div_no_nan(total_loss, num_present, name='value') + + +def _num_elements(losses): + """Computes the number of elements in `losses` tensor.""" + with ops.name_scope(None, 'num_elements', values=[losses]) as scope: + return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) + + +def _reduce_weighted_loss(weighted_losses, + reduction=ReductionV2.SUM_OVER_BATCH_SIZE): + """Reduces the individual weighted loss measurements.""" + if reduction == ReductionV2.NONE: + loss = weighted_losses + else: + loss = math_ops.reduce_sum(weighted_losses) + if reduction == ReductionV2.SUM_OVER_BATCH_SIZE: + loss = _safe_mean(loss, _num_elements(weighted_losses)) + return loss + + +def compute_weighted_loss(losses, + sample_weight=None, + reduction=ReductionV2.SUM_OVER_BATCH_SIZE, + name=None): + """Computes the weighted loss. + + Args: + losses: `Tensor` of shape `[batch_size, d1, ... dN]`. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as + `losses`, or be broadcastable to `losses`. + reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is + `SUM_OVER_BATCH_SIZE`. + name: Optional name for the op. + + Raises: + ValueError: If the shape of `sample_weight` is not compatible with `losses`. + + Returns: + Weighted loss `Tensor` of the same type as `losses`. If `reduction` is + `NONE`, this has the same shape as `losses`; otherwise, it is scalar. + """ + ReductionV2.validate(reduction) + if sample_weight is None: + sample_weight = 1.0 + with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)): + # Save the `reduction` argument for loss normalization when distributing + # to multiple replicas. + # TODO(josh11b): Associate it with the returned op for more precision. + ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access + + # Update dimensions of `sample_weight` to match with `losses` if possible. + losses, _, sample_weight = squeeze_or_expand_dimensions( + losses, None, sample_weight) + losses = ops.convert_to_tensor(losses) + input_dtype = losses.dtype + losses = math_ops.to_float(losses) + sample_weight = math_ops.to_float(sample_weight) + + try: + # Broadcast weights if possible. + sample_weight = weights_broadcast_ops.broadcast_weights( + sample_weight, losses) + except ValueError: + # Reduce values to same ndim as weight array. + ndim = K.ndim(losses) + weight_ndim = K.ndim(sample_weight) + losses = K.mean(losses, axis=list(range(weight_ndim, ndim))) + + sample_weight.get_shape().assert_is_compatible_with(losses.get_shape()) + weighted_losses = math_ops.multiply(losses, sample_weight) + # Apply reduction function to the individual weighted losses. + loss = _reduce_weighted_loss(weighted_losses, reduction) + # Convert the result back to the input type. + loss = math_ops.cast(loss, input_dtype) + return loss diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 0a5b511f820..7c52b28b391 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -33,32 +33,8 @@ from tensorflow.python.util.deprecation import deprecated_argument_lookup from tensorflow.python.util.tf_export import tf_export -@tf_export("losses.Reduction", v1=[]) -class ReductionV2(object): - """Types of loss reduction. - - Contains the following values: - `NONE`: Un-reduced weighted losses with the same shape as input. - `SUM`: Scalar sum of weighted losses. - `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. - """ - - NONE = "none" - SUM = "weighted_sum" - SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size" - - @classmethod - def all(cls): - return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) - - @classmethod - def validate(cls, key): - if key not in cls.all(): - raise ValueError("Invalid Reduction Key %s." % key) - - @tf_export(v1=["losses.Reduction"]) -class Reduction(ReductionV2): +class Reduction(object): """Types of loss reduction. Contains the following values: @@ -71,6 +47,9 @@ class Reduction(ReductionV2): `SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`. """ + NONE = "none" + SUM = "weighted_sum" + SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size" MEAN = "weighted_mean" SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights" SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS diff --git a/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt index b2adb52660f..258ad5047eb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt @@ -1,7 +1,6 @@ path: "tensorflow.losses.Reduction" tf_class { is_instance: "" - is_instance: "" is_instance: "" member { name: "MEAN" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-reduction.pbtxt new file mode 100644 index 00000000000..031d9b171fb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-reduction.pbtxt @@ -0,0 +1,28 @@ +path: "tensorflow.keras.losses.Reduction" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "NONE" + mtype: "" + } + member { + name: "SUM" + mtype: "" + } + member { + name: "SUM_OVER_BATCH_SIZE" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "all" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "validate" + argspec: "args=[\'cls\', \'key\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt index eca6b915388..8618c6f1c7c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.losses" tf_module { + member { + name: "Reduction" + mtype: "" + } member_method { name: "KLD" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt index 6a44e4ce66c..ad72e3194a8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.losses.Reduction" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "NONE" - mtype: "" + mtype: "" } member { name: "SUM"