Factor a "capture_dependencies" scope out of Template.

I don't intend for this to get used much directly, but it's handy for Template-like frameworks (e.g. Sonnet), to let them re-enter the dependency-capturing part of Templates.

PiperOrigin-RevId: 200595624
This commit is contained in:
Allen Lavoie 2018-06-14 12:02:32 -07:00 committed by TensorFlower Gardener
parent 91ec6cc494
commit 7ccf1937b8
4 changed files with 121 additions and 68 deletions

View File

@ -20,6 +20,7 @@ Visualization and inspection:
@@object_metadata
Managing dependencies:
@@capture_dependencies
@@Checkpointable
@@CheckpointableObjectGraph
@@NoDependency
@ -43,9 +44,11 @@ from tensorflow.python.training.checkpointable.base import Checkpointable
from tensorflow.python.training.checkpointable.base import NoDependency
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
from tensorflow.python.training.checkpointable.util import capture_dependencies
from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.deprecation import deprecated
@ -295,66 +296,6 @@ class Template(checkpointable.CheckpointableBase):
# which is not the same as whether the scope has been created.
self._variables_created = False
def _checkpointable_custom_creator(self, next_creator, name, initial_value,
checkpointable_parent=None, **kwargs):
"""A variable creation hook which adds Checkpointable dependencies.
Set during the `Template`'s first wrapped function execution. Ensures that
(a) `Template` objects depend on `Template`s created inside them which
create variables, and (b) that any variables not in a more deeply nested
`Template` are added as dependencies directly.
The `checkpointable_parent` argument is passed between `Template` custom
creators but ignored when the variable object itself is created. This
argument indicates (if not `None`) that a more deeply nested `Template` has
already added the variable as a dependency, and that parent `Template`s
should add a dependency on that `Template` rather than on the variable
directly.
Args:
next_creator: See `variable_scope.variable_creator_scope`; the next
creator in the chain.
name: The (full, scope-influenced) name of the variable. The scope name
for the Template itself is stripped for the purposes of object-based
dependency tracking, but scopes within Templates are respected.
initial_value: See `variable_scope.variable_creator_scope`. Taken
explicitly so the argument can be re-named and used with
`Checkpointable._add_variable_with_custom_getter`.
checkpointable_parent: If not None, a more deeply nested Template object
to add a dependency on (rather than depending on the variable directly).
**kwargs: Passed through to the next creator.
Returns:
The output of `next_creator`: the fetched/created variable object.
"""
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
# we don't want to propagate.
return next_creator(
initial_value=initializer,
name=name,
**inner_kwargs)
if name.startswith(self._variable_scope.name):
scope_stripped_name = name[len(self._variable_scope.name) + 1:]
if not checkpointable_parent:
return self._add_variable_with_custom_getter(
initializer=initial_value,
name=scope_stripped_name,
getter=_call_next_creator_renaming_initializer,
# Disable error checking for Checkpointable. Exceptions are instead
# raised if necessary when the object-based saver tries to
# save/restore the object.
overwrite=True,
checkpointable_parent=self,
**kwargs)
else:
self._track_checkpointable(
checkpointable_parent,
name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access
len(self._variable_scope.name) + 1:],
overwrite=True)
return next_creator(name=name, initial_value=initial_value,
checkpointable_parent=self, **kwargs)
def _call_func(self, args, kwargs):
try:
vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
@ -365,8 +306,7 @@ class Template(checkpointable.CheckpointableBase):
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
with variable_scope.variable_creator_scope(
self._checkpointable_custom_creator):
with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
@ -634,8 +574,7 @@ class EagerTemplate(Template):
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
with variable_scope.variable_creator_scope(
self._checkpointable_custom_creator):
with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:

View File

@ -41,6 +41,7 @@ from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@ -564,6 +565,93 @@ def gather_initializers(root_checkpointable):
if hasattr(c, "initializer") and c.initializer is not None]
@tf_contextlib.contextmanager
def capture_dependencies(template):
"""Capture variables created within this scope as `Template` dependencies.
Requires that `template.variable_scope` is active.
This scope is intended as a compatibility measure, allowing a checkpointable
object to add dependencies on variables created in a block of code which is
not aware of object-based saving (and instead uses variable names
heavily). This is how `Template` objects add dependencies on variables and
sub-`Template`s. Where possible, use `tf.make_template` directly.
Args:
template: The `Template` object to register dependencies with.
Yields:
None (when used as a context manager).
"""
name_prefix = template.variable_scope.name
def _checkpointable_custom_creator(next_creator, name, initial_value,
checkpointable_parent=None, **kwargs):
"""A variable creation hook which adds Checkpointable dependencies.
Set for example during a `Template`'s first wrapped function
execution. Ensures that (a) `template` depends on any checkpointable
objects using their own `capture_dependencies` scope inside this scope which
create variables, and (b) that any variables not in a more deeply nested
scope are added as dependencies directly.
The `checkpointable_parent` argument is passed between custom creators but
ignored when the variable object itself is created. This argument indicates
(if not `None`) that a more deeply nested scope has already added the
variable as a dependency, and that parent scopes should add a dependency on
that object rather than on the variable directly.
Args:
next_creator: See `variable_scope.variable_creator_scope`; the next
creator in the chain.
name: The (full, scope-influenced) name of the variable. The `name_prefix`
itself is stripped for the purposes of object-based dependency tracking,
but scopes opened within this scope are respected.
initial_value: See `variable_scope.variable_creator_scope`. Taken
explicitly so the argument can be re-named and used with
`Checkpointable._add_variable_with_custom_getter`.
checkpointable_parent: If not None, a more deeply nested checkpointable
object and its name prefix which were passed to `capture_dependencies`
to add a dependency on (rather than depending on the variable directly).
**kwargs: Passed through to the next creator.
Returns:
The output of `next_creator`: the fetched/created variable object.
"""
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
# we don't want to propagate.
return next_creator(
initial_value=initializer,
name=name,
**inner_kwargs)
if name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:]
if not checkpointable_parent:
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
initializer=initial_value,
name=scope_stripped_name,
getter=_call_next_creator_renaming_initializer,
# Disable error checking for Checkpointable. Exceptions are instead
# raised if necessary when the object-based saver tries to
# save/restore the object.
overwrite=True,
checkpointable_parent=(template, name_prefix),
**kwargs)
else:
parent_object, parent_name_prefix = checkpointable_parent
template._track_checkpointable( # pylint: disable=protected-access
parent_object,
name=parent_name_prefix[len(name_prefix) + 1:],
overwrite=True)
return next_creator(
name=name, initial_value=initial_value,
checkpointable_parent=(template, name_prefix), **kwargs)
with variable_scope.variable_creator_scope(_checkpointable_custom_creator):
yield
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
def __init__(self, tensor, name):

View File

@ -1243,6 +1243,18 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
class _ManualScope(checkpointable.Checkpointable):
def __call__(self):
with variable_scope.variable_scope("ManualScope") as vs:
self.variable_scope = vs
with checkpointable_utils.capture_dependencies(template=self):
return self._build()
def _build(self):
return variable_scope.get_variable(name="in_manual_scope", shape=[])
class TemplateTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@ -1255,14 +1267,23 @@ class TemplateTests(test.TestCase):
v2 = variable_scope.get_variable(
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
use_resource=True)
return v, v + 1., v2
manual = _ManualScope()
return v, v + 1., v2, manual, manual()
save_template = template.make_template("s1", _templated)
v1_save, _, v2_save = save_template()
v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
six.assertCountEqual(
self,
[v1_save, v2_save, manual_scope, manual_scope_v, save_template],
checkpointable_utils.list_objects(save_template))
manual_dep, = manual_scope._checkpoint_dependencies
self.assertEqual("in_manual_scope", manual_dep.name)
self.assertIs(manual_scope_v, manual_dep.ref)
optimizer = adam.AdamOptimizer(0.0)
save_root = checkpointable_utils.Checkpoint(
my_template=save_template, optimizer=optimizer)
optimizer.minimize(v1_save.read_value)
self.evaluate([v.initializer for v in save_template.variables])
self.evaluate([v.initializer for v in optimizer.variables()])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
@ -1275,11 +1296,13 @@ class TemplateTests(test.TestCase):
load_root = checkpointable_utils.Checkpoint(
my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
var, var_plus_one, var2 = load_template()
var, var_plus_one, var2, _, _ = load_template()
load_optimizer.minimize(var.read_value)
self.assertEqual(2, len(load_template._checkpoint_dependencies))
self.assertEqual(3, len(load_template._checkpoint_dependencies))
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
self.assertEqual("ManualScope",
load_template._checkpoint_dependencies[2].name)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([12.], self.evaluate(var))
self.assertAllEqual([13.], self.evaluate(var_plus_one))