Merge pull request #47045 from pedro-r-marques:rt-cce
PiperOrigin-RevId: 356851151 Change-Id: I94d20b27f9efd4af561f7562e2c14d5ee4efb92d
This commit is contained in:
commit
d99168f592
@ -18,6 +18,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import functools
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ from tensorflow.python.autograph.core import ag_ctx
|
|||||||
from tensorflow.python.autograph.impl import api as autograph
|
from tensorflow.python.autograph.impl import api as autograph
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import smart_cond
|
from tensorflow.python.framework import smart_cond
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
@ -35,6 +37,7 @@ from tensorflow.python.keras.utils import tf_utils
|
|||||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops.losses import losses_impl
|
from tensorflow.python.ops.losses import losses_impl
|
||||||
@ -1231,7 +1234,31 @@ def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
|
|||||||
(per-batch loss value); a ragged tensor otherwise.
|
(per-batch loss value); a ragged tensor otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def rt_is_equiv_dense(rt):
|
||||||
|
"""Returns true if this RaggedTensor has the same row_lenghts across
|
||||||
|
|
||||||
|
all ragged dimensions and thus can be converted to a dense tensor
|
||||||
|
without loss of information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rt: RaggedTensor
|
||||||
|
"""
|
||||||
|
return math_ops.reduce_all([
|
||||||
|
math_ops.equal(
|
||||||
|
math_ops.reduce_variance(math_ops.cast(row_lens, K.floatx())),
|
||||||
|
constant_op.constant([0.])) for row_lens in rt.nested_row_lengths()
|
||||||
|
])
|
||||||
|
|
||||||
|
def _convert_to_dense(inputs):
|
||||||
|
return tuple(rt.to_tensor() for rt in inputs)
|
||||||
|
|
||||||
def _wrapper(inputs):
|
def _wrapper(inputs):
|
||||||
|
_, y_pred = inputs
|
||||||
|
if isinstance(y_pred, ragged_tensor.RaggedTensor):
|
||||||
|
return control_flow_ops.cond(
|
||||||
|
rt_is_equiv_dense(y_pred),
|
||||||
|
lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
|
||||||
|
|
||||||
return loss_fn(*inputs)
|
return loss_fn(*inputs)
|
||||||
|
|
||||||
lshape = y_pred.shape.as_list()[1:-1]
|
lshape = y_pred.shape.as_list()[1:-1]
|
||||||
@ -1604,6 +1631,31 @@ def categorical_crossentropy(y_true,
|
|||||||
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
|
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
|
||||||
|
|
||||||
|
|
||||||
|
@dispatch.dispatch_for_types(categorical_crossentropy,
|
||||||
|
ragged_tensor.RaggedTensor)
|
||||||
|
def _ragged_tensor_categorical_crossentropy(y_true,
|
||||||
|
y_pred,
|
||||||
|
from_logits=False,
|
||||||
|
label_smoothing=0):
|
||||||
|
""" Implements support for handling RaggedTensors.
|
||||||
|
|
||||||
|
Expected shape: (batch, sequence_len, n_classes) with sequence_len
|
||||||
|
being variable per batch.
|
||||||
|
Return shape: (batch, sequence_len).
|
||||||
|
|
||||||
|
When used by CategoricalCrossentropy() with the default reduction
|
||||||
|
(SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
|
||||||
|
number of elements independent of the batch. E.g. if the RaggedTensor
|
||||||
|
has 2 batches with [2, 1] values respectivly the resulting loss is
|
||||||
|
the sum of the individual loss values divided by 3.
|
||||||
|
"""
|
||||||
|
fn = functools.partial(
|
||||||
|
categorical_crossentropy,
|
||||||
|
from_logits=from_logits,
|
||||||
|
label_smoothing=label_smoothing)
|
||||||
|
return _ragged_tensor_apply_loss(fn, y_true, y_pred)
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.metrics.sparse_categorical_crossentropy',
|
@keras_export('keras.metrics.sparse_categorical_crossentropy',
|
||||||
'keras.losses.sparse_categorical_crossentropy')
|
'keras.losses.sparse_categorical_crossentropy')
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
|
@ -1001,6 +1001,26 @@ class CategoricalCrossentropyTest(test.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, 'Shapes .+ are incompatible'):
|
with self.assertRaisesRegex(ValueError, 'Shapes .+ are incompatible'):
|
||||||
cce_obj(y_true, y_pred)
|
cce_obj(y_true, y_pred)
|
||||||
|
|
||||||
|
def test_ragged_tensors(self):
|
||||||
|
cce_obj = losses.CategoricalCrossentropy()
|
||||||
|
y_true = ragged_factory_ops.constant([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1]]])
|
||||||
|
y_pred = ragged_factory_ops.constant(
|
||||||
|
[[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
# batch losses [[0.1054, 0.8047], [0.0619]]
|
||||||
|
sample_weight = constant_op.constant([[1.2], [3.4]], shape=(2, 1))
|
||||||
|
loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||||
|
# sum([0.1054, 0.8047, 0.0619]) / 3
|
||||||
|
self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3)
|
||||||
|
|
||||||
|
# Test with logits.
|
||||||
|
logits = ragged_factory_ops.constant([[[8., 1., 1.], [0., 9., 1.]],
|
||||||
|
[[2., 3., 5.]]])
|
||||||
|
cce_obj = losses.CategoricalCrossentropy(from_logits=True)
|
||||||
|
# batch losses [[0.0018, 0.0004], [0.1698]]
|
||||||
|
loss = cce_obj(y_true, logits, sample_weight=sample_weight)
|
||||||
|
self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3)
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
class SparseCategoricalCrossentropyTest(test.TestCase):
|
class SparseCategoricalCrossentropyTest(test.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user