Avoid using equality for adding weights
PiperOrigin-RevId: 259071753
This commit is contained in:
parent
229dae116a
commit
ba1654087a
@ -2163,17 +2163,24 @@ class Layer(module.Module):
|
|||||||
for val in nest.flatten(value):
|
for val in nest.flatten(value):
|
||||||
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
|
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
|
||||||
# no longer return True for isinstance Variable checks.
|
# no longer return True for isinstance Variable checks.
|
||||||
if (isinstance(val, tf_variables.Variable) and
|
if not isinstance(val, tf_variables.Variable):
|
||||||
not isinstance(val, resource_variable_ops._UnreadVariable)): # pylint: disable=protected-access
|
continue
|
||||||
|
if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access
|
||||||
|
continue
|
||||||
|
|
||||||
# Users may add extra weights/variables
|
# Users may add extra weights/variables
|
||||||
# simply by assigning them to attributes (invalid for graph networks)
|
# simply by assigning them to attributes (invalid for graph networks)
|
||||||
self._maybe_create_attribute('_trainable_weights', [])
|
self._maybe_create_attribute('_trainable_weights', [])
|
||||||
self._maybe_create_attribute('_non_trainable_weights', [])
|
self._maybe_create_attribute('_non_trainable_weights', [])
|
||||||
if val not in self._trainable_weights + self._non_trainable_weights:
|
|
||||||
if val.trainable:
|
if val.trainable:
|
||||||
|
if any(val is w for w in self._trainable_weights):
|
||||||
|
continue
|
||||||
self._trainable_weights.append(val)
|
self._trainable_weights.append(val)
|
||||||
else:
|
else:
|
||||||
|
if any(val is w for w in self._non_trainable_weights):
|
||||||
|
continue
|
||||||
self._non_trainable_weights.append(val)
|
self._non_trainable_weights.append(val)
|
||||||
|
|
||||||
backend.track_variable(val)
|
backend.track_variable(val)
|
||||||
|
|
||||||
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
||||||
|
Loading…
Reference in New Issue
Block a user