Avoid using equality for adding weights

PiperOrigin-RevId: 259071753
This commit is contained in:
Gaurav Jain 2019-07-19 18:03:08 -07:00 committed by TensorFlower Gardener
parent 229dae116a
commit ba1654087a

View File

@ -2163,18 +2163,25 @@ class Layer(module.Module):
for val in nest.flatten(value): for val in nest.flatten(value):
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops # TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
# no longer return True for isinstance Variable checks. # no longer return True for isinstance Variable checks.
if (isinstance(val, tf_variables.Variable) and if not isinstance(val, tf_variables.Variable):
not isinstance(val, resource_variable_ops._UnreadVariable)): # pylint: disable=protected-access continue
# Users may add extra weights/variables if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access
# simply by assigning them to attributes (invalid for graph networks) continue
self._maybe_create_attribute('_trainable_weights', [])
self._maybe_create_attribute('_non_trainable_weights', []) # Users may add extra weights/variables
if val not in self._trainable_weights + self._non_trainable_weights: # simply by assigning them to attributes (invalid for graph networks)
if val.trainable: self._maybe_create_attribute('_trainable_weights', [])
self._trainable_weights.append(val) self._maybe_create_attribute('_non_trainable_weights', [])
else: if val.trainable:
self._non_trainable_weights.append(val) if any(val is w for w in self._trainable_weights):
backend.track_variable(val) continue
self._trainable_weights.append(val)
else:
if any(val is w for w in self._non_trainable_weights):
continue
self._non_trainable_weights.append(val)
backend.track_variable(val)
# Skip the auto trackable from tf.Module to keep status quo. See the comment # Skip the auto trackable from tf.Module to keep status quo. See the comment
# at __delattr__. # at __delattr__.