Allow sample weights to be passed in as a Tuple.
PiperOrigin-RevId: 251353592
This commit is contained in:
parent
9e27dbff00
commit
2720abdb86
@ -537,16 +537,17 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
|
||||
Raises:
|
||||
ValueError: In case of invalid user-provided argument.
|
||||
"""
|
||||
if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
|
||||
if x_weight is None or (isinstance(x_weight, (list, tuple)) and
|
||||
len(x_weight) == 0): # pylint: disable=g-explicit-length-test
|
||||
return [None for _ in output_names]
|
||||
if len(output_names) == 1:
|
||||
if isinstance(x_weight, list) and len(x_weight) == 1:
|
||||
if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
|
||||
return x_weight
|
||||
if isinstance(x_weight, dict) and output_names[0] in x_weight:
|
||||
return [x_weight[output_names[0]]]
|
||||
else:
|
||||
return [x_weight]
|
||||
if isinstance(x_weight, list):
|
||||
if isinstance(x_weight, (list, tuple)):
|
||||
if len(x_weight) != len(output_names):
|
||||
raise ValueError('Provided `' + weight_type + '` was a list of ' +
|
||||
str(len(x_weight)) + ' elements, but the model has ' +
|
||||
|
Loading…
Reference in New Issue
Block a user