Small refactor of thresholds
default value and validation steps. The number of thresholds is used instead of the user specified type of the thresholds
kwarg to determine the output of the result method:
`thresholds` is a scalar or single element list/tuple -> return scalar `thresholds` is a multi element list/tuple -> return list This is functionally equivalent to the previous code except for cases where the user passes in a single element list for the thresholds kwarg. In the previous code, this would cause the result method to return a list whereas now it returns a scalar. PiperOrigin-RevId: 225079221
This commit is contained in:
parent
3b94c63e1b
commit
3dfe44784d
@ -177,6 +177,12 @@ def _assert_thresholds_range(thresholds):
|
||||
.format(invalid_thresholds))
|
||||
|
||||
|
||||
def _parse_init_thresholds(thresholds, default_threshold=0.5):
|
||||
thresholds = to_list(default_threshold if thresholds is None else thresholds)
|
||||
_assert_thresholds_range(thresholds)
|
||||
return thresholds
|
||||
|
||||
|
||||
def _update_confusion_matrix_variables(variables_to_update,
|
||||
y_true,
|
||||
y_pred,
|
||||
@ -869,12 +875,11 @@ class _ConfusionMatrixConditionCount(Metric):
|
||||
"""
|
||||
super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
|
||||
self._confusion_matrix_cond = confusion_matrix_cond
|
||||
self.thresholds = 0.5 if thresholds is None else thresholds
|
||||
thresholds_list = to_list(self.thresholds)
|
||||
_assert_thresholds_range(thresholds_list)
|
||||
self.thresholds = _parse_init_thresholds(
|
||||
thresholds, default_threshold=0.5)
|
||||
self.accumulator = self.add_weight(
|
||||
'accumulator',
|
||||
shape=(len(thresholds_list),),
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
@ -895,10 +900,10 @@ class _ConfusionMatrixConditionCount(Metric):
|
||||
}, y_true, y_pred, self.thresholds, sample_weight)
|
||||
|
||||
def result(self):
|
||||
if isinstance(self.thresholds, (list, tuple)):
|
||||
result = self.accumulator
|
||||
else:
|
||||
if len(self.thresholds) == 1:
|
||||
result = self.accumulator[0]
|
||||
else:
|
||||
result = self.accumulator
|
||||
return ops.convert_to_tensor(result)
|
||||
|
||||
def reset_states(self):
|
||||
@ -1152,16 +1157,15 @@ class Precision(Metric):
|
||||
dtype: (Optional) data type of the metric result.
|
||||
"""
|
||||
super(Precision, self).__init__(name=name, dtype=dtype)
|
||||
self.thresholds = 0.5 if thresholds is None else thresholds
|
||||
thresholds_list = to_list(self.thresholds)
|
||||
_assert_thresholds_range(thresholds_list)
|
||||
self.thresholds = _parse_init_thresholds(
|
||||
thresholds, default_threshold=0.5)
|
||||
self.tp = self.add_weight(
|
||||
'true_positives',
|
||||
shape=(len(thresholds_list),),
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.fp = self.add_weight(
|
||||
'false_positives',
|
||||
shape=(len(thresholds_list),),
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
@ -1184,7 +1188,7 @@ class Precision(Metric):
|
||||
|
||||
def result(self):
|
||||
result = math_ops.div_no_nan(self.tp, self.tp + self.fp)
|
||||
return result if isinstance(self.thresholds, (list, tuple)) else result[0]
|
||||
return result[0] if len(self.thresholds) == 1 else result
|
||||
|
||||
def reset_states(self):
|
||||
num_thresholds = len(to_list(self.thresholds))
|
||||
@ -1237,16 +1241,15 @@ class Recall(Metric):
|
||||
dtype: (Optional) data type of the metric result.
|
||||
"""
|
||||
super(Recall, self).__init__(name=name, dtype=dtype)
|
||||
self.thresholds = 0.5 if thresholds is None else thresholds
|
||||
thresholds_list = to_list(self.thresholds)
|
||||
_assert_thresholds_range(thresholds_list)
|
||||
self.thresholds = _parse_init_thresholds(
|
||||
thresholds, default_threshold=0.5)
|
||||
self.tp = self.add_weight(
|
||||
'true_positives',
|
||||
shape=(len(thresholds_list),),
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.fn = self.add_weight(
|
||||
'false_negatives',
|
||||
shape=(len(thresholds_list),),
|
||||
shape=(len(self.thresholds),),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
@ -1269,7 +1272,7 @@ class Recall(Metric):
|
||||
|
||||
def result(self):
|
||||
result = math_ops.div_no_nan(self.tp, self.tp + self.fn)
|
||||
return result if isinstance(self.thresholds, (list, tuple)) else result[0]
|
||||
return result[0] if len(self.thresholds) == 1 else result
|
||||
|
||||
def reset_states(self):
|
||||
num_thresholds = len(to_list(self.thresholds))
|
||||
|
Loading…
Reference in New Issue
Block a user