From 78abbf16828acf0d59cd54433a29b9ce1c83fe8e Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Fri, 6 Sep 2019 07:46:17 -0700 Subject: [PATCH] Fix loss computation when y_true and y_pred is not same shape. PiperOrigin-RevId: 267595602 --- .../keras/engine/training_dataset_test.py | 21 +++++----- tensorflow/python/keras/losses.py | 5 +++ .../python/keras/utils/tf_utils_test.py | 6 +-- .../kernel_tests/confusion_matrix_test.py | 41 +++---------------- tensorflow/python/kernel_tests/losses_test.py | 2 +- tensorflow/python/ops/confusion_matrix.py | 6 ++- 6 files changed, 30 insertions(+), 51 deletions(-) diff --git a/tensorflow/python/keras/engine/training_dataset_test.py b/tensorflow/python/keras/engine/training_dataset_test.py index fa67f5acdd7..efbfc099941 100644 --- a/tensorflow/python/keras/engine/training_dataset_test.py +++ b/tensorflow/python/keras/engine/training_dataset_test.py @@ -294,7 +294,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): self.w = self.add_weight('w', ()) def call(self, inputs): - return keras.backend.sum(inputs) + self.w * 0 + return keras.backend.sum(inputs, axis=1, keepdims=True) + self.w * 0 model = keras.Sequential([SumLayer(input_shape=(2,))]) model.compile( @@ -317,11 +317,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): history = model.fit(train_dataset, epochs=2, steps_per_epoch=2, verbose=1, validation_data=val_dataset, validation_steps=2) - self.assertListEqual(history.history['loss'], - [inputs[:20].sum() / 2, inputs[20:].sum() / 2]) + self.assertAllClose(history.history['loss'], + [inputs[:20].sum() / 20, inputs[20:].sum() / 20]) # The validation dataset will be reset at the end of each validation run. - self.assertListEqual(history.history['val_loss'], - [inputs[:20].sum() / 2, inputs[:20].sum() / 2]) + self.assertAllClose(history.history['val_loss'], + [inputs[:20].sum() / 20, inputs[:20].sum() / 20]) # Test correctness with dataset reset. train_dataset = dataset_ops.Dataset.from_tensor_slices( @@ -330,10 +330,12 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): (inputs, targets)).batch(10) history = model.fit(train_dataset, epochs=2, verbose=1, validation_data=val_dataset) - self.assertListEqual(history.history['loss'], - [inputs.sum() / 4, inputs.sum() / 4]) - self.assertListEqual(history.history['val_loss'], - [inputs.sum() / 4, inputs.sum() / 4]) + self.assertAllClose( + history.history['loss'], + [inputs.sum() / 40, inputs.sum() / 40]) + self.assertAllClose( + history.history['val_loss'], + [inputs.sum() / 40, inputs.sum() / 40]) @tf_test_util.run_deprecated_v1 def test_dataset_input_shape_validation(self): @@ -456,7 +458,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): lines = capture.output.splitlines() - self.assertIn('1/Unknown', lines[2]) self.assertIn('10/10', lines[-1]) self.assertLen(history.history['loss'], 2) diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 3a23bb66ddb..47448e228e3 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -25,6 +25,7 @@ import six from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils import losses_utils from tensorflow.python.keras.utils import tf_utils @@ -34,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.losses import losses_impl +from tensorflow.python.ops.losses import util as tf_losses_util from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -213,6 +215,9 @@ class LossFunctionWrapper(Loss): Returns: Loss values per sample. """ + if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true): + y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions( + y_pred, y_true) return self.fn(y_true, y_pred, **self._fn_kwargs) def get_config(self): diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py index 902ecf91670..11cd5fe1ff9 100644 --- a/tensorflow/python/keras/utils/tf_utils_test.py +++ b/tensorflow/python/keras/utils/tf_utils_test.py @@ -87,7 +87,7 @@ class TestIsSymbolicTensor(test.TestCase): def __init__(self, input_): self._input = input_ - self.value = ops.convert_to_tensor(42.) + self.value = ops.convert_to_tensor([[42.]]) @property def dtype(self): @@ -123,14 +123,14 @@ class TestIsSymbolicTensor(test.TestCase): # User-land. model = keras.Sequential([ - keras.layers.InputLayer([]), + keras.layers.InputLayer((1,)), PlumbingLayer(Foo), # Makes a `Foo` object. ]) # Let's ensure Keras graph history is preserved by composing the models. model = keras.Model(model.inputs, model(model.outputs)) # Now we instantiate the model and verify we have a `Foo` object, not a # `Tensor`. - y = model(ops.convert_to_tensor(7.)) + y = model(ops.convert_to_tensor([[7.]])) self.assertIsInstance(y, Foo) # Confirm that (custom) loss sees `Foo` instance, not Tensor. obtained_prediction_box = [None] diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index cd6bd29e0de..c1178253a4b 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -22,7 +22,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import confusion_matrix @@ -215,22 +214,6 @@ class ConfusionMatrixTest(test.TestCase): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None) - @test_util.run_deprecated_v1 - def testInvalidRank_predictionsTooBig(self): - labels = np.asarray([1, 2, 3]) - predictions = np.asarray([[1, 2, 3]]) - self.assertRaisesRegexp(ValueError, "an not squeeze dim", - confusion_matrix.confusion_matrix, predictions, - labels) - - @test_util.run_deprecated_v1 - def testInvalidRank_predictionsTooSmall(self): - labels = np.asarray([[1, 2, 3]]) - predictions = np.asarray([1, 2, 3]) - self.assertRaisesRegexp(ValueError, "an not squeeze dim", - confusion_matrix.confusion_matrix, predictions, - labels) - @test_util.run_deprecated_v1 def testInputDifferentSize(self): labels = np.asarray([1, 2]) @@ -454,24 +437,18 @@ class RemoveSqueezableDimensionsTest(test.TestCase): def testUnsqueezableLabels(self): label_values = np.ones(shape=(2, 3, 2)) prediction_values = np.zeros(shape=(2, 3)) - with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"): - confusion_matrix.remove_squeezable_dimensions( - label_values, prediction_values) labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) - dynamic_labels, dynamic_predictions = ( - confusion_matrix.remove_squeezable_dimensions( - labels_placeholder, predictions_placeholder)) + _, dynamic_predictions = ( + confusion_matrix.remove_squeezable_dimensions(labels_placeholder, + predictions_placeholder)) with self.cached_session(): feed_dict = { labels_placeholder: label_values, predictions_placeholder: prediction_values } - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - r"Can not squeeze dim\[2\]"): - dynamic_labels.eval(feed_dict=feed_dict) self.assertAllEqual( prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) @@ -479,15 +456,12 @@ class RemoveSqueezableDimensionsTest(test.TestCase): def testUnsqueezablePredictions(self): label_values = np.ones(shape=(2, 3)) prediction_values = np.zeros(shape=(2, 3, 2)) - with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"): - confusion_matrix.remove_squeezable_dimensions( - label_values, prediction_values) labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) - dynamic_labels, dynamic_predictions = ( - confusion_matrix.remove_squeezable_dimensions( - labels_placeholder, predictions_placeholder)) + dynamic_labels, _ = ( + confusion_matrix.remove_squeezable_dimensions(labels_placeholder, + predictions_placeholder)) with self.cached_session(): feed_dict = { @@ -496,9 +470,6 @@ class RemoveSqueezableDimensionsTest(test.TestCase): } self.assertAllEqual( label_values, dynamic_labels.eval(feed_dict=feed_dict)) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - r"Can not squeeze dim\[2\]"): - dynamic_predictions.eval(feed_dict=feed_dict) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 203ac344ec2..b5f3e317d1c 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -488,7 +488,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): labels = constant_op.constant([[0, 1], [2, 3]]) weights = constant_op.constant(1.2) - with self.assertRaisesRegexp(ValueError, 'dimension'): + with self.assertRaisesRegexp(ValueError, 'mismatch'): losses.sparse_softmax_cross_entropy( labels, logits, weights=weights).eval() diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index bdee7b406f9..3e885975b03 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -67,9 +67,11 @@ def remove_squeezable_dimensions( if (labels_rank is not None) and (predictions_rank is not None): # Use static rank. rank_diff = predictions_rank - labels_rank - if rank_diff == expected_rank_diff + 1: + if (rank_diff == expected_rank_diff + 1 and + predictions_shape.dims[-1].is_compatible_with(1)): predictions = array_ops.squeeze(predictions, [-1]) - elif rank_diff == expected_rank_diff - 1: + elif (rank_diff == expected_rank_diff - 1 and + labels_shape.dims[-1].is_compatible_with(1)): labels = array_ops.squeeze(labels, [-1]) return labels, predictions