Small refactor of thresholds default value and validation steps. The number of thresholds is used instead of the user specified type of the thresholds kwarg to determine the output of the result method:

`thresholds` is a scalar or single element list/tuple -> return scalar
`thresholds` is a multi element list/tuple -> return list

This is functionally equivalent to the previous code except for cases where the user passes in a single element list for the thresholds kwarg. In the previous code, this would cause the result method to return a list whereas now it returns a scalar.

PiperOrigin-RevId: 225079221
This commit is contained in:
A. Unique TensorFlower 2018-12-11 15:12:10 -08:00 committed by TensorFlower Gardener
parent 3b94c63e1b
commit 3dfe44784d

View File

@ -177,6 +177,12 @@ def _assert_thresholds_range(thresholds):
.format(invalid_thresholds))
def _parse_init_thresholds(thresholds, default_threshold=0.5):
thresholds = to_list(default_threshold if thresholds is None else thresholds)
_assert_thresholds_range(thresholds)
return thresholds
def _update_confusion_matrix_variables(variables_to_update,
y_true,
y_pred,
@ -869,12 +875,11 @@ class _ConfusionMatrixConditionCount(Metric):
"""
super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
self._confusion_matrix_cond = confusion_matrix_cond
self.thresholds = 0.5 if thresholds is None else thresholds
thresholds_list = to_list(self.thresholds)
_assert_thresholds_range(thresholds_list)
self.thresholds = _parse_init_thresholds(
thresholds, default_threshold=0.5)
self.accumulator = self.add_weight(
'accumulator',
shape=(len(thresholds_list),),
shape=(len(self.thresholds),),
initializer=init_ops.zeros_initializer)
def update_state(self, y_true, y_pred, sample_weight=None):
@ -895,10 +900,10 @@ class _ConfusionMatrixConditionCount(Metric):
}, y_true, y_pred, self.thresholds, sample_weight)
def result(self):
if isinstance(self.thresholds, (list, tuple)):
result = self.accumulator
else:
if len(self.thresholds) == 1:
result = self.accumulator[0]
else:
result = self.accumulator
return ops.convert_to_tensor(result)
def reset_states(self):
@ -1152,16 +1157,15 @@ class Precision(Metric):
dtype: (Optional) data type of the metric result.
"""
super(Precision, self).__init__(name=name, dtype=dtype)
self.thresholds = 0.5 if thresholds is None else thresholds
thresholds_list = to_list(self.thresholds)
_assert_thresholds_range(thresholds_list)
self.thresholds = _parse_init_thresholds(
thresholds, default_threshold=0.5)
self.tp = self.add_weight(
'true_positives',
shape=(len(thresholds_list),),
shape=(len(self.thresholds),),
initializer=init_ops.zeros_initializer)
self.fp = self.add_weight(
'false_positives',
shape=(len(thresholds_list),),
shape=(len(self.thresholds),),
initializer=init_ops.zeros_initializer)
def update_state(self, y_true, y_pred, sample_weight=None):
@ -1184,7 +1188,7 @@ class Precision(Metric):
def result(self):
result = math_ops.div_no_nan(self.tp, self.tp + self.fp)
return result if isinstance(self.thresholds, (list, tuple)) else result[0]
return result[0] if len(self.thresholds) == 1 else result
def reset_states(self):
num_thresholds = len(to_list(self.thresholds))
@ -1237,16 +1241,15 @@ class Recall(Metric):
dtype: (Optional) data type of the metric result.
"""
super(Recall, self).__init__(name=name, dtype=dtype)
self.thresholds = 0.5 if thresholds is None else thresholds
thresholds_list = to_list(self.thresholds)
_assert_thresholds_range(thresholds_list)
self.thresholds = _parse_init_thresholds(
thresholds, default_threshold=0.5)
self.tp = self.add_weight(
'true_positives',
shape=(len(thresholds_list),),
shape=(len(self.thresholds),),
initializer=init_ops.zeros_initializer)
self.fn = self.add_weight(
'false_negatives',
shape=(len(thresholds_list),),
shape=(len(self.thresholds),),
initializer=init_ops.zeros_initializer)
def update_state(self, y_true, y_pred, sample_weight=None):
@ -1269,7 +1272,7 @@ class Recall(Metric):
def result(self):
result = math_ops.div_no_nan(self.tp, self.tp + self.fn)
return result if isinstance(self.thresholds, (list, tuple)) else result[0]
return result[0] if len(self.thresholds) == 1 else result
def reset_states(self):
num_thresholds = len(to_list(self.thresholds))