From d03e29a094ce5ff9af2c0b147538279962a2805b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 8 Jul 2020 13:47:06 -0700 Subject: [PATCH] Fix critical bug with `add_loss` TFOpLayer graph construction that caused incorrect loss values and backprop issues PiperOrigin-RevId: 320257330 Change-Id: I0a030bc7632735b152454657fd15e41539b4e4bd --- tensorflow/python/keras/backend.py | 6 ++-- .../python/keras/engine/functional_test.py | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 9330425272f..a02c62be842 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4690,7 +4690,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): labels=target, logits=output, axis=axis) if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Softmax'): + output.op.type == 'Softmax') and not hasattr(output, '_keras_history'): # When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training. @@ -4735,7 +4735,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): if (not from_logits and not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Softmax'): + output.op.type == 'Softmax') and not hasattr(output, '_keras_history'): # When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training. @@ -4814,7 +4814,7 @@ def binary_crossentropy(target, output, from_logits=False): return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output) if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Sigmoid'): + output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'): # When sigmoid activation function is used for output operation, we # use logits from the sigmoid function directly to compute loss in order # to prevent collapsing zero when training. diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 24b0e147b97..68667899903 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -34,6 +34,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import initializers from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers +from tensorflow.python.keras import losses from tensorflow.python.keras import models from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer @@ -1833,6 +1834,37 @@ class AddLossTest(keras_parameterized.TestCase): self.assertAllClose(model.get_weights(), model2.get_weights()) + def test_add_loss_crossentropy_backtracking(self): + inputs = input_layer_lib.Input((2,)) + labels = input_layer_lib.Input((1,)) + outputs = layers.Dense(1, activation='sigmoid')(inputs) + model = functional.Functional([inputs, labels], outputs) + model.add_loss(losses.binary_crossentropy(labels, outputs)) + model.compile('adam') + x = np.random.random((2, 2)) + y = np.random.random((2, 1)) + model.fit([x, y]) + + inputs = input_layer_lib.Input((2,)) + labels = input_layer_lib.Input((2,)) + outputs = layers.Dense(2, activation='softmax')(inputs) + model = functional.Functional([inputs, labels], outputs) + model.add_loss(losses.categorical_crossentropy(labels, outputs)) + model.compile('adam') + x = np.random.random((2, 2)) + y = np.random.random((2, 2)) + model.fit([x, y]) + + inputs = input_layer_lib.Input((2,)) + labels = input_layer_lib.Input((1,), dtype='int32') + outputs = layers.Dense(2, activation='softmax')(inputs) + model = functional.Functional([inputs, labels], outputs) + model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs)) + model.compile('adam') + x = np.random.random((2, 2)) + y = np.random.randint(0, 2, size=(2, 1)) + model.fit([x, y]) + @combinations.generate(combinations.keras_mode_combinations()) class WeightAccessTest(keras_parameterized.TestCase):