Merge pull request #47045 from pedro-r-marques:rt-cce
PiperOrigin-RevId: 356851151 Change-Id: I94d20b27f9efd4af561f7562e2c14d5ee4efb92d
This commit is contained in:
commit
d99168f592
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import functools
|
||||
|
||||
import six
|
||||
|
||||
@ -25,6 +26,7 @@ from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -35,6 +37,7 @@ from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops.losses import losses_impl
|
||||
@ -1231,7 +1234,31 @@ def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
|
||||
(per-batch loss value); a ragged tensor otherwise.
|
||||
"""
|
||||
|
||||
def rt_is_equiv_dense(rt):
|
||||
"""Returns true if this RaggedTensor has the same row_lenghts across
|
||||
|
||||
all ragged dimensions and thus can be converted to a dense tensor
|
||||
without loss of information.
|
||||
|
||||
Args:
|
||||
rt: RaggedTensor
|
||||
"""
|
||||
return math_ops.reduce_all([
|
||||
math_ops.equal(
|
||||
math_ops.reduce_variance(math_ops.cast(row_lens, K.floatx())),
|
||||
constant_op.constant([0.])) for row_lens in rt.nested_row_lengths()
|
||||
])
|
||||
|
||||
def _convert_to_dense(inputs):
|
||||
return tuple(rt.to_tensor() for rt in inputs)
|
||||
|
||||
def _wrapper(inputs):
|
||||
_, y_pred = inputs
|
||||
if isinstance(y_pred, ragged_tensor.RaggedTensor):
|
||||
return control_flow_ops.cond(
|
||||
rt_is_equiv_dense(y_pred),
|
||||
lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
|
||||
|
||||
return loss_fn(*inputs)
|
||||
|
||||
lshape = y_pred.shape.as_list()[1:-1]
|
||||
@ -1604,6 +1631,31 @@ def categorical_crossentropy(y_true,
|
||||
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
|
||||
|
||||
|
||||
@dispatch.dispatch_for_types(categorical_crossentropy,
|
||||
ragged_tensor.RaggedTensor)
|
||||
def _ragged_tensor_categorical_crossentropy(y_true,
|
||||
y_pred,
|
||||
from_logits=False,
|
||||
label_smoothing=0):
|
||||
""" Implements support for handling RaggedTensors.
|
||||
|
||||
Expected shape: (batch, sequence_len, n_classes) with sequence_len
|
||||
being variable per batch.
|
||||
Return shape: (batch, sequence_len).
|
||||
|
||||
When used by CategoricalCrossentropy() with the default reduction
|
||||
(SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
|
||||
number of elements independent of the batch. E.g. if the RaggedTensor
|
||||
has 2 batches with [2, 1] values respectivly the resulting loss is
|
||||
the sum of the individual loss values divided by 3.
|
||||
"""
|
||||
fn = functools.partial(
|
||||
categorical_crossentropy,
|
||||
from_logits=from_logits,
|
||||
label_smoothing=label_smoothing)
|
||||
return _ragged_tensor_apply_loss(fn, y_true, y_pred)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.sparse_categorical_crossentropy',
|
||||
'keras.losses.sparse_categorical_crossentropy')
|
||||
@dispatch.add_dispatch_support
|
||||
|
@ -1001,6 +1001,26 @@ class CategoricalCrossentropyTest(test.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'Shapes .+ are incompatible'):
|
||||
cce_obj(y_true, y_pred)
|
||||
|
||||
def test_ragged_tensors(self):
|
||||
cce_obj = losses.CategoricalCrossentropy()
|
||||
y_true = ragged_factory_ops.constant([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1]]])
|
||||
y_pred = ragged_factory_ops.constant(
|
||||
[[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]],
|
||||
dtype=dtypes.float32)
|
||||
# batch losses [[0.1054, 0.8047], [0.0619]]
|
||||
sample_weight = constant_op.constant([[1.2], [3.4]], shape=(2, 1))
|
||||
loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
# sum([0.1054, 0.8047, 0.0619]) / 3
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3)
|
||||
|
||||
# Test with logits.
|
||||
logits = ragged_factory_ops.constant([[[8., 1., 1.], [0., 9., 1.]],
|
||||
[[2., 3., 5.]]])
|
||||
cce_obj = losses.CategoricalCrossentropy(from_logits=True)
|
||||
# batch losses [[0.0018, 0.0004], [0.1698]]
|
||||
loss = cce_obj(y_true, logits, sample_weight=sample_weight)
|
||||
self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3)
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class SparseCategoricalCrossentropyTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user