Fix ValueError with tf.keras.metrics.Recall and float64 keras backend

This PR fixes the issue raised in 36790 where tf.keras.metrics.Recall
causes ValueError when the backend of the keras is float64:

This PR cast the value to the dtype of var as var.assign_add
is being called.

This PR fixes 36790.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-05-03 22:36:08 +00:00
parent ae76544efc
commit a744818876

View File

@ -422,9 +422,9 @@ def update_confusion_matrix_variables(variables_to_update,
def weighted_assign_add(label, pred, weights, var):
label_and_pred = math_ops.cast(
math_ops.logical_and(label, pred), dtype=dtypes.float32)
math_ops.logical_and(label, pred), dtype=var.dtype)
if weights is not None:
label_and_pred *= weights
label_and_pred *= math_ops.cast(weights, dtype=var.dtype)
return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
loop_vars = {