Add an arg num_labels
to class AUC, thus the callers that do not play well with lazy variable creation can avoid lazy variable creation.
PiperOrigin-RevId: 343144219 Change-Id: I0284a91ef366447698f06521b3f38888eaa827eb
This commit is contained in:
parent
2e077fc356
commit
79f58c38ed
tensorflow
python/keras
tools/api/golden
@ -1899,7 +1899,10 @@ class AUC(Metric):
|
||||
case, when multilabel data is passed to AUC, each label-prediction pair
|
||||
is treated as an individual data point. Should be set to False for
|
||||
multi-class data.
|
||||
label_weights: (optional) list, array, or tensor of non-negative weights
|
||||
num_labels: (Optional) The number of labels, used when `multi_label' is
|
||||
True. If `num_labels` is not specified, then state variables get created
|
||||
on the first call to `update_state`.
|
||||
label_weights: (Optional) list, array, or tensor of non-negative weights
|
||||
used to compute AUCs for multilabel data. When `multi_label` is True,
|
||||
the weights are applied to the individual label AUCs when they are
|
||||
averaged to produce the multi-label AUC. When it's False, they are used
|
||||
@ -1942,6 +1945,7 @@ class AUC(Metric):
|
||||
dtype=None,
|
||||
thresholds=None,
|
||||
multi_label=False,
|
||||
num_labels=None,
|
||||
label_weights=None):
|
||||
# Validate configurations.
|
||||
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
|
||||
@ -2004,8 +2008,13 @@ class AUC(Metric):
|
||||
|
||||
self._built = False
|
||||
if self.multi_label:
|
||||
self._num_labels = None
|
||||
if num_labels:
|
||||
shape = tensor_shape.TensorShape([None, num_labels])
|
||||
self._build(shape)
|
||||
else:
|
||||
if num_labels:
|
||||
raise ValueError(
|
||||
'`num_labels` is needed only when `multi_label` is True.')
|
||||
self._build(None)
|
||||
|
||||
@property
|
||||
|
@ -134,7 +134,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -134,7 +134,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -134,7 +134,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'num_thresholds\', \'curve\', \'summation_method\', \'name\', \'dtype\', \'thresholds\', \'multi_label\', \'num_labels\', \'label_weights\'], varargs=None, keywords=None, defaults=[\'200\', \'ROC\', \'interpolation\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
Loading…
Reference in New Issue
Block a user