Merge pull request #47045 from pedro-r-marques:rt-cce

PiperOrigin-RevId: 356851151
Change-Id: I94d20b27f9efd4af561f7562e2c14d5ee4efb92d
This commit is contained in:
TensorFlower Gardener 2021-02-10 16:20:51 -08:00
commit d99168f592
2 changed files with 72 additions and 0 deletions

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import abc
import functools
import six
@ -25,6 +26,7 @@ from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_spec
@ -35,6 +37,7 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl
@ -1231,7 +1234,31 @@ def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
(per-batch loss value); a ragged tensor otherwise.
"""
def rt_is_equiv_dense(rt):
"""Returns true if this RaggedTensor has the same row_lenghts across
all ragged dimensions and thus can be converted to a dense tensor
without loss of information.
Args:
rt: RaggedTensor
"""
return math_ops.reduce_all([
math_ops.equal(
math_ops.reduce_variance(math_ops.cast(row_lens, K.floatx())),
constant_op.constant([0.])) for row_lens in rt.nested_row_lengths()
])
def _convert_to_dense(inputs):
return tuple(rt.to_tensor() for rt in inputs)
def _wrapper(inputs):
_, y_pred = inputs
if isinstance(y_pred, ragged_tensor.RaggedTensor):
return control_flow_ops.cond(
rt_is_equiv_dense(y_pred),
lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
return loss_fn(*inputs)
lshape = y_pred.shape.as_list()[1:-1]
@ -1604,6 +1631,31 @@ def categorical_crossentropy(y_true,
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
@dispatch.dispatch_for_types(categorical_crossentropy,
ragged_tensor.RaggedTensor)
def _ragged_tensor_categorical_crossentropy(y_true,
y_pred,
from_logits=False,
label_smoothing=0):
""" Implements support for handling RaggedTensors.
Expected shape: (batch, sequence_len, n_classes) with sequence_len
being variable per batch.
Return shape: (batch, sequence_len).
When used by CategoricalCrossentropy() with the default reduction
(SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
number of elements independent of the batch. E.g. if the RaggedTensor
has 2 batches with [2, 1] values respectivly the resulting loss is
the sum of the individual loss values divided by 3.
"""
fn = functools.partial(
categorical_crossentropy,
from_logits=from_logits,
label_smoothing=label_smoothing)
return _ragged_tensor_apply_loss(fn, y_true, y_pred)
@keras_export('keras.metrics.sparse_categorical_crossentropy',
'keras.losses.sparse_categorical_crossentropy')
@dispatch.add_dispatch_support

View File

@ -1001,6 +1001,26 @@ class CategoricalCrossentropyTest(test.TestCase):
with self.assertRaisesRegex(ValueError, 'Shapes .+ are incompatible'):
cce_obj(y_true, y_pred)
def test_ragged_tensors(self):
cce_obj = losses.CategoricalCrossentropy()
y_true = ragged_factory_ops.constant([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1]]])
y_pred = ragged_factory_ops.constant(
[[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]],
dtype=dtypes.float32)
# batch losses [[0.1054, 0.8047], [0.0619]]
sample_weight = constant_op.constant([[1.2], [3.4]], shape=(2, 1))
loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)
# sum([0.1054, 0.8047, 0.0619]) / 3
self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3)
# Test with logits.
logits = ragged_factory_ops.constant([[[8., 1., 1.], [0., 9., 1.]],
[[2., 3., 5.]]])
cce_obj = losses.CategoricalCrossentropy(from_logits=True)
# batch losses [[0.0018, 0.0004], [0.1698]]
loss = cce_obj(y_true, logits, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class SparseCategoricalCrossentropyTest(test.TestCase):