Factor a "capture_dependencies" scope out of Template.
I don't intend for this to get used much directly, but it's handy for Template-like frameworks (e.g. Sonnet), to let them re-enter the dependency-capturing part of Templates. PiperOrigin-RevId: 200595624
This commit is contained in:
parent
91ec6cc494
commit
7ccf1937b8
@ -20,6 +20,7 @@ Visualization and inspection:
|
|||||||
@@object_metadata
|
@@object_metadata
|
||||||
|
|
||||||
Managing dependencies:
|
Managing dependencies:
|
||||||
|
@@capture_dependencies
|
||||||
@@Checkpointable
|
@@Checkpointable
|
||||||
@@CheckpointableObjectGraph
|
@@CheckpointableObjectGraph
|
||||||
@@NoDependency
|
@@NoDependency
|
||||||
@ -43,9 +44,11 @@ from tensorflow.python.training.checkpointable.base import Checkpointable
|
|||||||
from tensorflow.python.training.checkpointable.base import NoDependency
|
from tensorflow.python.training.checkpointable.base import NoDependency
|
||||||
from tensorflow.python.training.checkpointable.data_structures import List
|
from tensorflow.python.training.checkpointable.data_structures import List
|
||||||
from tensorflow.python.training.checkpointable.data_structures import Mapping
|
from tensorflow.python.training.checkpointable.data_structures import Mapping
|
||||||
|
from tensorflow.python.training.checkpointable.util import capture_dependencies
|
||||||
from tensorflow.python.training.checkpointable.util import list_objects
|
from tensorflow.python.training.checkpointable.util import list_objects
|
||||||
from tensorflow.python.training.checkpointable.util import object_metadata
|
from tensorflow.python.training.checkpointable.util import object_metadata
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
remove_undocumented(module_name=__name__)
|
remove_undocumented(module_name=__name__)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.checkpointable import base as checkpointable
|
from tensorflow.python.training.checkpointable import base as checkpointable
|
||||||
|
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util.deprecation import deprecated
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
@ -295,66 +296,6 @@ class Template(checkpointable.CheckpointableBase):
|
|||||||
# which is not the same as whether the scope has been created.
|
# which is not the same as whether the scope has been created.
|
||||||
self._variables_created = False
|
self._variables_created = False
|
||||||
|
|
||||||
def _checkpointable_custom_creator(self, next_creator, name, initial_value,
|
|
||||||
checkpointable_parent=None, **kwargs):
|
|
||||||
"""A variable creation hook which adds Checkpointable dependencies.
|
|
||||||
|
|
||||||
Set during the `Template`'s first wrapped function execution. Ensures that
|
|
||||||
(a) `Template` objects depend on `Template`s created inside them which
|
|
||||||
create variables, and (b) that any variables not in a more deeply nested
|
|
||||||
`Template` are added as dependencies directly.
|
|
||||||
|
|
||||||
The `checkpointable_parent` argument is passed between `Template` custom
|
|
||||||
creators but ignored when the variable object itself is created. This
|
|
||||||
argument indicates (if not `None`) that a more deeply nested `Template` has
|
|
||||||
already added the variable as a dependency, and that parent `Template`s
|
|
||||||
should add a dependency on that `Template` rather than on the variable
|
|
||||||
directly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
next_creator: See `variable_scope.variable_creator_scope`; the next
|
|
||||||
creator in the chain.
|
|
||||||
name: The (full, scope-influenced) name of the variable. The scope name
|
|
||||||
for the Template itself is stripped for the purposes of object-based
|
|
||||||
dependency tracking, but scopes within Templates are respected.
|
|
||||||
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
|
||||||
explicitly so the argument can be re-named and used with
|
|
||||||
`Checkpointable._add_variable_with_custom_getter`.
|
|
||||||
checkpointable_parent: If not None, a more deeply nested Template object
|
|
||||||
to add a dependency on (rather than depending on the variable directly).
|
|
||||||
**kwargs: Passed through to the next creator.
|
|
||||||
Returns:
|
|
||||||
The output of `next_creator`: the fetched/created variable object.
|
|
||||||
"""
|
|
||||||
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
|
||||||
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
|
||||||
# we don't want to propagate.
|
|
||||||
return next_creator(
|
|
||||||
initial_value=initializer,
|
|
||||||
name=name,
|
|
||||||
**inner_kwargs)
|
|
||||||
if name.startswith(self._variable_scope.name):
|
|
||||||
scope_stripped_name = name[len(self._variable_scope.name) + 1:]
|
|
||||||
if not checkpointable_parent:
|
|
||||||
return self._add_variable_with_custom_getter(
|
|
||||||
initializer=initial_value,
|
|
||||||
name=scope_stripped_name,
|
|
||||||
getter=_call_next_creator_renaming_initializer,
|
|
||||||
# Disable error checking for Checkpointable. Exceptions are instead
|
|
||||||
# raised if necessary when the object-based saver tries to
|
|
||||||
# save/restore the object.
|
|
||||||
overwrite=True,
|
|
||||||
checkpointable_parent=self,
|
|
||||||
**kwargs)
|
|
||||||
else:
|
|
||||||
self._track_checkpointable(
|
|
||||||
checkpointable_parent,
|
|
||||||
name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access
|
|
||||||
len(self._variable_scope.name) + 1:],
|
|
||||||
overwrite=True)
|
|
||||||
return next_creator(name=name, initial_value=initial_value,
|
|
||||||
checkpointable_parent=self, **kwargs)
|
|
||||||
|
|
||||||
def _call_func(self, args, kwargs):
|
def _call_func(self, args, kwargs):
|
||||||
try:
|
try:
|
||||||
vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
@ -365,8 +306,7 @@ class Template(checkpointable.CheckpointableBase):
|
|||||||
else:
|
else:
|
||||||
# The first time we run, restore variables if necessary (via
|
# The first time we run, restore variables if necessary (via
|
||||||
# Checkpointable).
|
# Checkpointable).
|
||||||
with variable_scope.variable_creator_scope(
|
with checkpointable_util.capture_dependencies(template=self):
|
||||||
self._checkpointable_custom_creator):
|
|
||||||
result = self._func(*args, **kwargs)
|
result = self._func(*args, **kwargs)
|
||||||
|
|
||||||
if self._variables_created:
|
if self._variables_created:
|
||||||
@ -634,8 +574,7 @@ class EagerTemplate(Template):
|
|||||||
else:
|
else:
|
||||||
# The first time we run, restore variables if necessary (via
|
# The first time we run, restore variables if necessary (via
|
||||||
# Checkpointable).
|
# Checkpointable).
|
||||||
with variable_scope.variable_creator_scope(
|
with checkpointable_util.capture_dependencies(template=self):
|
||||||
self._checkpointable_custom_creator):
|
|
||||||
result = self._func(*args, **kwargs)
|
result = self._func(*args, **kwargs)
|
||||||
|
|
||||||
if self._variables_created:
|
if self._variables_created:
|
||||||
|
@ -41,6 +41,7 @@ from tensorflow.python.training import saveable_object as saveable_object_lib
|
|||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training.checkpointable import base as checkpointable_lib
|
from tensorflow.python.training.checkpointable import base as checkpointable_lib
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -564,6 +565,93 @@ def gather_initializers(root_checkpointable):
|
|||||||
if hasattr(c, "initializer") and c.initializer is not None]
|
if hasattr(c, "initializer") and c.initializer is not None]
|
||||||
|
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
|
def capture_dependencies(template):
|
||||||
|
"""Capture variables created within this scope as `Template` dependencies.
|
||||||
|
|
||||||
|
Requires that `template.variable_scope` is active.
|
||||||
|
|
||||||
|
This scope is intended as a compatibility measure, allowing a checkpointable
|
||||||
|
object to add dependencies on variables created in a block of code which is
|
||||||
|
not aware of object-based saving (and instead uses variable names
|
||||||
|
heavily). This is how `Template` objects add dependencies on variables and
|
||||||
|
sub-`Template`s. Where possible, use `tf.make_template` directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The `Template` object to register dependencies with.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None (when used as a context manager).
|
||||||
|
"""
|
||||||
|
name_prefix = template.variable_scope.name
|
||||||
|
|
||||||
|
def _checkpointable_custom_creator(next_creator, name, initial_value,
|
||||||
|
checkpointable_parent=None, **kwargs):
|
||||||
|
"""A variable creation hook which adds Checkpointable dependencies.
|
||||||
|
|
||||||
|
Set for example during a `Template`'s first wrapped function
|
||||||
|
execution. Ensures that (a) `template` depends on any checkpointable
|
||||||
|
objects using their own `capture_dependencies` scope inside this scope which
|
||||||
|
create variables, and (b) that any variables not in a more deeply nested
|
||||||
|
scope are added as dependencies directly.
|
||||||
|
|
||||||
|
The `checkpointable_parent` argument is passed between custom creators but
|
||||||
|
ignored when the variable object itself is created. This argument indicates
|
||||||
|
(if not `None`) that a more deeply nested scope has already added the
|
||||||
|
variable as a dependency, and that parent scopes should add a dependency on
|
||||||
|
that object rather than on the variable directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
next_creator: See `variable_scope.variable_creator_scope`; the next
|
||||||
|
creator in the chain.
|
||||||
|
name: The (full, scope-influenced) name of the variable. The `name_prefix`
|
||||||
|
itself is stripped for the purposes of object-based dependency tracking,
|
||||||
|
but scopes opened within this scope are respected.
|
||||||
|
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
||||||
|
explicitly so the argument can be re-named and used with
|
||||||
|
`Checkpointable._add_variable_with_custom_getter`.
|
||||||
|
checkpointable_parent: If not None, a more deeply nested checkpointable
|
||||||
|
object and its name prefix which were passed to `capture_dependencies`
|
||||||
|
to add a dependency on (rather than depending on the variable directly).
|
||||||
|
**kwargs: Passed through to the next creator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of `next_creator`: the fetched/created variable object.
|
||||||
|
"""
|
||||||
|
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
||||||
|
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
||||||
|
# we don't want to propagate.
|
||||||
|
return next_creator(
|
||||||
|
initial_value=initializer,
|
||||||
|
name=name,
|
||||||
|
**inner_kwargs)
|
||||||
|
if name.startswith(name_prefix):
|
||||||
|
scope_stripped_name = name[len(name_prefix) + 1:]
|
||||||
|
if not checkpointable_parent:
|
||||||
|
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
|
||||||
|
initializer=initial_value,
|
||||||
|
name=scope_stripped_name,
|
||||||
|
getter=_call_next_creator_renaming_initializer,
|
||||||
|
# Disable error checking for Checkpointable. Exceptions are instead
|
||||||
|
# raised if necessary when the object-based saver tries to
|
||||||
|
# save/restore the object.
|
||||||
|
overwrite=True,
|
||||||
|
checkpointable_parent=(template, name_prefix),
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
parent_object, parent_name_prefix = checkpointable_parent
|
||||||
|
template._track_checkpointable( # pylint: disable=protected-access
|
||||||
|
parent_object,
|
||||||
|
name=parent_name_prefix[len(name_prefix) + 1:],
|
||||||
|
overwrite=True)
|
||||||
|
return next_creator(
|
||||||
|
name=name, initial_value=initial_value,
|
||||||
|
checkpointable_parent=(template, name_prefix), **kwargs)
|
||||||
|
|
||||||
|
with variable_scope.variable_creator_scope(_checkpointable_custom_creator):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||||
|
|
||||||
def __init__(self, tensor, name):
|
def __init__(self, tensor, name):
|
||||||
|
@ -1243,6 +1243,18 @@ class CheckpointingTests(test.TestCase):
|
|||||||
self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
|
self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
|
||||||
|
|
||||||
|
|
||||||
|
class _ManualScope(checkpointable.Checkpointable):
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
with variable_scope.variable_scope("ManualScope") as vs:
|
||||||
|
self.variable_scope = vs
|
||||||
|
with checkpointable_utils.capture_dependencies(template=self):
|
||||||
|
return self._build()
|
||||||
|
|
||||||
|
def _build(self):
|
||||||
|
return variable_scope.get_variable(name="in_manual_scope", shape=[])
|
||||||
|
|
||||||
|
|
||||||
class TemplateTests(test.TestCase):
|
class TemplateTests(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
@ -1255,14 +1267,23 @@ class TemplateTests(test.TestCase):
|
|||||||
v2 = variable_scope.get_variable(
|
v2 = variable_scope.get_variable(
|
||||||
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
|
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
|
||||||
use_resource=True)
|
use_resource=True)
|
||||||
return v, v + 1., v2
|
manual = _ManualScope()
|
||||||
|
return v, v + 1., v2, manual, manual()
|
||||||
|
|
||||||
save_template = template.make_template("s1", _templated)
|
save_template = template.make_template("s1", _templated)
|
||||||
v1_save, _, v2_save = save_template()
|
v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
|
||||||
|
six.assertCountEqual(
|
||||||
|
self,
|
||||||
|
[v1_save, v2_save, manual_scope, manual_scope_v, save_template],
|
||||||
|
checkpointable_utils.list_objects(save_template))
|
||||||
|
manual_dep, = manual_scope._checkpoint_dependencies
|
||||||
|
self.assertEqual("in_manual_scope", manual_dep.name)
|
||||||
|
self.assertIs(manual_scope_v, manual_dep.ref)
|
||||||
optimizer = adam.AdamOptimizer(0.0)
|
optimizer = adam.AdamOptimizer(0.0)
|
||||||
save_root = checkpointable_utils.Checkpoint(
|
save_root = checkpointable_utils.Checkpoint(
|
||||||
my_template=save_template, optimizer=optimizer)
|
my_template=save_template, optimizer=optimizer)
|
||||||
optimizer.minimize(v1_save.read_value)
|
optimizer.minimize(v1_save.read_value)
|
||||||
|
self.evaluate([v.initializer for v in save_template.variables])
|
||||||
self.evaluate([v.initializer for v in optimizer.variables()])
|
self.evaluate([v.initializer for v in optimizer.variables()])
|
||||||
self.evaluate(v1_save.assign([12.]))
|
self.evaluate(v1_save.assign([12.]))
|
||||||
self.evaluate(v2_save.assign([14.]))
|
self.evaluate(v2_save.assign([14.]))
|
||||||
@ -1275,11 +1296,13 @@ class TemplateTests(test.TestCase):
|
|||||||
load_root = checkpointable_utils.Checkpoint(
|
load_root = checkpointable_utils.Checkpoint(
|
||||||
my_template=load_template, optimizer=load_optimizer)
|
my_template=load_template, optimizer=load_optimizer)
|
||||||
status = load_root.restore(save_path)
|
status = load_root.restore(save_path)
|
||||||
var, var_plus_one, var2 = load_template()
|
var, var_plus_one, var2, _, _ = load_template()
|
||||||
load_optimizer.minimize(var.read_value)
|
load_optimizer.minimize(var.read_value)
|
||||||
self.assertEqual(2, len(load_template._checkpoint_dependencies))
|
self.assertEqual(3, len(load_template._checkpoint_dependencies))
|
||||||
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
|
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
|
||||||
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
|
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
|
||||||
|
self.assertEqual("ManualScope",
|
||||||
|
load_template._checkpoint_dependencies[2].name)
|
||||||
status.assert_consumed().run_restore_ops()
|
status.assert_consumed().run_restore_ops()
|
||||||
self.assertAllEqual([12.], self.evaluate(var))
|
self.assertAllEqual([12.], self.evaluate(var))
|
||||||
self.assertAllEqual([13.], self.evaluate(var_plus_one))
|
self.assertAllEqual([13.], self.evaluate(var_plus_one))
|
||||||
|
Loading…
Reference in New Issue
Block a user