Factor a "capture_dependencies" scope out of Template.
I don't intend for this to get used much directly, but it's handy for Template-like frameworks (e.g. Sonnet), to let them re-enter the dependency-capturing part of Templates. PiperOrigin-RevId: 200595624
This commit is contained in:
parent
91ec6cc494
commit
7ccf1937b8
@ -20,6 +20,7 @@ Visualization and inspection:
|
||||
@@object_metadata
|
||||
|
||||
Managing dependencies:
|
||||
@@capture_dependencies
|
||||
@@Checkpointable
|
||||
@@CheckpointableObjectGraph
|
||||
@@NoDependency
|
||||
@ -43,9 +44,11 @@ from tensorflow.python.training.checkpointable.base import Checkpointable
|
||||
from tensorflow.python.training.checkpointable.base import NoDependency
|
||||
from tensorflow.python.training.checkpointable.data_structures import List
|
||||
from tensorflow.python.training.checkpointable.data_structures import Mapping
|
||||
from tensorflow.python.training.checkpointable.util import capture_dependencies
|
||||
from tensorflow.python.training.checkpointable.util import list_objects
|
||||
from tensorflow.python.training.checkpointable.util import object_metadata
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(module_name=__name__)
|
||||
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
@ -295,66 +296,6 @@ class Template(checkpointable.CheckpointableBase):
|
||||
# which is not the same as whether the scope has been created.
|
||||
self._variables_created = False
|
||||
|
||||
def _checkpointable_custom_creator(self, next_creator, name, initial_value,
|
||||
checkpointable_parent=None, **kwargs):
|
||||
"""A variable creation hook which adds Checkpointable dependencies.
|
||||
|
||||
Set during the `Template`'s first wrapped function execution. Ensures that
|
||||
(a) `Template` objects depend on `Template`s created inside them which
|
||||
create variables, and (b) that any variables not in a more deeply nested
|
||||
`Template` are added as dependencies directly.
|
||||
|
||||
The `checkpointable_parent` argument is passed between `Template` custom
|
||||
creators but ignored when the variable object itself is created. This
|
||||
argument indicates (if not `None`) that a more deeply nested `Template` has
|
||||
already added the variable as a dependency, and that parent `Template`s
|
||||
should add a dependency on that `Template` rather than on the variable
|
||||
directly.
|
||||
|
||||
Args:
|
||||
next_creator: See `variable_scope.variable_creator_scope`; the next
|
||||
creator in the chain.
|
||||
name: The (full, scope-influenced) name of the variable. The scope name
|
||||
for the Template itself is stripped for the purposes of object-based
|
||||
dependency tracking, but scopes within Templates are respected.
|
||||
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
||||
explicitly so the argument can be re-named and used with
|
||||
`Checkpointable._add_variable_with_custom_getter`.
|
||||
checkpointable_parent: If not None, a more deeply nested Template object
|
||||
to add a dependency on (rather than depending on the variable directly).
|
||||
**kwargs: Passed through to the next creator.
|
||||
Returns:
|
||||
The output of `next_creator`: the fetched/created variable object.
|
||||
"""
|
||||
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
||||
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
||||
# we don't want to propagate.
|
||||
return next_creator(
|
||||
initial_value=initializer,
|
||||
name=name,
|
||||
**inner_kwargs)
|
||||
if name.startswith(self._variable_scope.name):
|
||||
scope_stripped_name = name[len(self._variable_scope.name) + 1:]
|
||||
if not checkpointable_parent:
|
||||
return self._add_variable_with_custom_getter(
|
||||
initializer=initial_value,
|
||||
name=scope_stripped_name,
|
||||
getter=_call_next_creator_renaming_initializer,
|
||||
# Disable error checking for Checkpointable. Exceptions are instead
|
||||
# raised if necessary when the object-based saver tries to
|
||||
# save/restore the object.
|
||||
overwrite=True,
|
||||
checkpointable_parent=self,
|
||||
**kwargs)
|
||||
else:
|
||||
self._track_checkpointable(
|
||||
checkpointable_parent,
|
||||
name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access
|
||||
len(self._variable_scope.name) + 1:],
|
||||
overwrite=True)
|
||||
return next_creator(name=name, initial_value=initial_value,
|
||||
checkpointable_parent=self, **kwargs)
|
||||
|
||||
def _call_func(self, args, kwargs):
|
||||
try:
|
||||
vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
@ -365,8 +306,7 @@ class Template(checkpointable.CheckpointableBase):
|
||||
else:
|
||||
# The first time we run, restore variables if necessary (via
|
||||
# Checkpointable).
|
||||
with variable_scope.variable_creator_scope(
|
||||
self._checkpointable_custom_creator):
|
||||
with checkpointable_util.capture_dependencies(template=self):
|
||||
result = self._func(*args, **kwargs)
|
||||
|
||||
if self._variables_created:
|
||||
@ -634,8 +574,7 @@ class EagerTemplate(Template):
|
||||
else:
|
||||
# The first time we run, restore variables if necessary (via
|
||||
# Checkpointable).
|
||||
with variable_scope.variable_creator_scope(
|
||||
self._checkpointable_custom_creator):
|
||||
with checkpointable_util.capture_dependencies(template=self):
|
||||
result = self._func(*args, **kwargs)
|
||||
|
||||
if self._variables_created:
|
||||
|
@ -41,6 +41,7 @@ from tensorflow.python.training import saveable_object as saveable_object_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.checkpointable import base as checkpointable_lib
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -564,6 +565,93 @@ def gather_initializers(root_checkpointable):
|
||||
if hasattr(c, "initializer") and c.initializer is not None]
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def capture_dependencies(template):
|
||||
"""Capture variables created within this scope as `Template` dependencies.
|
||||
|
||||
Requires that `template.variable_scope` is active.
|
||||
|
||||
This scope is intended as a compatibility measure, allowing a checkpointable
|
||||
object to add dependencies on variables created in a block of code which is
|
||||
not aware of object-based saving (and instead uses variable names
|
||||
heavily). This is how `Template` objects add dependencies on variables and
|
||||
sub-`Template`s. Where possible, use `tf.make_template` directly.
|
||||
|
||||
Args:
|
||||
template: The `Template` object to register dependencies with.
|
||||
|
||||
Yields:
|
||||
None (when used as a context manager).
|
||||
"""
|
||||
name_prefix = template.variable_scope.name
|
||||
|
||||
def _checkpointable_custom_creator(next_creator, name, initial_value,
|
||||
checkpointable_parent=None, **kwargs):
|
||||
"""A variable creation hook which adds Checkpointable dependencies.
|
||||
|
||||
Set for example during a `Template`'s first wrapped function
|
||||
execution. Ensures that (a) `template` depends on any checkpointable
|
||||
objects using their own `capture_dependencies` scope inside this scope which
|
||||
create variables, and (b) that any variables not in a more deeply nested
|
||||
scope are added as dependencies directly.
|
||||
|
||||
The `checkpointable_parent` argument is passed between custom creators but
|
||||
ignored when the variable object itself is created. This argument indicates
|
||||
(if not `None`) that a more deeply nested scope has already added the
|
||||
variable as a dependency, and that parent scopes should add a dependency on
|
||||
that object rather than on the variable directly.
|
||||
|
||||
Args:
|
||||
next_creator: See `variable_scope.variable_creator_scope`; the next
|
||||
creator in the chain.
|
||||
name: The (full, scope-influenced) name of the variable. The `name_prefix`
|
||||
itself is stripped for the purposes of object-based dependency tracking,
|
||||
but scopes opened within this scope are respected.
|
||||
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
||||
explicitly so the argument can be re-named and used with
|
||||
`Checkpointable._add_variable_with_custom_getter`.
|
||||
checkpointable_parent: If not None, a more deeply nested checkpointable
|
||||
object and its name prefix which were passed to `capture_dependencies`
|
||||
to add a dependency on (rather than depending on the variable directly).
|
||||
**kwargs: Passed through to the next creator.
|
||||
|
||||
Returns:
|
||||
The output of `next_creator`: the fetched/created variable object.
|
||||
"""
|
||||
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
||||
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
||||
# we don't want to propagate.
|
||||
return next_creator(
|
||||
initial_value=initializer,
|
||||
name=name,
|
||||
**inner_kwargs)
|
||||
if name.startswith(name_prefix):
|
||||
scope_stripped_name = name[len(name_prefix) + 1:]
|
||||
if not checkpointable_parent:
|
||||
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
|
||||
initializer=initial_value,
|
||||
name=scope_stripped_name,
|
||||
getter=_call_next_creator_renaming_initializer,
|
||||
# Disable error checking for Checkpointable. Exceptions are instead
|
||||
# raised if necessary when the object-based saver tries to
|
||||
# save/restore the object.
|
||||
overwrite=True,
|
||||
checkpointable_parent=(template, name_prefix),
|
||||
**kwargs)
|
||||
else:
|
||||
parent_object, parent_name_prefix = checkpointable_parent
|
||||
template._track_checkpointable( # pylint: disable=protected-access
|
||||
parent_object,
|
||||
name=parent_name_prefix[len(name_prefix) + 1:],
|
||||
overwrite=True)
|
||||
return next_creator(
|
||||
name=name, initial_value=initial_value,
|
||||
checkpointable_parent=(template, name_prefix), **kwargs)
|
||||
|
||||
with variable_scope.variable_creator_scope(_checkpointable_custom_creator):
|
||||
yield
|
||||
|
||||
|
||||
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||
|
||||
def __init__(self, tensor, name):
|
||||
|
@ -1243,6 +1243,18 @@ class CheckpointingTests(test.TestCase):
|
||||
self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
|
||||
|
||||
|
||||
class _ManualScope(checkpointable.Checkpointable):
|
||||
|
||||
def __call__(self):
|
||||
with variable_scope.variable_scope("ManualScope") as vs:
|
||||
self.variable_scope = vs
|
||||
with checkpointable_utils.capture_dependencies(template=self):
|
||||
return self._build()
|
||||
|
||||
def _build(self):
|
||||
return variable_scope.get_variable(name="in_manual_scope", shape=[])
|
||||
|
||||
|
||||
class TemplateTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@ -1255,14 +1267,23 @@ class TemplateTests(test.TestCase):
|
||||
v2 = variable_scope.get_variable(
|
||||
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
|
||||
use_resource=True)
|
||||
return v, v + 1., v2
|
||||
manual = _ManualScope()
|
||||
return v, v + 1., v2, manual, manual()
|
||||
|
||||
save_template = template.make_template("s1", _templated)
|
||||
v1_save, _, v2_save = save_template()
|
||||
v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
[v1_save, v2_save, manual_scope, manual_scope_v, save_template],
|
||||
checkpointable_utils.list_objects(save_template))
|
||||
manual_dep, = manual_scope._checkpoint_dependencies
|
||||
self.assertEqual("in_manual_scope", manual_dep.name)
|
||||
self.assertIs(manual_scope_v, manual_dep.ref)
|
||||
optimizer = adam.AdamOptimizer(0.0)
|
||||
save_root = checkpointable_utils.Checkpoint(
|
||||
my_template=save_template, optimizer=optimizer)
|
||||
optimizer.minimize(v1_save.read_value)
|
||||
self.evaluate([v.initializer for v in save_template.variables])
|
||||
self.evaluate([v.initializer for v in optimizer.variables()])
|
||||
self.evaluate(v1_save.assign([12.]))
|
||||
self.evaluate(v2_save.assign([14.]))
|
||||
@ -1275,11 +1296,13 @@ class TemplateTests(test.TestCase):
|
||||
load_root = checkpointable_utils.Checkpoint(
|
||||
my_template=load_template, optimizer=load_optimizer)
|
||||
status = load_root.restore(save_path)
|
||||
var, var_plus_one, var2 = load_template()
|
||||
var, var_plus_one, var2, _, _ = load_template()
|
||||
load_optimizer.minimize(var.read_value)
|
||||
self.assertEqual(2, len(load_template._checkpoint_dependencies))
|
||||
self.assertEqual(3, len(load_template._checkpoint_dependencies))
|
||||
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
|
||||
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
|
||||
self.assertEqual("ManualScope",
|
||||
load_template._checkpoint_dependencies[2].name)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.assertAllEqual([12.], self.evaluate(var))
|
||||
self.assertAllEqual([13.], self.evaluate(var_plus_one))
|
||||
|
Loading…
Reference in New Issue
Block a user