Add SaveableHook, a special SaveableObject which registers callbacks.

Registers a single constant tensor in order to conform to the SaveableObject API; I feel that's cleaner than special casing SaveableHook throughout the codebase.

PiperOrigin-RevId: 280708433
Change-Id: I5872949eca35c7fe3dcc401c52a63b66a141d865
This commit is contained in:
Revan Sopher 2019-11-15 12:08:03 -08:00 committed by TensorFlower Gardener
parent 0c04770ad1
commit 5304e1240a
4 changed files with 126 additions and 2 deletions

View File

@ -17,6 +17,7 @@ py_library(
srcs = ["functional_saver.py"],
srcs_version = "PY2AND3",
deps = [
":saveable_hook",
":saveable_object",
":saveable_object_util",
"//tensorflow/python/eager:def_function",
@ -31,6 +32,7 @@ cuda_py_test(
],
additional_deps = [
":functional_saver",
":saveable_hook",
"//tensorflow/python/eager:test",
],
)
@ -41,6 +43,15 @@ py_library(
srcs_version = "PY2AND3",
)
py_library(
name = "saveable_hook",
srcs = ["saveable_hook.py"],
deps = [
"//tensorflow/python:constant_op",
"//tensorflow/python/training/tracking:base",
],
)
py_library(
name = "saveable_object_util",
srcs = ["saveable_object_util.py"],

View File

@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import nest
@ -130,15 +131,31 @@ class MultiDeviceSaver(object):
Args:
saveable_objects: A list of `SaveableObject`s.
Objects extending `SaveableObject` will be saved and restored, and
objects extending `SaveableHook` will be called into at save and
restore time.
"""
self._before_save_callbacks = []
self._after_restore_callbacks = []
saveable_objects = list(saveable_objects)
saveables_by_device = {}
for saveable in saveable_objects:
if not isinstance(saveable, saveable_object.SaveableObject):
is_saveable = isinstance(saveable, saveable_object.SaveableObject)
is_hook = isinstance(saveable, saveable_hook.SaveableHook)
if not is_saveable and not is_hook:
raise ValueError(
"Expected a dictionary of SaveableObjects, got {}."
.format(saveable))
saveables_by_device.setdefault(saveable.device, []).append(saveable)
if is_hook:
self._before_save_callbacks.append(saveable.before_save)
self._after_restore_callbacks.append(saveable.after_restore)
if is_saveable:
saveables_by_device.setdefault(saveable.device, []).append(saveable)
self._single_device_savers = {
device: _SingleDeviceSaver(saveables)
for device, saveables in saveables_by_device.items()}
@ -182,6 +199,9 @@ class MultiDeviceSaver(object):
Returns:
An `Operation`, or None when executing eagerly.
"""
for callback in self._before_save_callbacks:
callback()
# IMPLEMENTATION DETAILS: most clients should skip.
#
# Suffix for any well-formed "checkpoint_prefix", when sharded.
@ -253,4 +273,8 @@ class MultiDeviceSaver(object):
for device, saver in sorted(self._single_device_savers.items()):
with ops.device(device):
restore_ops.update(saver.restore(file_prefix))
for callback in self._after_restore_callbacks:
callback()
return restore_ops

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object_util
@ -113,6 +114,35 @@ class SaverTest(test.TestCase):
self.assertEqual(2., self.evaluate(v2))
class SaveableHookTest(test.TestCase):
def test_callbacks_run(self):
# Use dict because an int would be shadowed inside callback.
called = {
"save": 0,
"restore": 0,
}
class DummyHook(saveable_hook.SaveableHook):
def before_save(self):
called["save"] += 1
def after_restore(self):
called["restore"] += 1
saveable = DummyHook(name="dummy")
saver = functional_saver.MultiDeviceSaver([saveable])
prefix = os.path.join(self.get_temp_dir(), "ckpt")
self.evaluate(saver.save(constant_op.constant(prefix)))
self.assertEqual({"save": 1, "restore": 0}, called)
self.evaluate(saver.restore(prefix))
self.assertEqual({"save": 1, "restore": 1}, called)
if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(device_count={"CPU": 3}))

View File

@ -0,0 +1,59 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SaveableHook, for running callbacks at save and restore time."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.training.tracking import base
class SaveableHook(base.NoRestoreSaveable):
"""Base class for running callbacks at Save/Restore time.
Subclasses should override one or both methods to modify or read variables
during the saving process. No guarantees are made regarding the precedence
of execution between multiple `SaveableHook` objects, but execution is
guaranteed to occur before or after the respective event.
Users should emit the SaveableHook alongside other SaveableObjects, such as
in Trackable._gather_saveables_for_checkpoint().
Saves a single constant in order to be compliant with the SaveableObject API.
"""
def __init__(self, name):
"""Creates a `SaveableHook` object.
Args:
name: the name to save the object under.
"""
super(SaveableHook, self).__init__(
tensor=constant_op.constant(0),
name=name,
)
@property
def device(self):
return self.op.device
def before_save(self):
"""This method will be called before iterating devices for saving."""
pass
def after_restore(self):
"""This method will be called after each device is restored."""
pass