Merge pull request #36990 from foxik:fix_categorical_label_smoothing
PiperOrigin-RevId: 313703968 Change-Id: Ib0e7e98d496055d0574cda343c43d9f5225686eb
This commit is contained in:
commit
0620d5912b
@ -1526,7 +1526,7 @@ def categorical_crossentropy(y_true,
|
||||
label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx())
|
||||
|
||||
def _smooth_labels():
|
||||
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
|
||||
num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype)
|
||||
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
|
||||
|
||||
y_true = smart_cond.smart_cond(label_smoothing,
|
||||
|
Loading…
Reference in New Issue
Block a user