From ba1654087a4966bd85328f399ed6f288da4b84db Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Fri, 19 Jul 2019 18:03:08 -0700 Subject: [PATCH] Avoid using equality for adding weights PiperOrigin-RevId: 259071753 --- tensorflow/python/keras/engine/base_layer.py | 31 ++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4cd6fa74819..5663ff16745 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2163,18 +2163,25 @@ class Layer(module.Module): for val in nest.flatten(value): # TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops # no longer return True for isinstance Variable checks. - if (isinstance(val, tf_variables.Variable) and - not isinstance(val, resource_variable_ops._UnreadVariable)): # pylint: disable=protected-access - # Users may add extra weights/variables - # simply by assigning them to attributes (invalid for graph networks) - self._maybe_create_attribute('_trainable_weights', []) - self._maybe_create_attribute('_non_trainable_weights', []) - if val not in self._trainable_weights + self._non_trainable_weights: - if val.trainable: - self._trainable_weights.append(val) - else: - self._non_trainable_weights.append(val) - backend.track_variable(val) + if not isinstance(val, tf_variables.Variable): + continue + if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access + continue + + # Users may add extra weights/variables + # simply by assigning them to attributes (invalid for graph networks) + self._maybe_create_attribute('_trainable_weights', []) + self._maybe_create_attribute('_non_trainable_weights', []) + if val.trainable: + if any(val is w for w in self._trainable_weights): + continue + self._trainable_weights.append(val) + else: + if any(val is w for w in self._non_trainable_weights): + continue + self._non_trainable_weights.append(val) + + backend.track_variable(val) # Skip the auto trackable from tf.Module to keep status quo. See the comment # at __delattr__.