Allow sample weights to be passed in as a Tuple.

PiperOrigin-RevId: 251353592
This commit is contained in:
Katherine Wu 2019-06-03 18:46:19 -07:00 committed by TensorFlower Gardener
parent 9e27dbff00
commit 2720abdb86

View File

@ -537,16 +537,17 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
Raises:
ValueError: In case of invalid user-provided argument.
"""
if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
if x_weight is None or (isinstance(x_weight, (list, tuple)) and
len(x_weight) == 0): # pylint: disable=g-explicit-length-test
return [None for _ in output_names]
if len(output_names) == 1:
if isinstance(x_weight, list) and len(x_weight) == 1:
if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
return x_weight
if isinstance(x_weight, dict) and output_names[0] in x_weight:
return [x_weight[output_names[0]]]
else:
return [x_weight]
if isinstance(x_weight, list):
if isinstance(x_weight, (list, tuple)):
if len(x_weight) != len(output_names):
raise ValueError('Provided `' + weight_type + '` was a list of ' +
str(len(x_weight)) + ' elements, but the model has ' +