diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 9a2023a26fe..64be9eafeb1 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -537,16 +537,17 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type): Raises: ValueError: In case of invalid user-provided argument. """ - if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test + if x_weight is None or (isinstance(x_weight, (list, tuple)) and + len(x_weight) == 0): # pylint: disable=g-explicit-length-test return [None for _ in output_names] if len(output_names) == 1: - if isinstance(x_weight, list) and len(x_weight) == 1: + if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1: return x_weight if isinstance(x_weight, dict) and output_names[0] in x_weight: return [x_weight[output_names[0]]] else: return [x_weight] - if isinstance(x_weight, list): + if isinstance(x_weight, (list, tuple)): if len(x_weight) != len(output_names): raise ValueError('Provided `' + weight_type + '` was a list of ' + str(len(x_weight)) + ' elements, but the model has ' +