Add test case for tf.keras.metrics.Recall() and float64 keras backend.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
a744818876
commit
8dadd8304f
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras import combinations
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import layers
|
from tensorflow.python.keras import layers
|
||||||
|
@ -2174,6 +2175,23 @@ class ResetStatesTest(keras_parameterized.TestCase):
|
||||||
self.assertArrayNear(self.evaluate(m_obj.total_cm)[0], [1, 0], 1e-1)
|
self.assertArrayNear(self.evaluate(m_obj.total_cm)[0], [1, 0], 1e-1)
|
||||||
self.assertArrayNear(self.evaluate(m_obj.total_cm)[1], [3, 0], 1e-1)
|
self.assertArrayNear(self.evaluate(m_obj.total_cm)[1], [3, 0], 1e-1)
|
||||||
|
|
||||||
|
def test_reset_states_recall_float64(self):
|
||||||
|
# Test case for GitHub issue 36790.
|
||||||
|
try:
|
||||||
|
backend.set_floatx('float64')
|
||||||
|
r_obj = metrics.Recall()
|
||||||
|
model = _get_model([r_obj])
|
||||||
|
x = np.concatenate((np.ones((50, 4)), np.zeros((50, 4))))
|
||||||
|
y = np.concatenate((np.ones((50, 1)), np.ones((50, 1))))
|
||||||
|
model.evaluate(x, y)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.true_positives), 50.)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.false_negatives), 50.)
|
||||||
|
model.evaluate(x, y)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.true_positives), 50.)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.false_negatives), 50.)
|
||||||
|
finally:
|
||||||
|
backend.set_floatx('float32')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue