Merge pull request #47075 from pedro-r-marques:rt-bce
PiperOrigin-RevId: 358064569 Change-Id: I278260acc2b6dd1fedcba263136234d04ec8e66f
This commit is contained in:
commit
8a37ad24a2
@ -1729,6 +1729,28 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
|
||||
K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
|
||||
|
||||
|
||||
@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor)
|
||||
def _ragged_tensor_binary_crossentropy(y_true,
|
||||
y_pred,
|
||||
from_logits=False,
|
||||
label_smoothing=0):
|
||||
""" Implements support for handling RaggedTensors.
|
||||
|
||||
Expected shape: (batch, sequence_len) with sequence_len being variable
|
||||
per batch.
|
||||
Return shape: (batch,); returns the per batch mean of the loss values.
|
||||
|
||||
When used by BinaryCrossentropy() with the default reduction
|
||||
(SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over
|
||||
the number of batches.
|
||||
"""
|
||||
fn = functools.partial(
|
||||
binary_crossentropy,
|
||||
from_logits=from_logits,
|
||||
label_smoothing=label_smoothing)
|
||||
return _ragged_tensor_apply_loss(fn, y_true, y_pred)
|
||||
|
||||
|
||||
@keras_export('keras.metrics.kl_divergence',
|
||||
'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld',
|
||||
'keras.metrics.KLD', 'keras.losses.kl_divergence',
|
||||
|
||||
@ -894,6 +894,35 @@ class BinaryCrossentropyTest(test.TestCase):
|
||||
expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
|
||||
self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
|
||||
|
||||
def test_ragged_tensors(self):
|
||||
bce_obj = losses.BinaryCrossentropy()
|
||||
y_true = ragged_factory_ops.constant([[1, 0, 1], [0]])
|
||||
y_pred = ragged_factory_ops.constant([[1, 1, 1], [0]], dtype=dtypes.float32)
|
||||
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
|
||||
loss = bce_obj(y_true, y_pred, sample_weight=sample_weight)
|
||||
|
||||
# per batch loss = [ sum([0, 15.33, 0]) / 3, 0. ]
|
||||
# = [ 5.11, 0]
|
||||
# Reduced loss = 5.11 * 1.2 / 2
|
||||
|
||||
self.assertAlmostEqual(self.evaluate(loss), 3.0666, 3)
|
||||
|
||||
# Test with logits.
|
||||
y_true = ragged_factory_ops.constant([[1, 0, 1], [0, 1]])
|
||||
logits = ragged_factory_ops.constant([[100.0, -100.0, 100.0],
|
||||
[100.0, 100.0]])
|
||||
weights = constant_op.constant([4, 3])
|
||||
bce_obj = losses.BinaryCrossentropy(from_logits=True)
|
||||
loss = bce_obj(y_true, logits, sample_weight=weights)
|
||||
|
||||
# Loss = max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
# (where x = logits and z = y_true)
|
||||
# Loss = [(0 + 0 + 0)/3, 100 / 2]
|
||||
# Weighted loss = [0 * 4, 50 * 3]
|
||||
# Reduced loss = (0 + 50 * 3) / 2
|
||||
|
||||
self.assertAlmostEqual(self.evaluate(loss), 75., 3)
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class CategoricalCrossentropyTest(test.TestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user