Make softmax and sigmoid activation backtracking more robust.

Logits are now cached on the post-activation Tensors and used whenever
available. This allows logits to be used in eager mode as well as in the middle of a tf.function, whereas previously they were only able to be used in the middle of a tf.function.

The only time this mechanism will not work (and the fallback, existing mechanism will be used) is when the activation and the loss function are separated by a tf.function boundary.

Also fixes Loss class to only autograph the loss function when in eager mode. This allows Python control flow to work correctly in the loss function.

PiperOrigin-RevId: 333367543
Change-Id: I64454b9acb6a77247122d72e1302aa5110415d57
This commit is contained in:
Thomas O'Malley 2020-09-23 14:00:30 -07:00 committed by TensorFlower Gardener
parent 7f5b8ad370
commit 40ccaf67ca
4 changed files with 96 additions and 14 deletions

View File

@ -71,17 +71,21 @@ def softmax(x, axis=-1):
Raises:
ValueError: In case `dim(x) == 1`.
"""
ndim = K.ndim(x)
if ndim == 2:
return nn.softmax(x)
elif ndim > 2:
rank = x.shape.rank
if rank == 2:
output = nn.softmax(x)
elif rank > 2:
e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
return e / s
output = e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
'Received input: %s' % (x,))
# Cache the logits to use for crossentropy loss.
output._keras_logits = x # pylint: disable=protected-access
return output
@keras_export('keras.activations.elu')
@dispatch.add_dispatch_support
@ -391,7 +395,10 @@ def sigmoid(x):
Returns:
Tensor with the sigmoid activation: `1 / (1 + exp(-x))`.
"""
return nn.sigmoid(x)
output = nn.sigmoid(x)
# Cache the logits to use for crossentropy loss.
output._keras_logits = x # pylint: disable=protected-access
return output
@keras_export('keras.activations.exponential')

View File

@ -4801,8 +4801,14 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""
target = ops.convert_to_tensor_v2_with_dispatch(target)
output = ops.convert_to_tensor_v2_with_dispatch(output)
target.shape.assert_is_compatible_with(output.shape)
# Use logits whenever they are available. `softmax` and `sigmoid`
# activations cache logits on the `output` Tensor.
if hasattr(output, '_keras_logits'):
output = output._keras_logits # pylint: disable=protected-access
from_logits = True
if from_logits:
return nn.softmax_cross_entropy_with_logits_v2(
labels=target, logits=output, axis=axis)
@ -4852,9 +4858,14 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = ops.convert_to_tensor_v2_with_dispatch(target)
output = ops.convert_to_tensor_v2_with_dispatch(output)
if (not from_logits and
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
# Use logits whenever they are available. `softmax` and `sigmoid`
# activations cache logits on the `output` Tensor.
if hasattr(output, '_keras_logits'):
output = output._keras_logits # pylint: disable=protected-access
from_logits = True
elif (not from_logits and
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
@ -4862,8 +4873,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
assert len(output.op.inputs) == 1
output = output.op.inputs[0]
from_logits = True
if not from_logits:
elif not from_logits:
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
output = math_ops.log(output)
@ -4930,6 +4940,12 @@ def binary_crossentropy(target, output, from_logits=False):
target = ops.convert_to_tensor_v2_with_dispatch(target)
output = ops.convert_to_tensor_v2_with_dispatch(output)
# Use logits whenever they are available. `softmax` and `sigmoid`
# activations cache logits on the `output` Tensor.
if hasattr(output, '_keras_logits'):
output = output._keras_logits # pylint: disable=protected-access
from_logits = True
if from_logits:
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)

View File

@ -25,6 +25,7 @@ import six
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_util
@ -144,8 +145,11 @@ class Loss(object):
graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
y_true, y_pred, sample_weight)
with K.name_scope(self._name_scope), graph_ctx:
ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
losses = ag_call(y_true, y_pred)
if context.executing_eagerly():
call_fn = self.call
else:
call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
losses = call_fn(y_true, y_pred)
return losses_utils.compute_weighted_loss(
losses, sample_weight, reduction=self._get_reduction())

View File

@ -21,11 +21,13 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras import losses
@ -245,6 +247,59 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'Could not interpret loss'):
losses.get(0)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_binary_crossentropy_uses_cached_logits(self):
logits = constant_op.constant([[-30., 30.]])
y_pred = activations.sigmoid(logits)
self.assertTrue(hasattr(y_pred, '_keras_logits'))
y_true = constant_op.constant([[0., 1.]])
loss = losses.binary_crossentropy(y_true, y_pred)[0]
# Check that logits are used. If y_pred is used directly, loss will
# collapse to 0 from underflow.
self.assertNotEqual(self.evaluate(loss), 0.)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_categorical_crossentropy_uses_cached_logits(self):
logits = constant_op.constant([[-5., 0., 5.]])
y_pred = activations.softmax(logits)
self.assertTrue(hasattr(y_pred, '_keras_logits'))
y_true = constant_op.constant([[0., 0., 1.]])
loss = losses.categorical_crossentropy(y_true, logits, from_logits=True)[0]
# Check that logits are used. If y_pred is used directly, loss will
# collapse to 0 from underflow.
self.assertNotEqual(self.evaluate(loss), 0.)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_sparse_categorical_crossentropy_uses_cached_logits(self):
logits = constant_op.constant([[-5., 0., 5.]])
y_pred = activations.softmax(logits)
self.assertTrue(hasattr(y_pred, '_keras_logits'))
y_true = constant_op.constant([2])
loss = losses.sparse_categorical_crossentropy(
y_true, logits, from_logits=True)[0]
# Check that logits are used. If y_pred is used directly, loss will
# collapse to 0 from underflow.
self.assertNotEqual(self.evaluate(loss), 0.)
@combinations.generate(combinations.combine(mode=['eager']))
def test_loss_not_autographed_in_eager(self):
class MyLoss(losses.Loss):
def call(self, y_true, y_pred):
return y_true - y_pred
loss = MyLoss()
y_true = constant_op.constant([[0., 0., 0.]])
y_pred = constant_op.constant([[1., 1., 1.]])
def tf_convert(fn, _):
assert False, 'Function should not be autographed.'
return fn
with test.mock.patch.object(autograph, 'tf_convert', tf_convert):
loss(y_true, y_pred)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class MeanSquaredErrorTest(test.TestCase):