Removing identity backtracking from entropy losses.
PiperOrigin-RevId: 316956157 Change-Id: I91130052e29e69ae131fe8aad0bbd1d4d42b00f1
This commit is contained in:
parent
67dd8f02fe
commit
56e71dd0e7
@ -4637,12 +4637,6 @@ def softsign(x):
|
|||||||
return nn.softsign(x)
|
return nn.softsign(x)
|
||||||
|
|
||||||
|
|
||||||
def _backtrack_identity(tensor):
|
|
||||||
while tensor.op.type == 'Identity':
|
|
||||||
tensor = tensor.op.inputs[0]
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.categorical_crossentropy')
|
@keras_export('keras.backend.categorical_crossentropy')
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||||
@ -4695,17 +4689,16 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|||||||
return nn.softmax_cross_entropy_with_logits_v2(
|
return nn.softmax_cross_entropy_with_logits_v2(
|
||||||
labels=target, logits=output, axis=axis)
|
labels=target, logits=output, axis=axis)
|
||||||
|
|
||||||
if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
|
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||||
output = _backtrack_identity(output)
|
output.op.type == 'Softmax'):
|
||||||
if output.op.type == 'Softmax':
|
# When softmax activation function is used for output operation, we
|
||||||
# When softmax activation function is used for output operation, we
|
# use logits from the softmax function directly to compute loss in order
|
||||||
# use logits from the softmax function directly to compute loss in order
|
# to prevent collapsing zero when training.
|
||||||
# to prevent collapsing zero when training.
|
# See b/117284466
|
||||||
# See b/117284466
|
assert len(output.op.inputs) == 1
|
||||||
assert len(output.op.inputs) == 1
|
output = output.op.inputs[0]
|
||||||
output = output.op.inputs[0]
|
return nn.softmax_cross_entropy_with_logits_v2(
|
||||||
return nn.softmax_cross_entropy_with_logits_v2(
|
labels=target, logits=output, axis=axis)
|
||||||
labels=target, logits=output, axis=axis)
|
|
||||||
|
|
||||||
# scale preds so that the class probas of each sample sum to 1
|
# scale preds so that the class probas of each sample sum to 1
|
||||||
output = output / math_ops.reduce_sum(output, axis, True)
|
output = output / math_ops.reduce_sum(output, axis, True)
|
||||||
@ -4740,17 +4733,16 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|||||||
target = ops.convert_to_tensor_v2(target)
|
target = ops.convert_to_tensor_v2(target)
|
||||||
output = ops.convert_to_tensor_v2(output)
|
output = ops.convert_to_tensor_v2(output)
|
||||||
|
|
||||||
if not from_logits and not isinstance(
|
if (not from_logits and
|
||||||
output, (ops.EagerTensor, variables_module.Variable)):
|
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||||
output = _backtrack_identity(output)
|
output.op.type == 'Softmax'):
|
||||||
if output.op.type == 'Softmax':
|
# When softmax activation function is used for output operation, we
|
||||||
# When softmax activation function is used for output operation, we
|
# use logits from the softmax function directly to compute loss in order
|
||||||
# use logits from the softmax function directly to compute loss in order
|
# to prevent collapsing zero when training.
|
||||||
# to prevent collapsing zero when training.
|
# See b/117284466
|
||||||
# See b/117284466
|
assert len(output.op.inputs) == 1
|
||||||
assert len(output.op.inputs) == 1
|
output = output.op.inputs[0]
|
||||||
output = output.op.inputs[0]
|
from_logits = True
|
||||||
from_logits = True
|
|
||||||
|
|
||||||
if not from_logits:
|
if not from_logits:
|
||||||
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
||||||
@ -4821,15 +4813,14 @@ def binary_crossentropy(target, output, from_logits=False):
|
|||||||
if from_logits:
|
if from_logits:
|
||||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||||
|
|
||||||
if not isinstance(output, (ops.EagerTensor, variables_module.Variable)):
|
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||||
output = _backtrack_identity(output)
|
output.op.type == 'Sigmoid'):
|
||||||
if output.op.type == 'Sigmoid':
|
# When sigmoid activation function is used for output operation, we
|
||||||
# When sigmoid activation function is used for output operation, we
|
# use logits from the sigmoid function directly to compute loss in order
|
||||||
# use logits from the sigmoid function directly to compute loss in order
|
# to prevent collapsing zero when training.
|
||||||
# to prevent collapsing zero when training.
|
assert len(output.op.inputs) == 1
|
||||||
assert len(output.op.inputs) == 1
|
output = output.op.inputs[0]
|
||||||
output = output.op.inputs[0]
|
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
|
||||||
|
|
||||||
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
||||||
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
|
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras import testing_utils
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||||
|
|
||||||
MAE = losses.MeanAbsoluteError
|
MAE = losses.MeanAbsoluteError
|
||||||
@ -450,6 +451,19 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
|
|||||||
'Expected a symbolic Tensors or a callable for the loss value'):
|
'Expected a symbolic Tensors or a callable for the loss value'):
|
||||||
model.add_loss(model.weights[0])
|
model.add_loss(model.weights[0])
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_add_entropy_loss_on_functional_model(self):
|
||||||
|
inputs = Input(shape=(1,))
|
||||||
|
targets = Input(shape=(1,))
|
||||||
|
outputs = testing_utils.Bias()(inputs)
|
||||||
|
model = Model([inputs, targets], outputs)
|
||||||
|
model.add_loss(losses.binary_crossentropy(targets, outputs))
|
||||||
|
model.compile('sgd', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
with test.mock.patch.object(logging, 'warning') as mock_log:
|
||||||
|
model.fit([self.x, self.y], batch_size=3, epochs=5)
|
||||||
|
self.assertNotIn('Gradients do not exist for variables',
|
||||||
|
str(mock_log.call_args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user