diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index eac18d63137..c63682541f6 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -273,8 +273,8 @@ class _StateManagerImpl(StateManager): """ self._trainable = trainable self._layer = layer - if self._layer is not None: - self._layer._maybe_create_attribute('_resources', []) # pylint: disable=protected-access + if self._layer is not None and not hasattr(self._layer, '_resources'): + self._layer._resources = [] # pylint: disable=protected-access self._cols_to_vars_map = collections.defaultdict(lambda: {}) # TODO(vbardiovsky): Make sure the resources are tracked by moving them to # the layer (inheriting from AutoTrackable), e.g.: diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 3d02e85a78e..903e55e78cf 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1119,7 +1119,7 @@ class Layer(module.Module): elif tensor_util.is_tensor(loss): eager_losses.append(_tag_unconditional(loss)) - self._callable_losses += callable_losses + self._callable_losses.extend(callable_losses) in_call_context = base_layer_utils.call_context().in_call if eager_losses and not in_call_context: @@ -1127,7 +1127,7 @@ class Layer(module.Module): 'Expected a symbolic Tensors or a callable for the loss value. ' 'Please wrap your loss computation in a zero argument `lambda`.') - self._eager_losses += eager_losses + self._eager_losses.extend(eager_losses) if in_call_context: for symbolic_loss in symbolic_losses: @@ -1308,7 +1308,7 @@ class Layer(module.Module): # they do not need to be tracked later. if ops.executing_eagerly_outside_functions() and call_context.in_call: updates = [u for u in updates if callable(u)] - self._updates += updates + self._updates.extend(updates) def set_weights(self, weights): """Sets the weights of the layer, from Numpy arrays. @@ -2194,6 +2194,7 @@ class Layer(module.Module): object_identity.ObjectIdentityDictionary()) return self._obj_reference_counts_dict + @trackable.no_automatic_dependency_tracking def _maybe_create_attribute(self, name, default_value): """Create the attribute with the default value if it hasn't been created. @@ -2255,7 +2256,6 @@ class Layer(module.Module): def __setattr__(self, name, value): if (name == '_self_setattr_tracking' or not getattr(self, '_self_setattr_tracking', True) or - getattr(self, '_is_graph_network', False) or # Exclude @property.setters from tracking hasattr(self.__class__, name)): try: diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index 78621c0245d..5726204cd17 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import test +from tensorflow.python.training.tracking.util import Checkpoint try: import yaml # pylint:disable=g-import-not-at-top @@ -1163,6 +1164,13 @@ class NetworkConstructionTest(keras_parameterized.TestCase): self.assertEqual('a', net2.layers[0].name) self.assertEqual('b', net2.layers[1].name) + @keras_parameterized.run_with_all_model_types + def test_dependency_tracking(self): + model = testing_utils.get_small_mlp(1, 4, input_dim=3) + model.trackable = Checkpoint() + self.assertIn('trackable', model._unconditional_dependency_names) + self.assertEqual(model.trackable, model._lookup_dependency('trackable')) + class DeferredModeTest(test.TestCase): diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 3d13c569f89..8810fae0308 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -1775,6 +1775,7 @@ class Model(network.Network): return self.callback_model return self + @trackable.no_automatic_dependency_tracking def _make_callback_model(self, grouped_model): first_replicated_model = self._distribution_strategy.unwrap( grouped_model)[0]