Merge pull request #36990 from foxik:fix_categorical_label_smoothing

PiperOrigin-RevId: 313703968
Change-Id: Ib0e7e98d496055d0574cda343c43d9f5225686eb
This commit is contained in:
TensorFlower Gardener 2020-05-28 19:19:21 -07:00
commit 0620d5912b

View File

@ -1526,7 +1526,7 @@ def categorical_crossentropy(y_true,
label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx())
def _smooth_labels():
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype)
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
y_true = smart_cond.smart_cond(label_smoothing,