Fix critical bug with add_loss TFOpLayer graph construction

that caused incorrect loss values and backprop issues

PiperOrigin-RevId: 320257330
Change-Id: I0a030bc7632735b152454657fd15e41539b4e4bd
This commit is contained in:
Francois Chollet 2020-07-08 13:47:06 -07:00 committed by Geeta Chavan
parent 14b2d686d6
commit d03e29a094
2 changed files with 35 additions and 3 deletions

View File

@ -4690,7 +4690,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
labels=target, logits=output, axis=axis)
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Softmax'):
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
@ -4735,7 +4735,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
if (not from_logits and
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Softmax'):
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
@ -4814,7 +4814,7 @@ def binary_crossentropy(target, output, from_logits=False):
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Sigmoid'):
output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
# When sigmoid activation function is used for output operation, we
# use logits from the sigmoid function directly to compute loss in order
# to prevent collapsing zero when training.

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras import combinations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
@ -1833,6 +1834,37 @@ class AddLossTest(keras_parameterized.TestCase):
self.assertAllClose(model.get_weights(), model2.get_weights())
def test_add_loss_crossentropy_backtracking(self):
inputs = input_layer_lib.Input((2,))
labels = input_layer_lib.Input((1,))
outputs = layers.Dense(1, activation='sigmoid')(inputs)
model = functional.Functional([inputs, labels], outputs)
model.add_loss(losses.binary_crossentropy(labels, outputs))
model.compile('adam')
x = np.random.random((2, 2))
y = np.random.random((2, 1))
model.fit([x, y])
inputs = input_layer_lib.Input((2,))
labels = input_layer_lib.Input((2,))
outputs = layers.Dense(2, activation='softmax')(inputs)
model = functional.Functional([inputs, labels], outputs)
model.add_loss(losses.categorical_crossentropy(labels, outputs))
model.compile('adam')
x = np.random.random((2, 2))
y = np.random.random((2, 2))
model.fit([x, y])
inputs = input_layer_lib.Input((2,))
labels = input_layer_lib.Input((1,), dtype='int32')
outputs = layers.Dense(2, activation='softmax')(inputs)
model = functional.Functional([inputs, labels], outputs)
model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs))
model.compile('adam')
x = np.random.random((2, 2))
y = np.random.randint(0, 2, size=(2, 1))
model.fit([x, y])
@combinations.generate(combinations.keras_mode_combinations())
class WeightAccessTest(keras_parameterized.TestCase):