diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 1a1507b3882..16931b0dd00 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4406,7 +4406,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): target = cast(target, 'int64') - # Try to adjust the shape so that rank of labels = 1 - rank of logits. + # Try to adjust the shape so that rank of labels = rank of logits - 1. output_shape = array_ops.shape_v2(output) target_rank = target.shape.ndims