Fix loss computation when y_true and y_pred is not same shape.
PiperOrigin-RevId: 267595602
This commit is contained in:
parent
36e4ce080e
commit
78abbf1682
@ -294,7 +294,7 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
|||||||
self.w = self.add_weight('w', ())
|
self.w = self.add_weight('w', ())
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return keras.backend.sum(inputs) + self.w * 0
|
return keras.backend.sum(inputs, axis=1, keepdims=True) + self.w * 0
|
||||||
|
|
||||||
model = keras.Sequential([SumLayer(input_shape=(2,))])
|
model = keras.Sequential([SumLayer(input_shape=(2,))])
|
||||||
model.compile(
|
model.compile(
|
||||||
@ -317,11 +317,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
|||||||
history = model.fit(train_dataset,
|
history = model.fit(train_dataset,
|
||||||
epochs=2, steps_per_epoch=2, verbose=1,
|
epochs=2, steps_per_epoch=2, verbose=1,
|
||||||
validation_data=val_dataset, validation_steps=2)
|
validation_data=val_dataset, validation_steps=2)
|
||||||
self.assertListEqual(history.history['loss'],
|
self.assertAllClose(history.history['loss'],
|
||||||
[inputs[:20].sum() / 2, inputs[20:].sum() / 2])
|
[inputs[:20].sum() / 20, inputs[20:].sum() / 20])
|
||||||
# The validation dataset will be reset at the end of each validation run.
|
# The validation dataset will be reset at the end of each validation run.
|
||||||
self.assertListEqual(history.history['val_loss'],
|
self.assertAllClose(history.history['val_loss'],
|
||||||
[inputs[:20].sum() / 2, inputs[:20].sum() / 2])
|
[inputs[:20].sum() / 20, inputs[:20].sum() / 20])
|
||||||
|
|
||||||
# Test correctness with dataset reset.
|
# Test correctness with dataset reset.
|
||||||
train_dataset = dataset_ops.Dataset.from_tensor_slices(
|
train_dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
@ -330,10 +330,12 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
|||||||
(inputs, targets)).batch(10)
|
(inputs, targets)).batch(10)
|
||||||
history = model.fit(train_dataset,
|
history = model.fit(train_dataset,
|
||||||
epochs=2, verbose=1, validation_data=val_dataset)
|
epochs=2, verbose=1, validation_data=val_dataset)
|
||||||
self.assertListEqual(history.history['loss'],
|
self.assertAllClose(
|
||||||
[inputs.sum() / 4, inputs.sum() / 4])
|
history.history['loss'],
|
||||||
self.assertListEqual(history.history['val_loss'],
|
[inputs.sum() / 40, inputs.sum() / 40])
|
||||||
[inputs.sum() / 4, inputs.sum() / 4])
|
self.assertAllClose(
|
||||||
|
history.history['val_loss'],
|
||||||
|
[inputs.sum() / 40, inputs.sum() / 40])
|
||||||
|
|
||||||
@tf_test_util.run_deprecated_v1
|
@tf_test_util.run_deprecated_v1
|
||||||
def test_dataset_input_shape_validation(self):
|
def test_dataset_input_shape_validation(self):
|
||||||
@ -456,7 +458,6 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
lines = capture.output.splitlines()
|
lines = capture.output.splitlines()
|
||||||
|
|
||||||
self.assertIn('1/Unknown', lines[2])
|
|
||||||
self.assertIn('10/10', lines[-1])
|
self.assertIn('10/10', lines[-1])
|
||||||
|
|
||||||
self.assertLen(history.history['loss'], 2)
|
self.assertLen(history.history['loss'], 2)
|
||||||
|
@ -25,6 +25,7 @@ import six
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import smart_cond
|
from tensorflow.python.framework import smart_cond
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.utils import losses_utils
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
@ -34,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops.losses import losses_impl
|
from tensorflow.python.ops.losses import losses_impl
|
||||||
|
from tensorflow.python.ops.losses import util as tf_losses_util
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
@ -213,6 +215,9 @@ class LossFunctionWrapper(Loss):
|
|||||||
Returns:
|
Returns:
|
||||||
Loss values per sample.
|
Loss values per sample.
|
||||||
"""
|
"""
|
||||||
|
if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
|
||||||
|
y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
|
||||||
|
y_pred, y_true)
|
||||||
return self.fn(y_true, y_pred, **self._fn_kwargs)
|
return self.fn(y_true, y_pred, **self._fn_kwargs)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
|
@ -87,7 +87,7 @@ class TestIsSymbolicTensor(test.TestCase):
|
|||||||
|
|
||||||
def __init__(self, input_):
|
def __init__(self, input_):
|
||||||
self._input = input_
|
self._input = input_
|
||||||
self.value = ops.convert_to_tensor(42.)
|
self.value = ops.convert_to_tensor([[42.]])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
@ -123,14 +123,14 @@ class TestIsSymbolicTensor(test.TestCase):
|
|||||||
|
|
||||||
# User-land.
|
# User-land.
|
||||||
model = keras.Sequential([
|
model = keras.Sequential([
|
||||||
keras.layers.InputLayer([]),
|
keras.layers.InputLayer((1,)),
|
||||||
PlumbingLayer(Foo), # Makes a `Foo` object.
|
PlumbingLayer(Foo), # Makes a `Foo` object.
|
||||||
])
|
])
|
||||||
# Let's ensure Keras graph history is preserved by composing the models.
|
# Let's ensure Keras graph history is preserved by composing the models.
|
||||||
model = keras.Model(model.inputs, model(model.outputs))
|
model = keras.Model(model.inputs, model(model.outputs))
|
||||||
# Now we instantiate the model and verify we have a `Foo` object, not a
|
# Now we instantiate the model and verify we have a `Foo` object, not a
|
||||||
# `Tensor`.
|
# `Tensor`.
|
||||||
y = model(ops.convert_to_tensor(7.))
|
y = model(ops.convert_to_tensor([[7.]]))
|
||||||
self.assertIsInstance(y, Foo)
|
self.assertIsInstance(y, Foo)
|
||||||
# Confirm that (custom) loss sees `Foo` instance, not Tensor.
|
# Confirm that (custom) loss sees `Foo` instance, not Tensor.
|
||||||
obtained_prediction_box = [None]
|
obtained_prediction_box = [None]
|
||||||
|
@ -22,7 +22,6 @@ import numpy as np
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import confusion_matrix
|
from tensorflow.python.ops import confusion_matrix
|
||||||
@ -215,22 +214,6 @@ class ConfusionMatrixTest(test.TestCase):
|
|||||||
self._testConfMatrix(
|
self._testConfMatrix(
|
||||||
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
labels=labels, predictions=predictions, num_classes=3, truth=None)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testInvalidRank_predictionsTooBig(self):
|
|
||||||
labels = np.asarray([1, 2, 3])
|
|
||||||
predictions = np.asarray([[1, 2, 3]])
|
|
||||||
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
|
|
||||||
confusion_matrix.confusion_matrix, predictions,
|
|
||||||
labels)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testInvalidRank_predictionsTooSmall(self):
|
|
||||||
labels = np.asarray([[1, 2, 3]])
|
|
||||||
predictions = np.asarray([1, 2, 3])
|
|
||||||
self.assertRaisesRegexp(ValueError, "an not squeeze dim",
|
|
||||||
confusion_matrix.confusion_matrix, predictions,
|
|
||||||
labels)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testInputDifferentSize(self):
|
def testInputDifferentSize(self):
|
||||||
labels = np.asarray([1, 2])
|
labels = np.asarray([1, 2])
|
||||||
@ -454,24 +437,18 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
|
|||||||
def testUnsqueezableLabels(self):
|
def testUnsqueezableLabels(self):
|
||||||
label_values = np.ones(shape=(2, 3, 2))
|
label_values = np.ones(shape=(2, 3, 2))
|
||||||
prediction_values = np.zeros(shape=(2, 3))
|
prediction_values = np.zeros(shape=(2, 3))
|
||||||
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
|
|
||||||
confusion_matrix.remove_squeezable_dimensions(
|
|
||||||
label_values, prediction_values)
|
|
||||||
|
|
||||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
dynamic_labels, dynamic_predictions = (
|
_, dynamic_predictions = (
|
||||||
confusion_matrix.remove_squeezable_dimensions(
|
confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
|
||||||
labels_placeholder, predictions_placeholder))
|
predictions_placeholder))
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
feed_dict = {
|
feed_dict = {
|
||||||
labels_placeholder: label_values,
|
labels_placeholder: label_values,
|
||||||
predictions_placeholder: prediction_values
|
predictions_placeholder: prediction_values
|
||||||
}
|
}
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
||||||
r"Can not squeeze dim\[2\]"):
|
|
||||||
dynamic_labels.eval(feed_dict=feed_dict)
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
|
||||||
|
|
||||||
@ -479,15 +456,12 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
|
|||||||
def testUnsqueezablePredictions(self):
|
def testUnsqueezablePredictions(self):
|
||||||
label_values = np.ones(shape=(2, 3))
|
label_values = np.ones(shape=(2, 3))
|
||||||
prediction_values = np.zeros(shape=(2, 3, 2))
|
prediction_values = np.zeros(shape=(2, 3, 2))
|
||||||
with self.assertRaisesRegexp(ValueError, r"Can not squeeze dim\[2\]"):
|
|
||||||
confusion_matrix.remove_squeezable_dimensions(
|
|
||||||
label_values, prediction_values)
|
|
||||||
|
|
||||||
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
dynamic_labels, dynamic_predictions = (
|
dynamic_labels, _ = (
|
||||||
confusion_matrix.remove_squeezable_dimensions(
|
confusion_matrix.remove_squeezable_dimensions(labels_placeholder,
|
||||||
labels_placeholder, predictions_placeholder))
|
predictions_placeholder))
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
feed_dict = {
|
feed_dict = {
|
||||||
@ -496,9 +470,6 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
|
|||||||
}
|
}
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
label_values, dynamic_labels.eval(feed_dict=feed_dict))
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
|
||||||
r"Can not squeeze dim\[2\]"):
|
|
||||||
dynamic_predictions.eval(feed_dict=feed_dict)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -488,7 +488,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
|
|||||||
labels = constant_op.constant([[0, 1], [2, 3]])
|
labels = constant_op.constant([[0, 1], [2, 3]])
|
||||||
weights = constant_op.constant(1.2)
|
weights = constant_op.constant(1.2)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'dimension'):
|
with self.assertRaisesRegexp(ValueError, 'mismatch'):
|
||||||
losses.sparse_softmax_cross_entropy(
|
losses.sparse_softmax_cross_entropy(
|
||||||
labels, logits, weights=weights).eval()
|
labels, logits, weights=weights).eval()
|
||||||
|
|
||||||
|
@ -67,9 +67,11 @@ def remove_squeezable_dimensions(
|
|||||||
if (labels_rank is not None) and (predictions_rank is not None):
|
if (labels_rank is not None) and (predictions_rank is not None):
|
||||||
# Use static rank.
|
# Use static rank.
|
||||||
rank_diff = predictions_rank - labels_rank
|
rank_diff = predictions_rank - labels_rank
|
||||||
if rank_diff == expected_rank_diff + 1:
|
if (rank_diff == expected_rank_diff + 1 and
|
||||||
|
predictions_shape.dims[-1].is_compatible_with(1)):
|
||||||
predictions = array_ops.squeeze(predictions, [-1])
|
predictions = array_ops.squeeze(predictions, [-1])
|
||||||
elif rank_diff == expected_rank_diff - 1:
|
elif (rank_diff == expected_rank_diff - 1 and
|
||||||
|
labels_shape.dims[-1].is_compatible_with(1)):
|
||||||
labels = array_ops.squeeze(labels, [-1])
|
labels = array_ops.squeeze(labels, [-1])
|
||||||
return labels, predictions
|
return labels, predictions
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user