Make softmax and sigmoid activation backtracking more robust.
Logits are now cached on the post-activation Tensors and used whenever available. This allows logits to be used in eager mode as well as in the middle of a tf.function, whereas previously they were only able to be used in the middle of a tf.function. The only time this mechanism will not work (and the fallback, existing mechanism will be used) is when the activation and the loss function are separated by a tf.function boundary. Also fixes Loss class to only autograph the loss function when in eager mode. This allows Python control flow to work correctly in the loss function. PiperOrigin-RevId: 333367543 Change-Id: I64454b9acb6a77247122d72e1302aa5110415d57
This commit is contained in:
parent
7f5b8ad370
commit
40ccaf67ca
@ -71,17 +71,21 @@ def softmax(x, axis=-1):
|
||||
Raises:
|
||||
ValueError: In case `dim(x) == 1`.
|
||||
"""
|
||||
ndim = K.ndim(x)
|
||||
if ndim == 2:
|
||||
return nn.softmax(x)
|
||||
elif ndim > 2:
|
||||
rank = x.shape.rank
|
||||
if rank == 2:
|
||||
output = nn.softmax(x)
|
||||
elif rank > 2:
|
||||
e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
|
||||
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
|
||||
return e / s
|
||||
output = e / s
|
||||
else:
|
||||
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
|
||||
'Received input: %s' % (x,))
|
||||
|
||||
# Cache the logits to use for crossentropy loss.
|
||||
output._keras_logits = x # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
|
||||
@keras_export('keras.activations.elu')
|
||||
@dispatch.add_dispatch_support
|
||||
@ -391,7 +395,10 @@ def sigmoid(x):
|
||||
Returns:
|
||||
Tensor with the sigmoid activation: `1 / (1 + exp(-x))`.
|
||||
"""
|
||||
return nn.sigmoid(x)
|
||||
output = nn.sigmoid(x)
|
||||
# Cache the logits to use for crossentropy loss.
|
||||
output._keras_logits = x # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
|
||||
@keras_export('keras.activations.exponential')
|
||||
|
@ -4801,8 +4801,14 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
"""
|
||||
target = ops.convert_to_tensor_v2_with_dispatch(target)
|
||||
output = ops.convert_to_tensor_v2_with_dispatch(output)
|
||||
|
||||
target.shape.assert_is_compatible_with(output.shape)
|
||||
|
||||
# Use logits whenever they are available. `softmax` and `sigmoid`
|
||||
# activations cache logits on the `output` Tensor.
|
||||
if hasattr(output, '_keras_logits'):
|
||||
output = output._keras_logits # pylint: disable=protected-access
|
||||
from_logits = True
|
||||
|
||||
if from_logits:
|
||||
return nn.softmax_cross_entropy_with_logits_v2(
|
||||
labels=target, logits=output, axis=axis)
|
||||
@ -4852,9 +4858,14 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
target = ops.convert_to_tensor_v2_with_dispatch(target)
|
||||
output = ops.convert_to_tensor_v2_with_dispatch(output)
|
||||
|
||||
if (not from_logits and
|
||||
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
|
||||
# Use logits whenever they are available. `softmax` and `sigmoid`
|
||||
# activations cache logits on the `output` Tensor.
|
||||
if hasattr(output, '_keras_logits'):
|
||||
output = output._keras_logits # pylint: disable=protected-access
|
||||
from_logits = True
|
||||
elif (not from_logits and
|
||||
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
|
||||
# When softmax activation function is used for output operation, we
|
||||
# use logits from the softmax function directly to compute loss in order
|
||||
# to prevent collapsing zero when training.
|
||||
@ -4862,8 +4873,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||
assert len(output.op.inputs) == 1
|
||||
output = output.op.inputs[0]
|
||||
from_logits = True
|
||||
|
||||
if not from_logits:
|
||||
elif not from_logits:
|
||||
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
|
||||
output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
|
||||
output = math_ops.log(output)
|
||||
@ -4930,6 +4940,12 @@ def binary_crossentropy(target, output, from_logits=False):
|
||||
target = ops.convert_to_tensor_v2_with_dispatch(target)
|
||||
output = ops.convert_to_tensor_v2_with_dispatch(output)
|
||||
|
||||
# Use logits whenever they are available. `softmax` and `sigmoid`
|
||||
# activations cache logits on the `output` Tensor.
|
||||
if hasattr(output, '_keras_logits'):
|
||||
output = output._keras_logits # pylint: disable=protected-access
|
||||
from_logits = True
|
||||
|
||||
if from_logits:
|
||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||
|
||||
|
@ -25,6 +25,7 @@ import six
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -144,8 +145,11 @@ class Loss(object):
|
||||
graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
|
||||
y_true, y_pred, sample_weight)
|
||||
with K.name_scope(self._name_scope), graph_ctx:
|
||||
ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
|
||||
losses = ag_call(y_true, y_pred)
|
||||
if context.executing_eagerly():
|
||||
call_fn = self.call
|
||||
else:
|
||||
call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
|
||||
losses = call_fn(y_true, y_pred)
|
||||
return losses_utils.compute_weighted_loss(
|
||||
losses, sample_weight, reduction=self._get_reduction())
|
||||
|
||||
|
@ -21,11 +21,13 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import losses
|
||||
@ -245,6 +247,59 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'Could not interpret loss'):
|
||||
losses.get(0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_binary_crossentropy_uses_cached_logits(self):
|
||||
logits = constant_op.constant([[-30., 30.]])
|
||||
y_pred = activations.sigmoid(logits)
|
||||
self.assertTrue(hasattr(y_pred, '_keras_logits'))
|
||||
y_true = constant_op.constant([[0., 1.]])
|
||||
loss = losses.binary_crossentropy(y_true, y_pred)[0]
|
||||
# Check that logits are used. If y_pred is used directly, loss will
|
||||
# collapse to 0 from underflow.
|
||||
self.assertNotEqual(self.evaluate(loss), 0.)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_categorical_crossentropy_uses_cached_logits(self):
|
||||
logits = constant_op.constant([[-5., 0., 5.]])
|
||||
y_pred = activations.softmax(logits)
|
||||
self.assertTrue(hasattr(y_pred, '_keras_logits'))
|
||||
y_true = constant_op.constant([[0., 0., 1.]])
|
||||
loss = losses.categorical_crossentropy(y_true, logits, from_logits=True)[0]
|
||||
# Check that logits are used. If y_pred is used directly, loss will
|
||||
# collapse to 0 from underflow.
|
||||
self.assertNotEqual(self.evaluate(loss), 0.)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_sparse_categorical_crossentropy_uses_cached_logits(self):
|
||||
logits = constant_op.constant([[-5., 0., 5.]])
|
||||
y_pred = activations.softmax(logits)
|
||||
self.assertTrue(hasattr(y_pred, '_keras_logits'))
|
||||
y_true = constant_op.constant([2])
|
||||
loss = losses.sparse_categorical_crossentropy(
|
||||
y_true, logits, from_logits=True)[0]
|
||||
# Check that logits are used. If y_pred is used directly, loss will
|
||||
# collapse to 0 from underflow.
|
||||
self.assertNotEqual(self.evaluate(loss), 0.)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_loss_not_autographed_in_eager(self):
|
||||
|
||||
class MyLoss(losses.Loss):
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
return y_true - y_pred
|
||||
|
||||
loss = MyLoss()
|
||||
y_true = constant_op.constant([[0., 0., 0.]])
|
||||
y_pred = constant_op.constant([[1., 1., 1.]])
|
||||
|
||||
def tf_convert(fn, _):
|
||||
assert False, 'Function should not be autographed.'
|
||||
return fn
|
||||
|
||||
with test.mock.patch.object(autograph, 'tf_convert', tf_convert):
|
||||
loss(y_true, y_pred)
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class MeanSquaredErrorTest(test.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user