Merge pull request #47075 from pedro-r-marques:rt-bce

PiperOrigin-RevId: 358064569
Change-Id: I278260acc2b6dd1fedcba263136234d04ec8e66f
This commit is contained in:
TensorFlower Gardener 2021-02-17 17:09:16 -08:00
commit 8a37ad24a2
2 changed files with 51 additions and 0 deletions

View File

@ -1729,6 +1729,28 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor)
def _ragged_tensor_binary_crossentropy(y_true,
y_pred,
from_logits=False,
label_smoothing=0):
""" Implements support for handling RaggedTensors.
Expected shape: (batch, sequence_len) with sequence_len being variable
per batch.
Return shape: (batch,); returns the per batch mean of the loss values.
When used by BinaryCrossentropy() with the default reduction
(SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over
the number of batches.
"""
fn = functools.partial(
binary_crossentropy,
from_logits=from_logits,
label_smoothing=label_smoothing)
return _ragged_tensor_apply_loss(fn, y_true, y_pred)
@keras_export('keras.metrics.kl_divergence',
'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld',
'keras.metrics.KLD', 'keras.losses.kl_divergence',

View File

@ -894,6 +894,35 @@ class BinaryCrossentropyTest(test.TestCase):
expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
def test_ragged_tensors(self):
bce_obj = losses.BinaryCrossentropy()
y_true = ragged_factory_ops.constant([[1, 0, 1], [0]])
y_pred = ragged_factory_ops.constant([[1, 1, 1], [0]], dtype=dtypes.float32)
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = bce_obj(y_true, y_pred, sample_weight=sample_weight)
# per batch loss = [ sum([0, 15.33, 0]) / 3, 0. ]
# = [ 5.11, 0]
# Reduced loss = 5.11 * 1.2 / 2
self.assertAlmostEqual(self.evaluate(loss), 3.0666, 3)
# Test with logits.
y_true = ragged_factory_ops.constant([[1, 0, 1], [0, 1]])
logits = ragged_factory_ops.constant([[100.0, -100.0, 100.0],
[100.0, 100.0]])
weights = constant_op.constant([4, 3])
bce_obj = losses.BinaryCrossentropy(from_logits=True)
loss = bce_obj(y_true, logits, sample_weight=weights)
# Loss = max(x, 0) - x * z + log(1 + exp(-abs(x)))
# (where x = logits and z = y_true)
# Loss = [(0 + 0 + 0)/3, 100 / 2]
# Weighted loss = [0 * 4, 50 * 3]
# Reduced loss = (0 + 50 * 3) / 2
self.assertAlmostEqual(self.evaluate(loss), 75., 3)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CategoricalCrossentropyTest(test.TestCase):