From 6f5999038c8b35afa165ae7e19294b1ae494a2d1 Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Tue, 9 Feb 2021 11:58:11 +0100 Subject: [PATCH] Support RaggedTensors in categorical_crossentropy. When applying a loss, convert ragged tensors to dense if the ragged tensor has the same dimensions across all axis. This avoids unnecessary loss wrapper and map_fn invocations. --- tensorflow/python/keras/losses.py | 50 ++++++++++++++++++++++++++ tensorflow/python/keras/losses_test.py | 20 +++++++++++ 2 files changed, 70 insertions(+) diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 2eafc21812a..1f2b4effd69 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import abc +import functools import six @@ -25,6 +26,7 @@ from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_spec @@ -35,6 +37,7 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.losses import losses_impl @@ -1231,7 +1234,32 @@ def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred): (per-batch loss value); a ragged tensor otherwise. """ + def rt_is_equiv_dense(rt): + """Returns true if this RaggedTensor has the same row_lenghts across + all ragged dimensions and thus can be converted to a dense tensor + without loss of information. + + Args: + rt: RaggedTensor + """ + return math_ops.reduce_all([ + math_ops.equal( + math_ops.reduce_variance( + math_ops.cast(row_lens, K.floatx())), + constant_op.constant([0.])) + for row_lens in rt.nested_row_lengths()]) + + def _convert_to_dense(inputs): + return tuple(rt.to_tensor() for rt in inputs) + def _wrapper(inputs): + _, y_pred = inputs + if isinstance(y_pred, ragged_tensor.RaggedTensor): + return control_flow_ops.cond( + rt_is_equiv_dense(y_pred), + lambda: loss_fn(*_convert_to_dense(inputs)), + lambda: loss_fn(*inputs)) + return loss_fn(*inputs) lshape = y_pred.shape.as_list()[1:-1] @@ -1603,6 +1631,28 @@ def categorical_crossentropy(y_true, lambda: y_true) return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits) +@dispatch.dispatch_for_types( + categorical_crossentropy, ragged_tensor.RaggedTensor) +def _ragged_tensor_categorical_crossentropy(y_true, + y_pred, + from_logits=False, + label_smoothing=0): + """ Implements support for handling RaggedTensors. + + Expected shape: (batch, sequence_len, n_classes) with sequence_len + being variable per batch. + Return shape: (batch, sequence_len). + + When used by CategoricalCrossentropy() with the default reduction + (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the + number of elements independent of the batch. E.g. if the RaggedTensor + has 2 batches with [2, 1] values respectivly the resulting loss is + the sum of the individual loss values divided by 3. + """ + fn = functools.partial(categorical_crossentropy, from_logits=from_logits, + label_smoothing=label_smoothing) + return _ragged_tensor_apply_loss(fn, y_true, y_pred) + @keras_export('keras.metrics.sparse_categorical_crossentropy', 'keras.losses.sparse_categorical_crossentropy') diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 309b21ee207..c71665a63bd 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -1001,6 +1001,26 @@ class CategoricalCrossentropyTest(test.TestCase): with self.assertRaisesRegex(ValueError, 'Shapes .+ are incompatible'): cce_obj(y_true, y_pred) + def test_ragged_tensors(self): + cce_obj = losses.CategoricalCrossentropy() + y_true = ragged_factory_ops.constant([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1]]]) + y_pred = ragged_factory_ops.constant( + [[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]], + dtype=dtypes.float32) + # batch losses [[0.1054, 0.8047], [0.0619]] + sample_weight = constant_op.constant([[1.2], [3.4]], shape=(2, 1)) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + # sum([0.1054, 0.8047, 0.0619]) / 3 + self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3) + + # Test with logits. + logits = ragged_factory_ops.constant( + [[[8., 1., 1.], [0., 9., 1.]], [[2., 3., 5.]]]) + cce_obj = losses.CategoricalCrossentropy(from_logits=True) + # batch losses [[0.0018, 0.0004], [0.1698]] + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class SparseCategoricalCrossentropyTest(test.TestCase):