Removing identity backtracking from entropy losses.

PiperOrigin-RevId: 316956157
Change-Id: I91130052e29e69ae131fe8aad0bbd1d4d42b00f1
This commit is contained in:
Pavithra Vijay 2020-06-17 13:42:03 -07:00 committed by TensorFlower Gardener
parent 67dd8f02fe
commit 56e71dd0e7
2 changed files with 42 additions and 37 deletions

View File

@ -4637,12 +4637,6 @@ def softsign(x):
return nn.softsign(x) return nn.softsign(x)
def _backtrack_identity(tensor):
while tensor.op.type == 'Identity':
tensor = tensor.op.inputs[0]
return tensor
@keras_export('keras.backend.categorical_crossentropy') @keras_export('keras.backend.categorical_crossentropy')
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def categorical_crossentropy(target, output, from_logits=False, axis=-1): def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@ -4695,17 +4689,16 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
return nn.softmax_cross_entropy_with_logits_v2( return nn.softmax_cross_entropy_with_logits_v2(
labels=target, logits=output, axis=axis) labels=target, logits=output, axis=axis)
if not isinstance(output, (ops.EagerTensor, variables_module.Variable)): if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output = _backtrack_identity(output) output.op.type == 'Softmax'):
if output.op.type == 'Softmax': # When softmax activation function is used for output operation, we
# When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order
# use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training.
# to prevent collapsing zero when training. # See b/117284466
# See b/117284466 assert len(output.op.inputs) == 1
assert len(output.op.inputs) == 1 output = output.op.inputs[0]
output = output.op.inputs[0] return nn.softmax_cross_entropy_with_logits_v2(
return nn.softmax_cross_entropy_with_logits_v2( labels=target, logits=output, axis=axis)
labels=target, logits=output, axis=axis)
# scale preds so that the class probas of each sample sum to 1 # scale preds so that the class probas of each sample sum to 1
output = output / math_ops.reduce_sum(output, axis, True) output = output / math_ops.reduce_sum(output, axis, True)
@ -4740,17 +4733,16 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = ops.convert_to_tensor_v2(target) target = ops.convert_to_tensor_v2(target)
output = ops.convert_to_tensor_v2(output) output = ops.convert_to_tensor_v2(output)
if not from_logits and not isinstance( if (not from_logits and
output, (ops.EagerTensor, variables_module.Variable)): not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output = _backtrack_identity(output) output.op.type == 'Softmax'):
if output.op.type == 'Softmax': # When softmax activation function is used for output operation, we
# When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order
# use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training.
# to prevent collapsing zero when training. # See b/117284466
# See b/117284466 assert len(output.op.inputs) == 1
assert len(output.op.inputs) == 1 output = output.op.inputs[0]
output = output.op.inputs[0] from_logits = True
from_logits = True
if not from_logits: if not from_logits:
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
@ -4821,15 +4813,14 @@ def binary_crossentropy(target, output, from_logits=False):
if from_logits: if from_logits:
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output) return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
if not isinstance(output, (ops.EagerTensor, variables_module.Variable)): if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output = _backtrack_identity(output) output.op.type == 'Sigmoid'):
if output.op.type == 'Sigmoid': # When sigmoid activation function is used for output operation, we
# When sigmoid activation function is used for output operation, we # use logits from the sigmoid function directly to compute loss in order
# use logits from the sigmoid function directly to compute loss in order # to prevent collapsing zero when training.
# to prevent collapsing zero when training. assert len(output.op.inputs) == 1
assert len(output.op.inputs) == 1 output = output.op.inputs[0]
output = output.op.inputs[0] return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_) output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer from tensorflow.python.training.rmsprop import RMSPropOptimizer
MAE = losses.MeanAbsoluteError MAE = losses.MeanAbsoluteError
@ -450,6 +451,19 @@ class TestAddLossCorrectness(keras_parameterized.TestCase):
'Expected a symbolic Tensors or a callable for the loss value'): 'Expected a symbolic Tensors or a callable for the loss value'):
model.add_loss(model.weights[0]) model.add_loss(model.weights[0])
@keras_parameterized.run_all_keras_modes
def test_add_entropy_loss_on_functional_model(self):
inputs = Input(shape=(1,))
targets = Input(shape=(1,))
outputs = testing_utils.Bias()(inputs)
model = Model([inputs, targets], outputs)
model.add_loss(losses.binary_crossentropy(targets, outputs))
model.compile('sgd', run_eagerly=testing_utils.should_run_eagerly())
with test.mock.patch.object(logging, 'warning') as mock_log:
model.fit([self.x, self.y], batch_size=3, epochs=5)
self.assertNotIn('Gradients do not exist for variables',
str(mock_log.call_args))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()