Fix ValueError with tf.keras.metrics.Recall and float64 keras backend
This PR fixes the issue raised in 36790 where tf.keras.metrics.Recall causes ValueError when the backend of the keras is float64: This PR cast the value to the dtype of var as var.assign_add is being called. This PR fixes 36790. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
ae76544efc
commit
a744818876
@ -422,9 +422,9 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||
|
||||
def weighted_assign_add(label, pred, weights, var):
|
||||
label_and_pred = math_ops.cast(
|
||||
math_ops.logical_and(label, pred), dtype=dtypes.float32)
|
||||
math_ops.logical_and(label, pred), dtype=var.dtype)
|
||||
if weights is not None:
|
||||
label_and_pred *= weights
|
||||
label_and_pred *= math_ops.cast(weights, dtype=var.dtype)
|
||||
return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
|
||||
|
||||
loop_vars = {
|
||||
|
Loading…
Reference in New Issue
Block a user