Avoid using equality for adding weights

PiperOrigin-RevId: 259071753
This commit is contained in:
Gaurav Jain 2019-07-19 18:03:08 -07:00 committed by TensorFlower Gardener
parent 229dae116a
commit ba1654087a

View File

@ -2163,18 +2163,25 @@ class Layer(module.Module):
for val in nest.flatten(value):
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
# no longer return True for isinstance Variable checks.
if (isinstance(val, tf_variables.Variable) and
not isinstance(val, resource_variable_ops._UnreadVariable)): # pylint: disable=protected-access
# Users may add extra weights/variables
# simply by assigning them to attributes (invalid for graph networks)
self._maybe_create_attribute('_trainable_weights', [])
self._maybe_create_attribute('_non_trainable_weights', [])
if val not in self._trainable_weights + self._non_trainable_weights:
if val.trainable:
self._trainable_weights.append(val)
else:
self._non_trainable_weights.append(val)
backend.track_variable(val)
if not isinstance(val, tf_variables.Variable):
continue
if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access
continue
# Users may add extra weights/variables
# simply by assigning them to attributes (invalid for graph networks)
self._maybe_create_attribute('_trainable_weights', [])
self._maybe_create_attribute('_non_trainable_weights', [])
if val.trainable:
if any(val is w for w in self._trainable_weights):
continue
self._trainable_weights.append(val)
else:
if any(val is w for w in self._non_trainable_weights):
continue
self._non_trainable_weights.append(val)
backend.track_variable(val)
# Skip the auto trackable from tf.Module to keep status quo. See the comment
# at __delattr__.