Merge pull request #36990 from foxik:fix_categorical_label_smoothing
PiperOrigin-RevId: 313703968 Change-Id: Ib0e7e98d496055d0574cda343c43d9f5225686eb
This commit is contained in:
commit
0620d5912b
@ -1526,7 +1526,7 @@ def categorical_crossentropy(y_true,
|
|||||||
label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx())
|
label_smoothing = ops.convert_to_tensor_v2(label_smoothing, dtype=K.floatx())
|
||||||
|
|
||||||
def _smooth_labels():
|
def _smooth_labels():
|
||||||
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
|
num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype)
|
||||||
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
|
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
|
||||||
|
|
||||||
y_true = smart_cond.smart_cond(label_smoothing,
|
y_true = smart_cond.smart_cond(label_smoothing,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user