Fix loss computation when y_true and y_pred is not same shape.

PiperOrigin-RevId: 267595602
This commit is contained in:
Zhenyu Tan 2019-09-06 07:46:17 -07:00 committed by TensorFlower Gardener
parent 36e4ce080e
commit 78abbf1682
6 changed files with 30 additions and 51 deletions

View File

@ -294,7 +294,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
self.w = self.add_weight('w', ()) self.w = self.add_weight('w', ())
def call(self, inputs): def call(self, inputs):
return keras.backend.sum(inputs) + self.w * 0 return keras.backend.sum(inputs, axis=1, keepdims=True) + self.w * 0
model = keras.Sequential([SumLayer(input_shape=(2,))]) model = keras.Sequential([SumLayer(input_shape=(2,))])
model.compile( model.compile(
@ -317,11 +317,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
history = model.fit(train_dataset, history = model.fit(train_dataset,
epochs=2, steps_per_epoch=2, verbose=1, epochs=2, steps_per_epoch=2, verbose=1,
validation_data=val_dataset, validation_steps=2) validation_data=val_dataset, validation_steps=2)
self.assertListEqual(history.history['loss'], self.assertAllClose(history.history['loss'],
[inputs[:20].sum() / 2, inputs[20:].sum() / 2]) [inputs[:20].sum() / 20, inputs[20:].sum() / 20])
# The validation dataset will be reset at the end of each validation run. # The validation dataset will be reset at the end of each validation run.
self.assertListEqual(history.history['val_loss'], self.assertAllClose(history.history['val_loss'],
[inputs[:20].sum() / 2, inputs[:20].sum() / 2]) [inputs[:20].sum() / 20, inputs[:20].sum() / 20])
# Test correctness with dataset reset. # Test correctness with dataset reset.
train_dataset = dataset_ops.Dataset.from_tensor_slices( train_dataset = dataset_ops.Dataset.from_tensor_slices(
@ -330,10 +330,12 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
(inputs, targets)).batch(10) (inputs, targets)).batch(10)
history = model.fit(train_dataset, history = model.fit(train_dataset,
epochs=2, verbose=1, validation_data=val_dataset) epochs=2, verbose=1, validation_data=val_dataset)
self.assertListEqual(history.history['loss'], self.assertAllClose(
[inputs.sum() / 4, inputs.sum() / 4]) history.history['loss'],
self.assertListEqual(history.history['val_loss'], [inputs.sum() / 40, inputs.sum() / 40])
[inputs.sum() / 4, inputs.sum() / 4]) self.assertAllClose(
history.history['val_loss'],
[inputs.sum() / 40, inputs.sum() / 40])
@tf_test_util.run_deprecated_v1 @tf_test_util.run_deprecated_v1
def test_dataset_input_shape_validation(self): def test_dataset_input_shape_validation(self):
@ -456,7 +458,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
lines = capture.output.splitlines() lines = capture.output.splitlines()
self.assertIn('1/Unknown', lines[2])
self.assertIn('10/10', lines[-1]) self.assertIn('10/10', lines[-1])
self.assertLen(history.history['loss'], 2) self.assertLen(history.history['loss'], 2)

View File

@ -25,6 +25,7 @@ import six
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import losses_utils from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
@ -34,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util as tf_losses_util
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_controls
@ -213,6 +215,9 @@ class LossFunctionWrapper(Loss):
Returns: Returns:
Loss values per sample. Loss values per sample.
""" """
if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
y_pred, y_true)
return self.fn(y_true, y_pred, **self._fn_kwargs) return self.fn(y_true, y_pred, **self._fn_kwargs)
def get_config(self): def get_config(self):

View File

@ -87,7 +87,7 @@ class TestIsSymbolicTensor(test.TestCase):
def __init__(self, input_): def __init__(self, input_):
self._input = input_ self._input = input_
self.value = ops.convert_to_tensor(42.) self.value = ops.convert_to_tensor([[42.]])
@property @property
def dtype(self): def dtype(self):
@ -123,14 +123,14 @@ class TestIsSymbolicTensor(test.TestCase):
# User-land. # User-land.
model = keras.Sequential([ model = keras.Sequential([
keras.layers.InputLayer([]), keras.layers.InputLayer((1,)),
PlumbingLayer(Foo), # Makes a `Foo` object. PlumbingLayer(Foo), # Makes a `Foo` object.
]) ])
# Let's ensure Keras graph history is preserved by composing the models. # Let's ensure Keras graph history is preserved by composing the models.
model = keras.Model(model.inputs, model(model.outputs)) model = keras.Model(model.inputs, model(model.outputs))
# Now we instantiate the model and verify we have a `Foo` object, not a # Now we instantiate the model and verify we have a `Foo` object, not a
# `Tensor`. # `Tensor`.
y = model(ops.convert_to_tensor(7.)) y = model(ops.convert_to_tensor([[7.]]))
self.assertIsInstance(y, Foo) self.assertIsInstance(y, Foo)
# Confirm that (custom) loss sees `Foo` instance, not Tensor. # Confirm that (custom) loss sees `Foo` instance, not Tensor.
obtained_prediction_box = [None] obtained_prediction_box = [None]

View File

@ -22,7 +22,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import confusion_matrix
@ -215,22 +214,6 @@ class ConfusionMatrixTest(test.TestCase):
self._testConfMatrix( self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None) labels=labels, predictions=predictions, num_classes=3, truth=None)
@test_util.run_deprecated_v1
def testInvalidRank_predictionsTooBig(self):
labels = np.asarray([1, 2, 3])
predictions = np.asarray([[1, 2, 3]])
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
confusion_matrix.confusion_matrix, predictions,
labels)
@test_util.run_deprecated_v1
def testInvalidRank_predictionsTooSmall(self):
labels = np.asarray([[1, 2, 3]])
predictions = np.asarray([1, 2, 3])
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
confusion_matrix.confusion_matrix, predictions,
labels)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testInputDifferentSize(self): def testInputDifferentSize(self):
labels = np.asarray([1, 2]) labels = np.asarray([1, 2])
@ -454,24 +437,18 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
def testUnsqueezableLabels(self): def testUnsqueezableLabels(self):
label_values = np.ones(shape=(2, 3, 2)) label_values = np.ones(shape=(2, 3, 2))
prediction_values = np.zeros(shape=(2, 3)) prediction_values = np.zeros(shape=(2, 3))
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values)
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = ( _, dynamic_predictions = (
confusion_matrix.remove_squeezable_dimensions( confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
labels_placeholder, predictions_placeholder)) predictions_placeholder))
with self.cached_session(): with self.cached_session():
feed_dict = { feed_dict = {
labels_placeholder: label_values, labels_placeholder: label_values,
predictions_placeholder: prediction_values predictions_placeholder: prediction_values
} }
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
r"Can not squeeze dim\[2\]"):
dynamic_labels.eval(feed_dict=feed_dict)
self.assertAllEqual( self.assertAllEqual(
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict)) prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
@ -479,15 +456,12 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
def testUnsqueezablePredictions(self): def testUnsqueezablePredictions(self):
label_values = np.ones(shape=(2, 3)) label_values = np.ones(shape=(2, 3))
prediction_values = np.zeros(shape=(2, 3, 2)) prediction_values = np.zeros(shape=(2, 3, 2))
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
confusion_matrix.remove_squeezable_dimensions(
label_values, prediction_values)
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32) labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32) predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
dynamic_labels, dynamic_predictions = ( dynamic_labels, _ = (
confusion_matrix.remove_squeezable_dimensions( confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
labels_placeholder, predictions_placeholder)) predictions_placeholder))
with self.cached_session(): with self.cached_session():
feed_dict = { feed_dict = {
@ -496,9 +470,6 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
} }
self.assertAllEqual( self.assertAllEqual(
label_values, dynamic_labels.eval(feed_dict=feed_dict)) label_values, dynamic_labels.eval(feed_dict=feed_dict))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
r"Can not squeeze dim\[2\]"):
dynamic_predictions.eval(feed_dict=feed_dict)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -488,7 +488,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[0, 1], [2, 3]]) labels = constant_op.constant([[0, 1], [2, 3]])
weights = constant_op.constant(1.2) weights = constant_op.constant(1.2)
with self.assertRaisesRegexp(ValueError, 'dimension'): with self.assertRaisesRegexp(ValueError, 'mismatch'):
losses.sparse_softmax_cross_entropy( losses.sparse_softmax_cross_entropy(
labels, logits, weights=weights).eval() labels, logits, weights=weights).eval()

View File

@ -67,9 +67,11 @@ def remove_squeezable_dimensions(
if (labels_rank is not None) and (predictions_rank is not None): if (labels_rank is not None) and (predictions_rank is not None):
# Use static rank. # Use static rank.
rank_diff = predictions_rank - labels_rank rank_diff = predictions_rank - labels_rank
if rank_diff == expected_rank_diff + 1: if (rank_diff == expected_rank_diff + 1 and
predictions_shape.dims[-1].is_compatible_with(1)):
predictions = array_ops.squeeze(predictions, [-1]) predictions = array_ops.squeeze(predictions, [-1])
elif rank_diff == expected_rank_diff - 1: elif (rank_diff == expected_rank_diff - 1 and
labels_shape.dims[-1].is_compatible_with(1)):
labels = array_ops.squeeze(labels, [-1]) labels = array_ops.squeeze(labels, [-1])
return labels, predictions return labels, predictions