diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 1d1f3b45864..c8ccb7f6242 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -177,6 +177,12 @@ def _assert_thresholds_range(thresholds): .format(invalid_thresholds)) +def _parse_init_thresholds(thresholds, default_threshold=0.5): + thresholds = to_list(default_threshold if thresholds is None else thresholds) + _assert_thresholds_range(thresholds) + return thresholds + + def _update_confusion_matrix_variables(variables_to_update, y_true, y_pred, @@ -869,12 +875,11 @@ class _ConfusionMatrixConditionCount(Metric): """ super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype) self._confusion_matrix_cond = confusion_matrix_cond - self.thresholds = 0.5 if thresholds is None else thresholds - thresholds_list = to_list(self.thresholds) - _assert_thresholds_range(thresholds_list) + self.thresholds = _parse_init_thresholds( + thresholds, default_threshold=0.5) self.accumulator = self.add_weight( 'accumulator', - shape=(len(thresholds_list),), + shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer) def update_state(self, y_true, y_pred, sample_weight=None): @@ -895,10 +900,10 @@ class _ConfusionMatrixConditionCount(Metric): }, y_true, y_pred, self.thresholds, sample_weight) def result(self): - if isinstance(self.thresholds, (list, tuple)): - result = self.accumulator - else: + if len(self.thresholds) == 1: result = self.accumulator[0] + else: + result = self.accumulator return ops.convert_to_tensor(result) def reset_states(self): @@ -1152,16 +1157,15 @@ class Precision(Metric): dtype: (Optional) data type of the metric result. """ super(Precision, self).__init__(name=name, dtype=dtype) - self.thresholds = 0.5 if thresholds is None else thresholds - thresholds_list = to_list(self.thresholds) - _assert_thresholds_range(thresholds_list) + self.thresholds = _parse_init_thresholds( + thresholds, default_threshold=0.5) self.tp = self.add_weight( 'true_positives', - shape=(len(thresholds_list),), + shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer) self.fp = self.add_weight( 'false_positives', - shape=(len(thresholds_list),), + shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer) def update_state(self, y_true, y_pred, sample_weight=None): @@ -1184,7 +1188,7 @@ class Precision(Metric): def result(self): result = math_ops.div_no_nan(self.tp, self.tp + self.fp) - return result if isinstance(self.thresholds, (list, tuple)) else result[0] + return result[0] if len(self.thresholds) == 1 else result def reset_states(self): num_thresholds = len(to_list(self.thresholds)) @@ -1237,16 +1241,15 @@ class Recall(Metric): dtype: (Optional) data type of the metric result. """ super(Recall, self).__init__(name=name, dtype=dtype) - self.thresholds = 0.5 if thresholds is None else thresholds - thresholds_list = to_list(self.thresholds) - _assert_thresholds_range(thresholds_list) + self.thresholds = _parse_init_thresholds( + thresholds, default_threshold=0.5) self.tp = self.add_weight( 'true_positives', - shape=(len(thresholds_list),), + shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer) self.fn = self.add_weight( 'false_negatives', - shape=(len(thresholds_list),), + shape=(len(self.thresholds),), initializer=init_ops.zeros_initializer) def update_state(self, y_true, y_pred, sample_weight=None): @@ -1269,7 +1272,7 @@ class Recall(Metric): def result(self): result = math_ops.div_no_nan(self.tp, self.tp + self.fn) - return result if isinstance(self.thresholds, (list, tuple)) else result[0] + return result[0] if len(self.thresholds) == 1 else result def reset_states(self): num_thresholds = len(to_list(self.thresholds))