diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 99eadaec4c8..fee65be18c9 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers @@ -2247,6 +2248,23 @@ class ResetStatesTest(keras_parameterized.TestCase): self.assertArrayNear(self.evaluate(m_obj.total_cm)[0], [1, 0], 1e-1) self.assertArrayNear(self.evaluate(m_obj.total_cm)[1], [3, 0], 1e-1) + def test_reset_states_recall_float64(self): + # Test case for GitHub issue 36790. + try: + backend.set_floatx('float64') + r_obj = metrics.Recall() + model = _get_model([r_obj]) + x = np.concatenate((np.ones((50, 4)), np.zeros((50, 4)))) + y = np.concatenate((np.ones((50, 1)), np.ones((50, 1)))) + model.evaluate(x, y) + self.assertEqual(self.evaluate(r_obj.true_positives), 50.) + self.assertEqual(self.evaluate(r_obj.false_negatives), 50.) + model.evaluate(x, y) + self.assertEqual(self.evaluate(r_obj.true_positives), 50.) + self.assertEqual(self.evaluate(r_obj.false_negatives), 50.) + finally: + backend.set_floatx('float32') + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/metrics_utils.py b/tensorflow/python/keras/utils/metrics_utils.py index 58fff40564d..5f9b57c095e 100644 --- a/tensorflow/python/keras/utils/metrics_utils.py +++ b/tensorflow/python/keras/utils/metrics_utils.py @@ -299,9 +299,19 @@ def update_confusion_matrix_variables(variables_to_update, '`multi_label` is True.') if variables_to_update is None: return - y_true = math_ops.cast(y_true, dtype=dtypes.float32) - y_pred = math_ops.cast(y_pred, dtype=dtypes.float32) - thresholds = ops.convert_to_tensor_v2(thresholds, dtype=dtypes.float32) + if not any( + key for key in variables_to_update if key in list(ConfusionMatrix)): + raise ValueError( + 'Please provide at least one valid confusion matrix ' + 'variable to update. Valid variable key options are: "{}". ' + 'Received: "{}"'.format( + list(ConfusionMatrix), variables_to_update.keys())) + + variable_dtype = list(variables_to_update.values())[0].dtype + + y_true = math_ops.cast(y_true, dtype=variable_dtype) + y_pred = math_ops.cast(y_pred, dtype=variable_dtype) + thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype) num_thresholds = thresholds.shape[0] if multi_label: one_thresh = math_ops.equal( @@ -314,14 +324,6 @@ def update_confusion_matrix_variables(variables_to_update, sample_weight) one_thresh = math_ops.cast(True, dtype=dtypes.bool) - if not any( - key for key in variables_to_update if key in list(ConfusionMatrix)): - raise ValueError( - 'Please provide at least one valid confusion matrix ' - 'variable to update. Valid variable key options are: "{}". ' - 'Received: "{}"'.format( - list(ConfusionMatrix), variables_to_update.keys())) - invalid_keys = [ key for key in variables_to_update if key not in list(ConfusionMatrix) ] @@ -401,7 +403,7 @@ def update_confusion_matrix_variables(variables_to_update, if sample_weight is not None: sample_weight = weights_broadcast_ops.broadcast_weights( - math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred) + math_ops.cast(sample_weight, dtype=variable_dtype), y_pred) weights_tiled = array_ops.tile( array_ops.reshape(sample_weight, thresh_tiles), data_tiles) else: @@ -422,9 +424,9 @@ def update_confusion_matrix_variables(variables_to_update, def weighted_assign_add(label, pred, weights, var): label_and_pred = math_ops.cast( - math_ops.logical_and(label, pred), dtype=dtypes.float32) + math_ops.logical_and(label, pred), dtype=var.dtype) if weights is not None: - label_and_pred *= weights + label_and_pred *= math_ops.cast(weights, dtype=var.dtype) return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) loop_vars = {