Checkpointable: A small utility for exempting objects from __setattr__ tracking
Exposes it as tf.contrib.checkpoint.NoDependency. Objects wrapped in a NoDependency object get unwrapped in __setattr__ and not tracked. Removes the _save_counter dependency from tf.train.Checkpoint (the save counter is still tracked as "save_counter" and always has been, so this is a backwards-compatible dependency removal). PiperOrigin-RevId: 195502562
This commit is contained in:
parent
59f0618ced
commit
dd5ef1b9fc
@ -19,6 +19,7 @@ For creating and managing dependencies:
|
||||
@@CheckpointableObjectGraph
|
||||
@@dot_graph_from_checkpoint
|
||||
@@object_metadata
|
||||
@@NoDependency
|
||||
@@split_dependency
|
||||
"""
|
||||
|
||||
@ -29,6 +30,7 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
|
||||
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
|
||||
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
|
||||
from tensorflow.python.training.checkpointable import NoDependency
|
||||
from tensorflow.python.training.checkpointable_utils import object_metadata
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -318,6 +318,9 @@ class Network(base_layer.Layer):
|
||||
layer, name='layer-%d' % layer_index, overwrite=True)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
no_dependency = isinstance(value, checkpointable.NoDependency)
|
||||
if no_dependency:
|
||||
value = value.value
|
||||
if isinstance(value, (base_layer.Layer, Network)):
|
||||
try:
|
||||
is_graph_network = self._is_graph_network
|
||||
@ -332,7 +335,8 @@ class Network(base_layer.Layer):
|
||||
# In subclassed models, legacy layers (tf.layers) must always use
|
||||
# resource variables.
|
||||
value._use_resource_variables = True
|
||||
if isinstance(value, checkpointable.CheckpointableBase):
|
||||
if (not no_dependency
|
||||
and isinstance(value, checkpointable.CheckpointableBase)):
|
||||
# Layer (and therefore Network/Model) inherit from CheckpointableBase
|
||||
# rather than Checkpointable, which means there is no Checkpointable
|
||||
# __setattr__ override (it would be a performance issue for functional
|
||||
|
@ -28,7 +28,9 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras._impl import keras
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpointable
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
try:
|
||||
@ -583,6 +585,22 @@ class ModelSubclassingTest(test.TestCase):
|
||||
loss = model.train_on_batch(x, y)
|
||||
self.assertGreater(loss, 0.1)
|
||||
|
||||
def test_no_dependency(self):
|
||||
class Foo(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self.isdep = keras.layers.Dense(1)
|
||||
self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
|
||||
self.notdep_var = checkpointable.NoDependency(
|
||||
resource_variable_ops.ResourceVariable(1., name='notdep_var'))
|
||||
|
||||
m = Foo()
|
||||
self.assertEqual([m.isdep, m.notdep], m.layers)
|
||||
self.assertEqual(1, len(m._checkpoint_dependencies))
|
||||
self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
|
||||
self.assertEqual('notdep_var:0', m.notdep_var.name)
|
||||
|
||||
|
||||
class CustomCallModel(keras.Model):
|
||||
|
||||
|
@ -659,6 +659,31 @@ class CheckpointableBase(object):
|
||||
return {}
|
||||
|
||||
|
||||
class NoDependency(object):
|
||||
"""Allows attribute assignment to `Checkpointable` objects with no dependency.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
obj = Checkpointable()
|
||||
obj.has_dependency = tf.Variable(0., name="dep")
|
||||
obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
|
||||
assert obj.no_dependency.name == "nodep:0"
|
||||
```
|
||||
|
||||
`obj` in this example has a dependency on the variable "dep", and both
|
||||
attributes contain un-wrapped `Variable` objects.
|
||||
|
||||
`NoDependency` also works with `tf.keras.Model`, but only for checkpoint
|
||||
dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
|
||||
`Layer` to the attribute without a checkpoint dependency, but the `Model` will
|
||||
still track the `Layer` (so it will appear in `Model.layers`, and its
|
||||
variables will appear in `Model.variables`).
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class Checkpointable(CheckpointableBase):
|
||||
"""Manages dependencies on other objects.
|
||||
|
||||
@ -691,8 +716,11 @@ class Checkpointable(CheckpointableBase):
|
||||
"""Support self.foo = checkpointable syntax."""
|
||||
# Perform the attribute assignment, and potentially call other __setattr__
|
||||
# overrides such as that for tf.keras.Model.
|
||||
no_dependency = isinstance(value, NoDependency)
|
||||
if no_dependency:
|
||||
value = value.value
|
||||
super(Checkpointable, self).__setattr__(name, value)
|
||||
if isinstance(value, CheckpointableBase):
|
||||
if not no_dependency and isinstance(value, CheckpointableBase):
|
||||
self._track_checkpointable(
|
||||
value, name=name,
|
||||
# Allow the user to switch the Checkpointable which is tracked by this
|
||||
|
@ -34,6 +34,16 @@ class InterfaceTests(test.TestCase):
|
||||
root.leaf = duplicate_name_dep
|
||||
root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
|
||||
|
||||
def testNoDependency(self):
|
||||
root = checkpointable.Checkpointable()
|
||||
hasdep = checkpointable.Checkpointable()
|
||||
root.hasdep = hasdep
|
||||
nodep = checkpointable.Checkpointable()
|
||||
root.nodep = checkpointable.NoDependency(nodep)
|
||||
self.assertEqual(1, len(root._checkpoint_dependencies))
|
||||
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
||||
self.assertIs(root.hasdep, hasdep)
|
||||
self.assertIs(root.nodep, nodep)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1044,8 +1044,11 @@ class Checkpoint(checkpointable_lib.Checkpointable):
|
||||
if self._save_counter is None:
|
||||
# Initialized to 0 and incremented before saving.
|
||||
with ops.device("/cpu:0"):
|
||||
self._save_counter = add_variable(
|
||||
self, name="save_counter", initializer=0, dtype=dtypes.int64)
|
||||
# add_variable creates a dependency named "save_counter"; NoDependency
|
||||
# prevents creating a second dependency named "_save_counter".
|
||||
self._save_counter = checkpointable_lib.NoDependency(
|
||||
add_variable(self, name="save_counter", initializer=0,
|
||||
dtype=dtypes.int64))
|
||||
|
||||
@property
|
||||
def save_counter(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user