Fix loss computation when y_true and y_pred is not same shape.
PiperOrigin-RevId: 267595602
This commit is contained in:
parent
36e4ce080e
commit
78abbf1682
@ -294,7 +294,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
||||
self.w = self.add_weight('w', ())
|
||||
|
||||
def call(self, inputs):
|
||||
return keras.backend.sum(inputs) + self.w * 0
|
||||
return keras.backend.sum(inputs, axis=1, keepdims=True) + self.w * 0
|
||||
|
||||
model = keras.Sequential([SumLayer(input_shape=(2,))])
|
||||
model.compile(
|
||||
@ -317,11 +317,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
||||
history = model.fit(train_dataset,
|
||||
epochs=2, steps_per_epoch=2, verbose=1,
|
||||
validation_data=val_dataset, validation_steps=2)
|
||||
self.assertListEqual(history.history['loss'],
|
||||
[inputs[:20].sum() / 2, inputs[20:].sum() / 2])
|
||||
self.assertAllClose(history.history['loss'],
|
||||
[inputs[:20].sum() / 20, inputs[20:].sum() / 20])
|
||||
# The validation dataset will be reset at the end of each validation run.
|
||||
self.assertListEqual(history.history['val_loss'],
|
||||
[inputs[:20].sum() / 2, inputs[:20].sum() / 2])
|
||||
self.assertAllClose(history.history['val_loss'],
|
||||
[inputs[:20].sum() / 20, inputs[:20].sum() / 20])
|
||||
|
||||
# Test correctness with dataset reset.
|
||||
train_dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||
@ -330,10 +330,12 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
||||
(inputs, targets)).batch(10)
|
||||
history = model.fit(train_dataset,
|
||||
epochs=2, verbose=1, validation_data=val_dataset)
|
||||
self.assertListEqual(history.history['loss'],
|
||||
[inputs.sum() / 4, inputs.sum() / 4])
|
||||
self.assertListEqual(history.history['val_loss'],
|
||||
[inputs.sum() / 4, inputs.sum() / 4])
|
||||
self.assertAllClose(
|
||||
history.history['loss'],
|
||||
[inputs.sum() / 40, inputs.sum() / 40])
|
||||
self.assertAllClose(
|
||||
history.history['val_loss'],
|
||||
[inputs.sum() / 40, inputs.sum() / 40])
|
||||
|
||||
@tf_test_util.run_deprecated_v1
|
||||
def test_dataset_input_shape_validation(self):
|
||||
@ -456,7 +458,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
||||
|
||||
lines = capture.output.splitlines()
|
||||
|
||||
self.assertIn('1/Unknown', lines[2])
|
||||
self.assertIn('10/10', lines[-1])
|
||||
|
||||
self.assertLen(history.history['loss'], 2)
|
||||
|
@ -25,6 +25,7 @@ import six
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
@ -34,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops.losses import losses_impl
|
||||
from tensorflow.python.ops.losses import util as tf_losses_util
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
@ -213,6 +215,9 @@ class LossFunctionWrapper(Loss):
|
||||
Returns:
|
||||
Loss values per sample.
|
||||
"""
|
||||
if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
|
||||
y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
|
||||
y_pred, y_true)
|
||||
return self.fn(y_true, y_pred, **self._fn_kwargs)
|
||||
|
||||
def get_config(self):
|
||||
|
@ -87,7 +87,7 @@ class TestIsSymbolicTensor(test.TestCase):
|
||||
|
||||
def __init__(self, input_):
|
||||
self._input = input_
|
||||
self.value = ops.convert_to_tensor(42.)
|
||||
self.value = ops.convert_to_tensor([[42.]])
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
@ -123,14 +123,14 @@ class TestIsSymbolicTensor(test.TestCase):
|
||||
|
||||
# User-land.
|
||||
model = keras.Sequential([
|
||||
keras.layers.InputLayer([]),
|
||||
keras.layers.InputLayer((1,)),
|
||||
PlumbingLayer(Foo), # Makes a `Foo` object.
|
||||
])
|
||||
# Let's ensure Keras graph history is preserved by composing the models.
|
||||
model = keras.Model(model.inputs, model(model.outputs))
|
||||
# Now we instantiate the model and verify we have a `Foo` object, not a
|
||||
# `Tensor`.
|
||||
y = model(ops.convert_to_tensor(7.))
|
||||
y = model(ops.convert_to_tensor([[7.]]))
|
||||
self.assertIsInstance(y, Foo)
|
||||
# Confirm that (custom) loss sees `Foo` instance, not Tensor.
|
||||
obtained_prediction_box = [None]
|
||||
|
@ -22,7 +22,6 @@ import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import confusion_matrix
|
||||
@ -215,22 +214,6 @@ class ConfusionMatrixTest(test.TestCase):
|
||||
self._testConfMatrix(
|
||||
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidRank_predictionsTooBig(self):
|
||||
labels = np.asarray([1, 2, 3])
|
||||
predictions = np.asarray([[1, 2, 3]])
|
||||
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
|
||||
confusion_matrix.confusion_matrix, predictions,
|
||||
labels)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInvalidRank_predictionsTooSmall(self):
|
||||
labels = np.asarray([[1, 2, 3]])
|
||||
predictions = np.asarray([1, 2, 3])
|
||||
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
|
||||
confusion_matrix.confusion_matrix, predictions,
|
||||
labels)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testInputDifferentSize(self):
|
||||
labels = np.asarray([1, 2])
|
||||
@ -454,24 +437,18 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
|
||||
def testUnsqueezableLabels(self):
|
||||
label_values = np.ones(shape=(2, 3, 2))
|
||||
prediction_values = np.zeros(shape=(2, 3))
|
||||
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values)
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
_, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
|
||||
predictions_placeholder))
|
||||
|
||||
with self.cached_session():
|
||||
feed_dict = {
|
||||
labels_placeholder: label_values,
|
||||
predictions_placeholder: prediction_values
|
||||
}
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
r"Can not squeeze dim\[2\]"):
|
||||
dynamic_labels.eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual(
|
||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||
|
||||
@ -479,15 +456,12 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
|
||||
def testUnsqueezablePredictions(self):
|
||||
label_values = np.ones(shape=(2, 3))
|
||||
prediction_values = np.zeros(shape=(2, 3, 2))
|
||||
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
label_values, prediction_values)
|
||||
|
||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||
dynamic_labels, dynamic_predictions = (
|
||||
confusion_matrix.remove_squeezable_dimensions(
|
||||
labels_placeholder, predictions_placeholder))
|
||||
dynamic_labels, _ = (
|
||||
confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
|
||||
predictions_placeholder))
|
||||
|
||||
with self.cached_session():
|
||||
feed_dict = {
|
||||
@ -496,9 +470,6 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
|
||||
}
|
||||
self.assertAllEqual(
|
||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
r"Can not squeeze dim\[2\]"):
|
||||
dynamic_predictions.eval(feed_dict=feed_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -488,7 +488,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
||||
labels = constant_op.constant([[0, 1], [2, 3]])
|
||||
weights = constant_op.constant(1.2)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'dimension'):
|
||||
with self.assertRaisesRegexp(ValueError, 'mismatch'):
|
||||
losses.sparse_softmax_cross_entropy(
|
||||
labels, logits, weights=weights).eval()
|
||||
|
||||
|
@ -67,9 +67,11 @@ def remove_squeezable_dimensions(
|
||||
if (labels_rank is not None) and (predictions_rank is not None):
|
||||
# Use static rank.
|
||||
rank_diff = predictions_rank - labels_rank
|
||||
if rank_diff == expected_rank_diff + 1:
|
||||
if (rank_diff == expected_rank_diff + 1 and
|
||||
predictions_shape.dims[-1].is_compatible_with(1)):
|
||||
predictions = array_ops.squeeze(predictions, [-1])
|
||||
elif rank_diff == expected_rank_diff - 1:
|
||||
elif (rank_diff == expected_rank_diff - 1 and
|
||||
labels_shape.dims[-1].is_compatible_with(1)):
|
||||
labels = array_ops.squeeze(labels, [-1])
|
||||
return labels, predictions
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user