Fix critical bug with add_loss
TFOpLayer graph construction
that caused incorrect loss values and backprop issues PiperOrigin-RevId: 320257330 Change-Id: I0a030bc7632735b152454657fd15e41539b4e4bd
This commit is contained in:
parent
14b2d686d6
commit
d03e29a094
@ -4690,7 +4690,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
labels=target, logits=output, axis=axis)
|
||||
|
||||
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Softmax'):
|
||||
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
@ -4735,7 +4735,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
|
||||
if (not from_logits and
|
||||
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Softmax'):
|
||||
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
@ -4814,7 +4814,7 @@ def binary_crossentropy(target, output, from_logits=False):
|
||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||
|
||||
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Sigmoid'):
|
||||
output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
|
||||
# When sigmoid activation function is used for output operation, we
|
||||
# use logits from the sigmoid function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
@ -1833,6 +1834,37 @@ class AddLossTest(keras_parameterized.TestCase):
|
||||
|
||||
self.assertAllClose(model.get_weights(), model2.get_weights())
|
||||
|
||||
def test_add_loss_crossentropy_backtracking(self):
|
||||
inputs = input_layer_lib.Input((2,))
|
||||
labels = input_layer_lib.Input((1,))
|
||||
outputs = layers.Dense(1, activation='sigmoid')(inputs)
|
||||
model = functional.Functional([inputs, labels], outputs)
|
||||
model.add_loss(losses.binary_crossentropy(labels, outputs))
|
||||
model.compile('adam')
|
||||
x = np.random.random((2, 2))
|
||||
y = np.random.random((2, 1))
|
||||
model.fit([x, y])
|
||||
|
||||
inputs = input_layer_lib.Input((2,))
|
||||
labels = input_layer_lib.Input((2,))
|
||||
outputs = layers.Dense(2, activation='softmax')(inputs)
|
||||
model = functional.Functional([inputs, labels], outputs)
|
||||
model.add_loss(losses.categorical_crossentropy(labels, outputs))
|
||||
model.compile('adam')
|
||||
x = np.random.random((2, 2))
|
||||
y = np.random.random((2, 2))
|
||||
model.fit([x, y])
|
||||
|
||||
inputs = input_layer_lib.Input((2,))
|
||||
labels = input_layer_lib.Input((1,), dtype='int32')
|
||||
outputs = layers.Dense(2, activation='softmax')(inputs)
|
||||
model = functional.Functional([inputs, labels], outputs)
|
||||
model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs))
|
||||
model.compile('adam')
|
||||
x = np.random.random((2, 2))
|
||||
y = np.random.randint(0, 2, size=(2, 1))
|
||||
model.fit([x, y])
|
||||
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
class WeightAccessTest(keras_parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user