Merge pull request #39134 from yongtang:36790-tf.metrics.Recall-backend-float64
PiperOrigin-RevId: 313606803 Change-Id: Idbb4244c62fe797815ea40ae1297c0346c65983e
This commit is contained in:
commit
eded0b5744
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras import combinations
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import layers
|
from tensorflow.python.keras import layers
|
||||||
|
@ -2247,6 +2248,23 @@ class ResetStatesTest(keras_parameterized.TestCase):
|
||||||
self.assertArrayNear(self.evaluate(m_obj.total_cm)[0], [1, 0], 1e-1)
|
self.assertArrayNear(self.evaluate(m_obj.total_cm)[0], [1, 0], 1e-1)
|
||||||
self.assertArrayNear(self.evaluate(m_obj.total_cm)[1], [3, 0], 1e-1)
|
self.assertArrayNear(self.evaluate(m_obj.total_cm)[1], [3, 0], 1e-1)
|
||||||
|
|
||||||
|
def test_reset_states_recall_float64(self):
|
||||||
|
# Test case for GitHub issue 36790.
|
||||||
|
try:
|
||||||
|
backend.set_floatx('float64')
|
||||||
|
r_obj = metrics.Recall()
|
||||||
|
model = _get_model([r_obj])
|
||||||
|
x = np.concatenate((np.ones((50, 4)), np.zeros((50, 4))))
|
||||||
|
y = np.concatenate((np.ones((50, 1)), np.ones((50, 1))))
|
||||||
|
model.evaluate(x, y)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.true_positives), 50.)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.false_negatives), 50.)
|
||||||
|
model.evaluate(x, y)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.true_positives), 50.)
|
||||||
|
self.assertEqual(self.evaluate(r_obj.false_negatives), 50.)
|
||||||
|
finally:
|
||||||
|
backend.set_floatx('float32')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
@ -299,9 +299,19 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
'`multi_label` is True.')
|
'`multi_label` is True.')
|
||||||
if variables_to_update is None:
|
if variables_to_update is None:
|
||||||
return
|
return
|
||||||
y_true = math_ops.cast(y_true, dtype=dtypes.float32)
|
if not any(
|
||||||
y_pred = math_ops.cast(y_pred, dtype=dtypes.float32)
|
key for key in variables_to_update if key in list(ConfusionMatrix)):
|
||||||
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=dtypes.float32)
|
raise ValueError(
|
||||||
|
'Please provide at least one valid confusion matrix '
|
||||||
|
'variable to update. Valid variable key options are: "{}". '
|
||||||
|
'Received: "{}"'.format(
|
||||||
|
list(ConfusionMatrix), variables_to_update.keys()))
|
||||||
|
|
||||||
|
variable_dtype = list(variables_to_update.values())[0].dtype
|
||||||
|
|
||||||
|
y_true = math_ops.cast(y_true, dtype=variable_dtype)
|
||||||
|
y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
|
||||||
|
thresholds = ops.convert_to_tensor_v2(thresholds, dtype=variable_dtype)
|
||||||
num_thresholds = thresholds.shape[0]
|
num_thresholds = thresholds.shape[0]
|
||||||
if multi_label:
|
if multi_label:
|
||||||
one_thresh = math_ops.equal(
|
one_thresh = math_ops.equal(
|
||||||
|
@ -314,14 +324,6 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
sample_weight)
|
sample_weight)
|
||||||
one_thresh = math_ops.cast(True, dtype=dtypes.bool)
|
one_thresh = math_ops.cast(True, dtype=dtypes.bool)
|
||||||
|
|
||||||
if not any(
|
|
||||||
key for key in variables_to_update if key in list(ConfusionMatrix)):
|
|
||||||
raise ValueError(
|
|
||||||
'Please provide at least one valid confusion matrix '
|
|
||||||
'variable to update. Valid variable key options are: "{}". '
|
|
||||||
'Received: "{}"'.format(
|
|
||||||
list(ConfusionMatrix), variables_to_update.keys()))
|
|
||||||
|
|
||||||
invalid_keys = [
|
invalid_keys = [
|
||||||
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
key for key in variables_to_update if key not in list(ConfusionMatrix)
|
||||||
]
|
]
|
||||||
|
@ -401,7 +403,7 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
|
|
||||||
if sample_weight is not None:
|
if sample_weight is not None:
|
||||||
sample_weight = weights_broadcast_ops.broadcast_weights(
|
sample_weight = weights_broadcast_ops.broadcast_weights(
|
||||||
math_ops.cast(sample_weight, dtype=dtypes.float32), y_pred)
|
math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
|
||||||
weights_tiled = array_ops.tile(
|
weights_tiled = array_ops.tile(
|
||||||
array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
|
array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
|
||||||
else:
|
else:
|
||||||
|
@ -422,9 +424,9 @@ def update_confusion_matrix_variables(variables_to_update,
|
||||||
|
|
||||||
def weighted_assign_add(label, pred, weights, var):
|
def weighted_assign_add(label, pred, weights, var):
|
||||||
label_and_pred = math_ops.cast(
|
label_and_pred = math_ops.cast(
|
||||||
math_ops.logical_and(label, pred), dtype=dtypes.float32)
|
math_ops.logical_and(label, pred), dtype=var.dtype)
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
label_and_pred *= weights
|
label_and_pred *= math_ops.cast(weights, dtype=var.dtype)
|
||||||
return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
|
return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
|
||||||
|
|
||||||
loop_vars = {
|
loop_vars = {
|
||||||
|
|
Loading…
Reference in New Issue