Avoid using equality for adding weights
PiperOrigin-RevId: 259071753
This commit is contained in:
parent
229dae116a
commit
ba1654087a
@ -2163,18 +2163,25 @@ class Layer(module.Module):
|
||||
for val in nest.flatten(value):
|
||||
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
|
||||
# no longer return True for isinstance Variable checks.
|
||||
if (isinstance(val, tf_variables.Variable) and
|
||||
not isinstance(val, resource_variable_ops._UnreadVariable)): # pylint: disable=protected-access
|
||||
# Users may add extra weights/variables
|
||||
# simply by assigning them to attributes (invalid for graph networks)
|
||||
self._maybe_create_attribute('_trainable_weights', [])
|
||||
self._maybe_create_attribute('_non_trainable_weights', [])
|
||||
if val not in self._trainable_weights + self._non_trainable_weights:
|
||||
if val.trainable:
|
||||
self._trainable_weights.append(val)
|
||||
else:
|
||||
self._non_trainable_weights.append(val)
|
||||
backend.track_variable(val)
|
||||
if not isinstance(val, tf_variables.Variable):
|
||||
continue
|
||||
if isinstance(val, resource_variable_ops._UnreadVariable): # pylint: disable=protected-access
|
||||
continue
|
||||
|
||||
# Users may add extra weights/variables
|
||||
# simply by assigning them to attributes (invalid for graph networks)
|
||||
self._maybe_create_attribute('_trainable_weights', [])
|
||||
self._maybe_create_attribute('_non_trainable_weights', [])
|
||||
if val.trainable:
|
||||
if any(val is w for w in self._trainable_weights):
|
||||
continue
|
||||
self._trainable_weights.append(val)
|
||||
else:
|
||||
if any(val is w for w in self._non_trainable_weights):
|
||||
continue
|
||||
self._non_trainable_weights.append(val)
|
||||
|
||||
backend.track_variable(val)
|
||||
|
||||
# Skip the auto trackable from tf.Module to keep status quo. See the comment
|
||||
# at __delattr__.
|
||||
|
Loading…
Reference in New Issue
Block a user