Track trackables in graph networks, and removing automatic tracking of Keras-internal attributes.

Also resolves loading bug from .

PiperOrigin-RevId: 266440936
This commit is contained in:
Katherine Wu 2019-08-30 13:09:17 -07:00 committed by TensorFlower Gardener
parent e4d8a30f03
commit 93b86cc5ee
4 changed files with 15 additions and 6 deletions
tensorflow/python

View File

@ -273,8 +273,8 @@ class _StateManagerImpl(StateManager):
"""
self._trainable = trainable
self._layer = layer
if self._layer is not None:
self._layer._maybe_create_attribute('_resources', []) # pylint: disable=protected-access
if self._layer is not None and not hasattr(self._layer, '_resources'):
self._layer._resources = [] # pylint: disable=protected-access
self._cols_to_vars_map = collections.defaultdict(lambda: {})
# TODO(vbardiovsky): Make sure the resources are tracked by moving them to
# the layer (inheriting from AutoTrackable), e.g.:

View File

@ -1119,7 +1119,7 @@ class Layer(module.Module):
elif tensor_util.is_tensor(loss):
eager_losses.append(_tag_unconditional(loss))
self._callable_losses += callable_losses
self._callable_losses.extend(callable_losses)
in_call_context = base_layer_utils.call_context().in_call
if eager_losses and not in_call_context:
@ -1127,7 +1127,7 @@ class Layer(module.Module):
'Expected a symbolic Tensors or a callable for the loss value. '
'Please wrap your loss computation in a zero argument `lambda`.')
self._eager_losses += eager_losses
self._eager_losses.extend(eager_losses)
if in_call_context:
for symbolic_loss in symbolic_losses:
@ -1308,7 +1308,7 @@ class Layer(module.Module):
# they do not need to be tracked later.
if ops.executing_eagerly_outside_functions() and call_context.in_call:
updates = [u for u in updates if callable(u)]
self._updates += updates
self._updates.extend(updates)
def set_weights(self, weights):
"""Sets the weights of the layer, from Numpy arrays.
@ -2194,6 +2194,7 @@ class Layer(module.Module):
object_identity.ObjectIdentityDictionary())
return self._obj_reference_counts_dict
@trackable.no_automatic_dependency_tracking
def _maybe_create_attribute(self, name, default_value):
"""Create the attribute with the default value if it hasn't been created.
@ -2255,7 +2256,6 @@ class Layer(module.Module):
def __setattr__(self, name, value):
if (name == '_self_setattr_tracking' or
not getattr(self, '_self_setattr_tracking', True) or
getattr(self, '_is_graph_network', False) or
# Exclude @property.setters from tracking
hasattr(self.__class__, name)):
try:

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
from tensorflow.python.training.tracking.util import Checkpoint
try:
import yaml # pylint:disable=g-import-not-at-top
@ -1163,6 +1164,13 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
self.assertEqual('a', net2.layers[0].name)
self.assertEqual('b', net2.layers[1].name)
@keras_parameterized.run_with_all_model_types
def test_dependency_tracking(self):
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.trackable = Checkpoint()
self.assertIn('trackable', model._unconditional_dependency_names)
self.assertEqual(model.trackable, model._lookup_dependency('trackable'))
class DeferredModeTest(test.TestCase):

View File

@ -1775,6 +1775,7 @@ class Model(network.Network):
return self.callback_model
return self
@trackable.no_automatic_dependency_tracking
def _make_callback_model(self, grouped_model):
first_replicated_model = self._distribution_strategy.unwrap(
grouped_model)[0]