Add an arg num_labels to class AUC, thus the callers that do not play well with lazy variable creation can avoid lazy variable creation.

PiperOrigin-RevId: 343144219
Change-Id: I0284a91ef366447698f06521b3f38888eaa827eb
This commit is contained in:
A. Unique TensorFlower 2020-11-18 13:27:50 -08:00 committed by TensorFlower Gardener
parent 2e077fc356
commit 79f58c38ed
4 changed files with 14 additions and 5 deletions

View File

@ -1899,7 +1899,10 @@ class AUC(Metric):
case, when multilabel data is passed to AUC, each label-prediction pair
is treated as an individual data point. Should be set to False for
multi-class data.
label_weights: (optional) list, array, or tensor of non-negative weights
num_labels: (Optional) The number of labels, used when `multi_label' is
True. If `num_labels` is not specified, then state variables get created
on the first call to `update_state`.
label_weights: (Optional) list, array, or tensor of non-negative weights
used to compute AUCs for multilabel data. When `multi_label` is True,
the weights are applied to the individual label AUCs when they are
averaged to produce the multi-label AUC. When it's False, they are used
@ -1942,6 +1945,7 @@ class AUC(Metric):
dtype=None,
thresholds=None,
multi_label=False,
num_labels=None,
label_weights=None):
# Validate configurations.
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
@ -2004,8 +2008,13 @@ class AUC(Metric):
self._built = False
if self.multi_label:
self._num_labels = None
if num_labels:
shape = tensor_shape.TensorShape([None, num_labels])
self._build(shape)
else:
if num_labels:
raise ValueError(
'`num_labels` is needed only when `multi_label` is True.')
self._build(None)
@property

View File

@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"

View File

@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"

View File

@ -134,7 +134,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"