Add SaveableHook, a special SaveableObject which registers callbacks.
Registers a single constant tensor in order to conform to the SaveableObject API; I feel that's cleaner than special casing SaveableHook throughout the codebase. PiperOrigin-RevId: 280708433 Change-Id: I5872949eca35c7fe3dcc401c52a63b66a141d865
This commit is contained in:
parent
0c04770ad1
commit
5304e1240a
@ -17,6 +17,7 @@ py_library(
|
||||
srcs = ["functional_saver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saveable_hook",
|
||||
":saveable_object",
|
||||
":saveable_object_util",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
@ -31,6 +32,7 @@ cuda_py_test(
|
||||
],
|
||||
additional_deps = [
|
||||
":functional_saver",
|
||||
":saveable_hook",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
@ -41,6 +43,15 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saveable_hook",
|
||||
srcs = ["saveable_hook.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saveable_object_util",
|
||||
srcs = ["saveable_object_util.py"],
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_io_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.training.saving import saveable_hook
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.util import nest
|
||||
@ -130,15 +131,31 @@ class MultiDeviceSaver(object):
|
||||
|
||||
Args:
|
||||
saveable_objects: A list of `SaveableObject`s.
|
||||
Objects extending `SaveableObject` will be saved and restored, and
|
||||
objects extending `SaveableHook` will be called into at save and
|
||||
restore time.
|
||||
"""
|
||||
self._before_save_callbacks = []
|
||||
self._after_restore_callbacks = []
|
||||
|
||||
saveable_objects = list(saveable_objects)
|
||||
saveables_by_device = {}
|
||||
for saveable in saveable_objects:
|
||||
if not isinstance(saveable, saveable_object.SaveableObject):
|
||||
is_saveable = isinstance(saveable, saveable_object.SaveableObject)
|
||||
is_hook = isinstance(saveable, saveable_hook.SaveableHook)
|
||||
|
||||
if not is_saveable and not is_hook:
|
||||
raise ValueError(
|
||||
"Expected a dictionary of SaveableObjects, got {}."
|
||||
.format(saveable))
|
||||
|
||||
if is_hook:
|
||||
self._before_save_callbacks.append(saveable.before_save)
|
||||
self._after_restore_callbacks.append(saveable.after_restore)
|
||||
|
||||
if is_saveable:
|
||||
saveables_by_device.setdefault(saveable.device, []).append(saveable)
|
||||
|
||||
self._single_device_savers = {
|
||||
device: _SingleDeviceSaver(saveables)
|
||||
for device, saveables in saveables_by_device.items()}
|
||||
@ -182,6 +199,9 @@ class MultiDeviceSaver(object):
|
||||
Returns:
|
||||
An `Operation`, or None when executing eagerly.
|
||||
"""
|
||||
for callback in self._before_save_callbacks:
|
||||
callback()
|
||||
|
||||
# IMPLEMENTATION DETAILS: most clients should skip.
|
||||
#
|
||||
# Suffix for any well-formed "checkpoint_prefix", when sharded.
|
||||
@ -253,4 +273,8 @@ class MultiDeviceSaver(object):
|
||||
for device, saver in sorted(self._single_device_savers.items()):
|
||||
with ops.device(device):
|
||||
restore_ops.update(saver.restore(file_prefix))
|
||||
|
||||
for callback in self._after_restore_callbacks:
|
||||
callback()
|
||||
|
||||
return restore_ops
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.training.saving import functional_saver
|
||||
from tensorflow.python.training.saving import saveable_hook
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
|
||||
|
||||
@ -113,6 +114,35 @@ class SaverTest(test.TestCase):
|
||||
self.assertEqual(2., self.evaluate(v2))
|
||||
|
||||
|
||||
class SaveableHookTest(test.TestCase):
|
||||
|
||||
def test_callbacks_run(self):
|
||||
# Use dict because an int would be shadowed inside callback.
|
||||
called = {
|
||||
"save": 0,
|
||||
"restore": 0,
|
||||
}
|
||||
|
||||
class DummyHook(saveable_hook.SaveableHook):
|
||||
|
||||
def before_save(self):
|
||||
called["save"] += 1
|
||||
|
||||
def after_restore(self):
|
||||
called["restore"] += 1
|
||||
|
||||
saveable = DummyHook(name="dummy")
|
||||
|
||||
saver = functional_saver.MultiDeviceSaver([saveable])
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
|
||||
self.evaluate(saver.save(constant_op.constant(prefix)))
|
||||
self.assertEqual({"save": 1, "restore": 0}, called)
|
||||
|
||||
self.evaluate(saver.restore(prefix))
|
||||
self.assertEqual({"save": 1, "restore": 1}, called)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
||||
|
59
tensorflow/python/training/saving/saveable_hook.py
Normal file
59
tensorflow/python/training/saving/saveable_hook.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""SaveableHook, for running callbacks at save and restore time."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.training.tracking import base
|
||||
|
||||
|
||||
class SaveableHook(base.NoRestoreSaveable):
|
||||
"""Base class for running callbacks at Save/Restore time.
|
||||
|
||||
Subclasses should override one or both methods to modify or read variables
|
||||
during the saving process. No guarantees are made regarding the precedence
|
||||
of execution between multiple `SaveableHook` objects, but execution is
|
||||
guaranteed to occur before or after the respective event.
|
||||
|
||||
Users should emit the SaveableHook alongside other SaveableObjects, such as
|
||||
in Trackable._gather_saveables_for_checkpoint().
|
||||
|
||||
Saves a single constant in order to be compliant with the SaveableObject API.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Creates a `SaveableHook` object.
|
||||
|
||||
Args:
|
||||
name: the name to save the object under.
|
||||
"""
|
||||
super(SaveableHook, self).__init__(
|
||||
tensor=constant_op.constant(0),
|
||||
name=name,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.op.device
|
||||
|
||||
def before_save(self):
|
||||
"""This method will be called before iterating devices for saving."""
|
||||
pass
|
||||
|
||||
def after_restore(self):
|
||||
"""This method will be called after each device is restored."""
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user