From a744818876ab362ad4112b625f40b2a0dbdafb12 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 May 2020 22:36:08 +0000 Subject: [PATCH] Fix ValueError with tf.keras.metrics.Recall and float64 keras backend This PR fixes the issue raised in 36790 where tf.keras.metrics.Recall causes ValueError when the backend of the keras is float64: This PR cast the value to the dtype of var as var.assign_add is being called. This PR fixes 36790. Signed-off-by: Yong Tang --- tensorflow/python/keras/utils/metrics_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py index 58fff40564d..5cb6fc5f9f8 100644 --- a/tensorflow/python/keras/utils/metrics_utils.py +++ b/tensorflow/python/keras/utils/metrics_utils.py @@ -422,9 +422,9 @@ def update_confusion_matrix_variables(variables_to_update, def weighted_assign_add(label, pred, weights, var): label_and_pred = math_ops.cast( - math_ops.logical_and(label, pred), dtype=dtypes.float32) + math_ops.logical_and(label, pred), dtype=var.dtype) if weights is not None: - label_and_pred *= weights + label_and_pred *= math_ops.cast(weights, dtype=var.dtype) return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) loop_vars = {