From 40ccaf67cafe41d1035a24b844254a00f13c39de Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Wed, 23 Sep 2020 14:00:30 -0700 Subject: [PATCH] Make softmax and sigmoid activation backtracking more robust. Logits are now cached on the post-activation Tensors and used whenever available. This allows logits to be used in eager mode as well as in the middle of a tf.function, whereas previously they were only able to be used in the middle of a tf.function. The only time this mechanism will not work (and the fallback, existing mechanism will be used) is when the activation and the loss function are separated by a tf.function boundary. Also fixes Loss class to only autograph the loss function when in eager mode. This allows Python control flow to work correctly in the loss function. PiperOrigin-RevId: 333367543 Change-Id: I64454b9acb6a77247122d72e1302aa5110415d57 --- tensorflow/python/keras/activations.py | 19 ++++++--- tensorflow/python/keras/backend.py | 28 ++++++++++--- tensorflow/python/keras/losses.py | 8 +++- tensorflow/python/keras/losses_test.py | 55 ++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index 119851f4e13..4f1ef96c8ef 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -71,17 +71,21 @@ def softmax(x, axis=-1): Raises: ValueError: In case `dim(x) == 1`. """ - ndim = K.ndim(x) - if ndim == 2: - return nn.softmax(x) - elif ndim > 2: + rank = x.shape.rank + if rank == 2: + output = nn.softmax(x) + elif rank > 2: e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True)) s = math_ops.reduce_sum(e, axis=axis, keepdims=True) - return e / s + output = e / s else: raise ValueError('Cannot apply softmax to a tensor that is 1D. ' 'Received input: %s' % (x,)) + # Cache the logits to use for crossentropy loss. + output._keras_logits = x # pylint: disable=protected-access + return output + @keras_export('keras.activations.elu') @dispatch.add_dispatch_support @@ -391,7 +395,10 @@ def sigmoid(x): Returns: Tensor with the sigmoid activation: `1 / (1 + exp(-x))`. """ - return nn.sigmoid(x) + output = nn.sigmoid(x) + # Cache the logits to use for crossentropy loss. + output._keras_logits = x # pylint: disable=protected-access + return output @keras_export('keras.activations.exponential') diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 7766a735fe6..7bab18084dd 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4801,8 +4801,14 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): """ target = ops.convert_to_tensor_v2_with_dispatch(target) output = ops.convert_to_tensor_v2_with_dispatch(output) - target.shape.assert_is_compatible_with(output.shape) + + # Use logits whenever they are available. `softmax` and `sigmoid` + # activations cache logits on the `output` Tensor. + if hasattr(output, '_keras_logits'): + output = output._keras_logits # pylint: disable=protected-access + from_logits = True + if from_logits: return nn.softmax_cross_entropy_with_logits_v2( labels=target, logits=output, axis=axis) @@ -4852,9 +4858,14 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): target = ops.convert_to_tensor_v2_with_dispatch(target) output = ops.convert_to_tensor_v2_with_dispatch(output) - if (not from_logits and - not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Softmax') and not hasattr(output, '_keras_history'): + # Use logits whenever they are available. `softmax` and `sigmoid` + # activations cache logits on the `output` Tensor. + if hasattr(output, '_keras_logits'): + output = output._keras_logits # pylint: disable=protected-access + from_logits = True + elif (not from_logits and + not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and + output.op.type == 'Softmax') and not hasattr(output, '_keras_history'): # When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training. @@ -4862,8 +4873,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): assert len(output.op.inputs) == 1 output = output.op.inputs[0] from_logits = True - - if not from_logits: + elif not from_logits: epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_) output = math_ops.log(output) @@ -4930,6 +4940,12 @@ def binary_crossentropy(target, output, from_logits=False): target = ops.convert_to_tensor_v2_with_dispatch(target) output = ops.convert_to_tensor_v2_with_dispatch(output) + # Use logits whenever they are available. `softmax` and `sigmoid` + # activations cache logits on the `output` Tensor. + if hasattr(output, '_keras_logits'): + output = output._keras_logits # pylint: disable=protected-access + from_logits = True + if from_logits: return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output) diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 6b74121cf80..c66bc55a9a2 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -25,6 +25,7 @@ import six from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util @@ -144,8 +145,11 @@ class Loss(object): graph_ctx = tf_utils.graph_context_for_symbolic_tensors( y_true, y_pred, sample_weight) with K.name_scope(self._name_scope), graph_ctx: - ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) - losses = ag_call(y_true, y_pred) + if context.executing_eagerly(): + call_fn = self.call + else: + call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) + losses = call_fn(y_true, y_pred) return losses_utils.compute_weighted_loss( losses, sample_weight, reduction=self._get_reduction()) diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 4de49e69829..7cbd7b18f70 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -21,11 +21,13 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.keras import activations from tensorflow.python.keras import backend from tensorflow.python.keras import combinations from tensorflow.python.keras import losses @@ -245,6 +247,59 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'Could not interpret loss'): losses.get(0) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_binary_crossentropy_uses_cached_logits(self): + logits = constant_op.constant([[-30., 30.]]) + y_pred = activations.sigmoid(logits) + self.assertTrue(hasattr(y_pred, '_keras_logits')) + y_true = constant_op.constant([[0., 1.]]) + loss = losses.binary_crossentropy(y_true, y_pred)[0] + # Check that logits are used. If y_pred is used directly, loss will + # collapse to 0 from underflow. + self.assertNotEqual(self.evaluate(loss), 0.) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_categorical_crossentropy_uses_cached_logits(self): + logits = constant_op.constant([[-5., 0., 5.]]) + y_pred = activations.softmax(logits) + self.assertTrue(hasattr(y_pred, '_keras_logits')) + y_true = constant_op.constant([[0., 0., 1.]]) + loss = losses.categorical_crossentropy(y_true, logits, from_logits=True)[0] + # Check that logits are used. If y_pred is used directly, loss will + # collapse to 0 from underflow. + self.assertNotEqual(self.evaluate(loss), 0.) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_sparse_categorical_crossentropy_uses_cached_logits(self): + logits = constant_op.constant([[-5., 0., 5.]]) + y_pred = activations.softmax(logits) + self.assertTrue(hasattr(y_pred, '_keras_logits')) + y_true = constant_op.constant([2]) + loss = losses.sparse_categorical_crossentropy( + y_true, logits, from_logits=True)[0] + # Check that logits are used. If y_pred is used directly, loss will + # collapse to 0 from underflow. + self.assertNotEqual(self.evaluate(loss), 0.) + + @combinations.generate(combinations.combine(mode=['eager'])) + def test_loss_not_autographed_in_eager(self): + + class MyLoss(losses.Loss): + + def call(self, y_true, y_pred): + return y_true - y_pred + + loss = MyLoss() + y_true = constant_op.constant([[0., 0., 0.]]) + y_pred = constant_op.constant([[1., 1., 1.]]) + + def tf_convert(fn, _): + assert False, 'Function should not be autographed.' + return fn + + with test.mock.patch.object(autograph, 'tf_convert', tf_convert): + loss(y_true, y_pred) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class MeanSquaredErrorTest(test.TestCase):