Fix SparseTopKCategoricalAccuracy for predictions with rank > 2.
PiperOrigin-RevId: 281769284 Change-Id: Icf5ee1addcee83d2b77ee6de644f5fb7e83b2971
This commit is contained in:
parent
82e32167b6
commit
ddca4b92b4
@ -2997,12 +2997,25 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
|
||||
@keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
|
||||
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
"""Computes how often integer targets are in the top `K` predictions.
|
||||
|
||||
Args:
|
||||
y_true: tensor of true targets.
|
||||
y_pred: tensor of predicted targets.
|
||||
k: (Optional) Number of top elements to look at for computing accuracy.
|
||||
Defaults to 5.
|
||||
|
||||
Returns:
|
||||
Sparse top K categorical accuracy value.
|
||||
"""
|
||||
y_pred_rank = ops.convert_to_tensor(y_pred).shape.ndims
|
||||
y_true_rank = ops.convert_to_tensor(y_true).shape.ndims
|
||||
# If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
|
||||
if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
|
||||
K.int_shape(y_true)) == len(K.int_shape(y_pred))):
|
||||
y_true = array_ops.squeeze(y_true, [-1])
|
||||
# Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
|
||||
if (y_true_rank is not None) and (y_pred_rank is not None):
|
||||
if y_pred_rank > 2:
|
||||
y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]])
|
||||
if y_true_rank > 1:
|
||||
y_true = array_ops.reshape(y_true, [-1])
|
||||
|
||||
return math_ops.cast(
|
||||
nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), K.floatx())
|
||||
|
@ -46,15 +46,21 @@ class KerasFunctionalMetricsTest(test.TestCase):
|
||||
# Test correctness if the shape of y_true is (num_samples,)
|
||||
y_true = K.variable([1., 0., 0., 0.])
|
||||
y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
|
||||
print(K.eval(metric(y_true, y_pred)))
|
||||
self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
|
||||
|
||||
# Test correctness if the shape of y_true is (num_samples, 1)
|
||||
y_true = K.variable([[1.], [0.], [0.], [0.]])
|
||||
y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
|
||||
print(K.eval(metric(y_true, y_pred)))
|
||||
self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
|
||||
|
||||
# Test correctness if the shape of y_true is (batch_size, seq_length) and
|
||||
# y_pred is (batch_size, seq_length, num_classes)
|
||||
y_pred = K.variable(
|
||||
np.array([[[0.2, 0.3, 0.1], [0.1, 0.2, 0.7]],
|
||||
[[0.3, 0.2, 0.1], [0.7, 0.2, 0.1]]]))
|
||||
y_true = K.variable(np.array([[1, 0], [1, 0]]))
|
||||
self.assertAllEqual(K.eval(metric(y_true, y_pred)), [[1., 0.], [0., 1.]])
|
||||
|
||||
def test_sparse_categorical_accuracy_float(self):
|
||||
with self.cached_session():
|
||||
metric = metrics.sparse_categorical_accuracy
|
||||
@ -106,6 +112,22 @@ class KerasFunctionalMetricsTest(test.TestCase):
|
||||
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
|
||||
self.assertEqual(np.mean(result), 0.)
|
||||
|
||||
# Test correctness if the shape of y_true is (batch_size, seq_length) and
|
||||
# y_pred is (batch_size, seq_length, num_classes)
|
||||
y_pred = K.variable(
|
||||
np.array([[[0.3, 0.2, 0.1], [0.1, 0.2, 0.7], [0.1, 0.2, 0.7]],
|
||||
[[0.3, 0.2, 0.1], [0.1, 0.2, 0.7], [0.3, 0.2, 0.1]]]))
|
||||
y_true = K.variable(np.array([[1, 0, 0], [1, 0, 1]]))
|
||||
result = K.eval(
|
||||
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
|
||||
self.assertEqual(np.mean(result), 1)
|
||||
result = K.eval(
|
||||
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
|
||||
self.assertEqual(np.mean(result), 0.5)
|
||||
result = K.eval(
|
||||
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
|
||||
self.assertEqual(np.mean(result), 0.)
|
||||
|
||||
def test_top_k_categorical_accuracy(self):
|
||||
with self.cached_session():
|
||||
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
|
||||
|
Loading…
Reference in New Issue
Block a user