Removing identity backtracking from entropy losses.
PiperOrigin-RevId: 316956157 Change-Id: I91130052e29e69ae131fe8aad0bbd1d4d42b00f1
This commit is contained in:
parent
67dd8f02fe
commit
56e71dd0e7
@ -4637,12 +4637,6 @@ def softsign(x):
|
||||
return nn.softsign(x)
|
||||
|
||||
|
||||
def _backtrack_identity(tensor):
|
||||
while tensor.op.type == 'Identity':
|
||||
tensor = tensor.op.inputs[0]
|
||||
return tensor
|
||||
|
||||
|
||||
@keras_export('keras.backend.categorical_crossentropy')
|
||||
@dispatch.add_dispatch_support
|
||||
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
@ -4695,17 +4689,16 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
return nn.softmax_cross_entropy_with_logits_v2(
|
||||
labels=target, logits=output, axis=axis)
|
||||
|
||||
if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
|
||||
output = _backtrack_identity(output)
|
||||
if output.op.type == 'Softmax':
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
# See b/117284466
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
return nn.softmax_cross_entropy_with_logits_v2(
|
||||
labels=target, logits=output, axis=axis)
|
||||
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Softmax'):
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
# See b/117284466
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
return nn.softmax_cross_entropy_with_logits_v2(
|
||||
labels=target, logits=output, axis=axis)
|
||||
|
||||
# scale preds so that the class probas of each sample sum to 1
|
||||
output = output / math_ops.reduce_sum(output, axis, True)
|
||||
@ -4740,17 +4733,16 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
target = ops.convert_to_tensor_v2(target)
|
||||
output = ops.convert_to_tensor_v2(output)
|
||||
|
||||
if not from_logits and not isinstance(
|
||||
output, (ops.EagerTensor, variables_module.Variable)):
|
||||
output = _backtrack_identity(output)
|
||||
if output.op.type == 'Softmax':
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
# See b/117284466
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
from_logits = True
|
||||
if (not from_logits and
|
||||
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Softmax'):
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
# See b/117284466
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
from_logits = True
|
||||
|
||||
if not from_logits:
|
||||
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
||||
@ -4821,15 +4813,14 @@ def binary_crossentropy(target, output, from_logits=False):
|
||||
if from_logits:
|
||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||
|
||||
if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
|
||||
output = _backtrack_identity(output)
|
||||
if output.op.type == 'Sigmoid':
|
||||
# When sigmoid activation function is used for output operation, we
|
||||
# use logits from the sigmoid function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Sigmoid'):
|
||||
# When sigmoid activation function is used for output operation, we
|
||||
# use logits from the sigmoid function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||
|
||||
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
||||
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
MAE = losses.MeanAbsoluteError
|
||||
@ -450,6 +451,19 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
||||
'Expected a symbolic Tensors or a callable for the loss value'):
|
||||
model.add_loss(model.weights[0])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_add_entropy_loss_on_functional_model(self):
|
||||
inputs = Input(shape=(1,))
|
||||
targets = Input(shape=(1,))
|
||||
outputs = testing_utils.Bias()(inputs)
|
||||
model = Model([inputs, targets], outputs)
|
||||
model.add_loss(losses.binary_crossentropy(targets, outputs))
|
||||
model.compile('sgd', run_eagerly=testing_utils.should_run_eagerly())
|
||||
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||
model.fit([self.x, self.y], batch_size=3, epochs=5)
|
||||
self.assertNotIn('Gradients do not exist for variables',
|
||||
str(mock_log.call_args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user