diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index 4d329b4b316..0bcae051264 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -363,10 +363,8 @@ def _get_grouped_variables(vars_to_warm_start): # out the list. grouped_variables = {} for v in list_of_vars: - if not isinstance(v, list): - var_name = _infer_var_name([v]) - else: - var_name = _infer_var_name(v) + v = [v] if not isinstance(v, list) else v + var_name = _infer_var_name(v) grouped_variables.setdefault(var_name, []).append(v) return grouped_variables