From ea63690ffa3d0a90eee25f4458d2b7f5588d07cc Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Wed, 10 Feb 2021 21:50:57 +0100 Subject: [PATCH] Implement support for RaggedTensors in binary_crossentropy loss. --- tensorflow/python/keras/losses.py | 18 ++++++++++++++++ tensorflow/python/keras/losses_test.py | 30 ++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 4b2105f7ab0..1ad806d653b 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -1728,6 +1728,24 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): return K.mean( K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1) +@dispatch.dispatch_for_types( + binary_crossentropy, ragged_tensor.RaggedTensor) +def _ragged_tensor_binary_crossentropy(y_true, y_pred, + from_logits=False, label_smoothing=0): + """ Implements support for handling RaggedTensors. + + Expected shape: (batch, sequence_len) with sequence_len being variable + per batch. + Return shape: (batch,); returns the per batch mean of the loss values. + + When used by BinaryCrossentropy() with the default reduction + (SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over + the number of batches. + """ + fn = functools.partial(binary_crossentropy, from_logits=from_logits, + label_smoothing=label_smoothing) + return _ragged_tensor_apply_loss(fn, y_true, y_pred) + @keras_export('keras.metrics.kl_divergence', 'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld', diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index bf474af966d..0933673d4fe 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -894,6 +894,36 @@ class BinaryCrossentropyTest(test.TestCase): expected_value = (100.0 + 50.0 * label_smoothing) / 3.0 self.assertAlmostEqual(self.evaluate(loss), expected_value, 3) + def test_ragged_tensors(self): + bce_obj = losses.BinaryCrossentropy() + y_true = ragged_factory_ops.constant([[1, 0, 1], [0]]) + y_pred = ragged_factory_ops.constant([[1, 1, 1], [0]], + dtype=dtypes.float32) + sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1)) + loss = bce_obj(y_true, y_pred, sample_weight=sample_weight) + + # per batch loss = [ sum([0, 15.33, 0]) / 3, 0. ] + # = [ 5.11, 0] + # Reduced loss = 5.11 * 1.2 / 2 + + self.assertAlmostEqual(self.evaluate(loss), 3.0666, 3) + + # Test with logits. + y_true = ragged_factory_ops.constant([[1, 0, 1], [0, 1]]) + logits = ragged_factory_ops.constant( + [[100.0, -100.0, 100.0], [100.0, 100.0]]) + weights = constant_op.constant([4, 3]) + bce_obj = losses.BinaryCrossentropy(from_logits=True) + loss = bce_obj(y_true, logits, sample_weight=weights) + + # Loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + # (where x = logits and z = y_true) + # Loss = [(0 + 0 + 0)/3, 100 / 2] + # Weighted loss = [0 * 4, 50 * 3] + # Reduced loss = (0 + 50 * 3) / 2 + + self.assertAlmostEqual(self.evaluate(loss), 75., 3) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class CategoricalCrossentropyTest(test.TestCase):