From 79f58c38edae028e7c4d897fdcb026264bf77330 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Nov 2020 13:27:50 -0800 Subject: [PATCH] Add an arg `num_labels` to class AUC, thus the callers that do not play well with lazy variable creation can avoid lazy variable creation. PiperOrigin-RevId: 343144219 Change-Id: I0284a91ef366447698f06521b3f38888eaa827eb --- tensorflow/python/keras/metrics.py | 13 +++++++++++-- .../golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt | 2 +- .../golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt | 2 +- .../api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index d88792d70aa..6a36ac447f3 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -1899,7 +1899,10 @@ class AUC(Metric): case, when multilabel data is passed to AUC, each label-prediction pair is treated as an individual data point. Should be set to False for multi-class data. - label_weights: (optional) list, array, or tensor of non-negative weights + num_labels: (Optional) The number of labels, used when `multi_label' is + True. If `num_labels` is not specified, then state variables get created + on the first call to `update_state`. + label_weights: (Optional) list, array, or tensor of non-negative weights used to compute AUCs for multilabel data. When `multi_label` is True, the weights are applied to the individual label AUCs when they are averaged to produce the multi-label AUC. When it's False, they are used @@ -1942,6 +1945,7 @@ class AUC(Metric): dtype=None, thresholds=None, multi_label=False, + num_labels=None, label_weights=None): # Validate configurations. if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( @@ -2004,8 +2008,13 @@ class AUC(Metric): self._built = False if self.multi_label: - self._num_labels = None + if num_labels: + shape = tensor_shape.TensorShape([None, num_labels]) + self._build(shape) else: + if num_labels: + raise ValueError( + '`num_labels` is needed only when `multi_label` is True.') self._build(None) @property diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt index b0d2b891dbc..294eaaab14d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -134,7 +134,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt index b0d2b891dbc..294eaaab14d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-a-u-c.pbtxt @@ -134,7 +134,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt index b2bc7a0a061..d7bd6b20982 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-a-u-c.pbtxt @@ -134,7 +134,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " } member_method { name: "add_loss"