diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index ba1b76bab32..29fc36b5fc6 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers @@ -2174,6 +2175,23 @@ class ResetStatesTest(keras_parameterized.TestCase): self.assertArrayNear(self.evaluate(m_obj.total_cm)[0], [1, 0], 1e-1) self.assertArrayNear(self.evaluate(m_obj.total_cm)[1], [3, 0], 1e-1) + def test_reset_states_recall_float64(self): + # Test case for GitHub issue 36790. + try: + backend.set_floatx('float64') + r_obj = metrics.Recall() + model = _get_model([r_obj]) + x = np.concatenate((np.ones((50, 4)), np.zeros((50, 4)))) + y = np.concatenate((np.ones((50, 1)), np.ones((50, 1)))) + model.evaluate(x, y) + self.assertEqual(self.evaluate(r_obj.true_positives), 50.) + self.assertEqual(self.evaluate(r_obj.false_negatives), 50.) + model.evaluate(x, y) + self.assertEqual(self.evaluate(r_obj.true_positives), 50.) + self.assertEqual(self.evaluate(r_obj.false_negatives), 50.) + finally: + backend.set_floatx('float32') + if __name__ == '__main__': test.main()