Track trackables in graph networks, and removing automatic tracking of Keras-internal attributes.
Also resolves loading bug from #31893. PiperOrigin-RevId: 266440936
This commit is contained in:
parent
e4d8a30f03
commit
93b86cc5ee
tensorflow/python
@ -273,8 +273,8 @@ class _StateManagerImpl(StateManager):
|
||||
"""
|
||||
self._trainable = trainable
|
||||
self._layer = layer
|
||||
if self._layer is not None:
|
||||
self._layer._maybe_create_attribute('_resources', []) # pylint: disable=protected-access
|
||||
if self._layer is not None and not hasattr(self._layer, '_resources'):
|
||||
self._layer._resources = [] # pylint: disable=protected-access
|
||||
self._cols_to_vars_map = collections.defaultdict(lambda: {})
|
||||
# TODO(vbardiovsky): Make sure the resources are tracked by moving them to
|
||||
# the layer (inheriting from AutoTrackable), e.g.:
|
||||
|
@ -1119,7 +1119,7 @@ class Layer(module.Module):
|
||||
elif tensor_util.is_tensor(loss):
|
||||
eager_losses.append(_tag_unconditional(loss))
|
||||
|
||||
self._callable_losses += callable_losses
|
||||
self._callable_losses.extend(callable_losses)
|
||||
|
||||
in_call_context = base_layer_utils.call_context().in_call
|
||||
if eager_losses and not in_call_context:
|
||||
@ -1127,7 +1127,7 @@ class Layer(module.Module):
|
||||
'Expected a symbolic Tensors or a callable for the loss value. '
|
||||
'Please wrap your loss computation in a zero argument `lambda`.')
|
||||
|
||||
self._eager_losses += eager_losses
|
||||
self._eager_losses.extend(eager_losses)
|
||||
|
||||
if in_call_context:
|
||||
for symbolic_loss in symbolic_losses:
|
||||
@ -1308,7 +1308,7 @@ class Layer(module.Module):
|
||||
# they do not need to be tracked later.
|
||||
if ops.executing_eagerly_outside_functions() and call_context.in_call:
|
||||
updates = [u for u in updates if callable(u)]
|
||||
self._updates += updates
|
||||
self._updates.extend(updates)
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets the weights of the layer, from Numpy arrays.
|
||||
@ -2194,6 +2194,7 @@ class Layer(module.Module):
|
||||
object_identity.ObjectIdentityDictionary())
|
||||
return self._obj_reference_counts_dict
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _maybe_create_attribute(self, name, default_value):
|
||||
"""Create the attribute with the default value if it hasn't been created.
|
||||
|
||||
@ -2255,7 +2256,6 @@ class Layer(module.Module):
|
||||
def __setattr__(self, name, value):
|
||||
if (name == '_self_setattr_tracking' or
|
||||
not getattr(self, '_self_setattr_tracking', True) or
|
||||
getattr(self, '_is_graph_network', False) or
|
||||
# Exclude @property.setters from tracking
|
||||
hasattr(self.__class__, name)):
|
||||
try:
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking.util import Checkpoint
|
||||
|
||||
try:
|
||||
import yaml # pylint:disable=g-import-not-at-top
|
||||
@ -1163,6 +1164,13 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
self.assertEqual('a', net2.layers[0].name)
|
||||
self.assertEqual('b', net2.layers[1].name)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_dependency_tracking(self):
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.trackable = Checkpoint()
|
||||
self.assertIn('trackable', model._unconditional_dependency_names)
|
||||
self.assertEqual(model.trackable, model._lookup_dependency('trackable'))
|
||||
|
||||
|
||||
class DeferredModeTest(test.TestCase):
|
||||
|
||||
|
@ -1775,6 +1775,7 @@ class Model(network.Network):
|
||||
return self.callback_model
|
||||
return self
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _make_callback_model(self, grouped_model):
|
||||
first_replicated_model = self._distribution_strategy.unwrap(
|
||||
grouped_model)[0]
|
||||
|
Loading…
Reference in New Issue
Block a user