Checkpointable: A small utility for exempting objects from __setattr__ tracking

Exposes it as tf.contrib.checkpoint.NoDependency. Objects wrapped in a
NoDependency object get unwrapped in __setattr__ and not tracked.

Removes the _save_counter dependency from tf.train.Checkpoint (the save counter
is still tracked as "save_counter" and always has been, so this is a
backwards-compatible dependency removal).

PiperOrigin-RevId: 195502562
This commit is contained in:
Allen Lavoie 2018-05-04 18:25:18 -07:00 committed by TensorFlower Gardener
parent 59f0618ced
commit dd5ef1b9fc
6 changed files with 69 additions and 4 deletions

View File

@ -19,6 +19,7 @@ For creating and managing dependencies:
@@CheckpointableObjectGraph
@@dot_graph_from_checkpoint
@@object_metadata
@@NoDependency
@@split_dependency
"""
@ -29,6 +30,7 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.python.training.checkpointable import NoDependency
from tensorflow.python.training.checkpointable_utils import object_metadata
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -318,6 +318,9 @@ class Network(base_layer.Layer):
layer, name='layer-%d' % layer_index, overwrite=True)
def __setattr__(self, name, value):
no_dependency = isinstance(value, checkpointable.NoDependency)
if no_dependency:
value = value.value
if isinstance(value, (base_layer.Layer, Network)):
try:
is_graph_network = self._is_graph_network
@ -332,7 +335,8 @@ class Network(base_layer.Layer):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
value._use_resource_variables = True
if isinstance(value, checkpointable.CheckpointableBase):
if (not no_dependency
and isinstance(value, checkpointable.CheckpointableBase)):
# Layer (and therefore Network/Model) inherit from CheckpointableBase
# rather than Checkpointable, which means there is no Checkpointable
# __setattr__ override (it would be a performance issue for functional

View File

@ -28,7 +28,9 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training import checkpointable
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@ -583,6 +585,22 @@ class ModelSubclassingTest(test.TestCase):
loss = model.train_on_batch(x, y)
self.assertGreater(loss, 0.1)
def test_no_dependency(self):
class Foo(keras.Model):
def __init__(self):
super(Foo, self).__init__()
self.isdep = keras.layers.Dense(1)
self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
self.notdep_var = checkpointable.NoDependency(
resource_variable_ops.ResourceVariable(1., name='notdep_var'))
m = Foo()
self.assertEqual([m.isdep, m.notdep], m.layers)
self.assertEqual(1, len(m._checkpoint_dependencies))
self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
self.assertEqual('notdep_var:0', m.notdep_var.name)
class CustomCallModel(keras.Model):

View File

@ -659,6 +659,31 @@ class CheckpointableBase(object):
return {}
class NoDependency(object):
"""Allows attribute assignment to `Checkpointable` objects with no dependency.
Example usage:
```python
obj = Checkpointable()
obj.has_dependency = tf.Variable(0., name="dep")
obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
assert obj.no_dependency.name == "nodep:0"
```
`obj` in this example has a dependency on the variable "dep", and both
attributes contain un-wrapped `Variable` objects.
`NoDependency` also works with `tf.keras.Model`, but only for checkpoint
dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
`Layer` to the attribute without a checkpoint dependency, but the `Model` will
still track the `Layer` (so it will appear in `Model.layers`, and its
variables will appear in `Model.variables`).
"""
def __init__(self, value):
self.value = value
class Checkpointable(CheckpointableBase):
"""Manages dependencies on other objects.
@ -691,8 +716,11 @@ class Checkpointable(CheckpointableBase):
"""Support self.foo = checkpointable syntax."""
# Perform the attribute assignment, and potentially call other __setattr__
# overrides such as that for tf.keras.Model.
no_dependency = isinstance(value, NoDependency)
if no_dependency:
value = value.value
super(Checkpointable, self).__setattr__(name, value)
if isinstance(value, CheckpointableBase):
if not no_dependency and isinstance(value, CheckpointableBase):
self._track_checkpointable(
value, name=name,
# Allow the user to switch the Checkpointable which is tracked by this

View File

@ -34,6 +34,16 @@ class InterfaceTests(test.TestCase):
root.leaf = duplicate_name_dep
root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
def testNoDependency(self):
root = checkpointable.Checkpointable()
hasdep = checkpointable.Checkpointable()
root.hasdep = hasdep
nodep = checkpointable.Checkpointable()
root.nodep = checkpointable.NoDependency(nodep)
self.assertEqual(1, len(root._checkpoint_dependencies))
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
self.assertIs(root.hasdep, hasdep)
self.assertIs(root.nodep, nodep)
if __name__ == "__main__":
test.main()

View File

@ -1044,8 +1044,11 @@ class Checkpoint(checkpointable_lib.Checkpointable):
if self._save_counter is None:
# Initialized to 0 and incremented before saving.
with ops.device("/cpu:0"):
self._save_counter = add_variable(
self, name="save_counter", initializer=0, dtype=dtypes.int64)
# add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter".
self._save_counter = checkpointable_lib.NoDependency(
add_variable(self, name="save_counter", initializer=0,
dtype=dtypes.int64))
@property
def save_counter(self):