Add SaveableHook, a special SaveableObject which registers callbacks.
Registers a single constant tensor in order to conform to the SaveableObject API; I feel that's cleaner than special casing SaveableHook throughout the codebase. PiperOrigin-RevId: 280708433 Change-Id: I5872949eca35c7fe3dcc401c52a63b66a141d865
This commit is contained in:
parent
0c04770ad1
commit
5304e1240a
@ -17,6 +17,7 @@ py_library(
|
|||||||
srcs = ["functional_saver.py"],
|
srcs = ["functional_saver.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":saveable_hook",
|
||||||
":saveable_object",
|
":saveable_object",
|
||||||
":saveable_object_util",
|
":saveable_object_util",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
@ -31,6 +32,7 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":functional_saver",
|
":functional_saver",
|
||||||
|
":saveable_hook",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -41,6 +43,15 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "saveable_hook",
|
||||||
|
srcs = ["saveable_hook.py"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python/training/tracking:base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "saveable_object_util",
|
name = "saveable_object_util",
|
||||||
srcs = ["saveable_object_util.py"],
|
srcs = ["saveable_object_util.py"],
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gen_io_ops
|
from tensorflow.python.ops import gen_io_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
|
from tensorflow.python.training.saving import saveable_hook
|
||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -130,15 +131,31 @@ class MultiDeviceSaver(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
saveable_objects: A list of `SaveableObject`s.
|
saveable_objects: A list of `SaveableObject`s.
|
||||||
|
Objects extending `SaveableObject` will be saved and restored, and
|
||||||
|
objects extending `SaveableHook` will be called into at save and
|
||||||
|
restore time.
|
||||||
"""
|
"""
|
||||||
|
self._before_save_callbacks = []
|
||||||
|
self._after_restore_callbacks = []
|
||||||
|
|
||||||
saveable_objects = list(saveable_objects)
|
saveable_objects = list(saveable_objects)
|
||||||
saveables_by_device = {}
|
saveables_by_device = {}
|
||||||
for saveable in saveable_objects:
|
for saveable in saveable_objects:
|
||||||
if not isinstance(saveable, saveable_object.SaveableObject):
|
is_saveable = isinstance(saveable, saveable_object.SaveableObject)
|
||||||
|
is_hook = isinstance(saveable, saveable_hook.SaveableHook)
|
||||||
|
|
||||||
|
if not is_saveable and not is_hook:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected a dictionary of SaveableObjects, got {}."
|
"Expected a dictionary of SaveableObjects, got {}."
|
||||||
.format(saveable))
|
.format(saveable))
|
||||||
saveables_by_device.setdefault(saveable.device, []).append(saveable)
|
|
||||||
|
if is_hook:
|
||||||
|
self._before_save_callbacks.append(saveable.before_save)
|
||||||
|
self._after_restore_callbacks.append(saveable.after_restore)
|
||||||
|
|
||||||
|
if is_saveable:
|
||||||
|
saveables_by_device.setdefault(saveable.device, []).append(saveable)
|
||||||
|
|
||||||
self._single_device_savers = {
|
self._single_device_savers = {
|
||||||
device: _SingleDeviceSaver(saveables)
|
device: _SingleDeviceSaver(saveables)
|
||||||
for device, saveables in saveables_by_device.items()}
|
for device, saveables in saveables_by_device.items()}
|
||||||
@ -182,6 +199,9 @@ class MultiDeviceSaver(object):
|
|||||||
Returns:
|
Returns:
|
||||||
An `Operation`, or None when executing eagerly.
|
An `Operation`, or None when executing eagerly.
|
||||||
"""
|
"""
|
||||||
|
for callback in self._before_save_callbacks:
|
||||||
|
callback()
|
||||||
|
|
||||||
# IMPLEMENTATION DETAILS: most clients should skip.
|
# IMPLEMENTATION DETAILS: most clients should skip.
|
||||||
#
|
#
|
||||||
# Suffix for any well-formed "checkpoint_prefix", when sharded.
|
# Suffix for any well-formed "checkpoint_prefix", when sharded.
|
||||||
@ -253,4 +273,8 @@ class MultiDeviceSaver(object):
|
|||||||
for device, saver in sorted(self._single_device_savers.items()):
|
for device, saver in sorted(self._single_device_savers.items()):
|
||||||
with ops.device(device):
|
with ops.device(device):
|
||||||
restore_ops.update(saver.restore(file_prefix))
|
restore_ops.update(saver.restore(file_prefix))
|
||||||
|
|
||||||
|
for callback in self._after_restore_callbacks:
|
||||||
|
callback()
|
||||||
|
|
||||||
return restore_ops
|
return restore_ops
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
|
from tensorflow.python.training.saving import saveable_hook
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
|
|
||||||
|
|
||||||
@ -113,6 +114,35 @@ class SaverTest(test.TestCase):
|
|||||||
self.assertEqual(2., self.evaluate(v2))
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
|
|
||||||
|
class SaveableHookTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_callbacks_run(self):
|
||||||
|
# Use dict because an int would be shadowed inside callback.
|
||||||
|
called = {
|
||||||
|
"save": 0,
|
||||||
|
"restore": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyHook(saveable_hook.SaveableHook):
|
||||||
|
|
||||||
|
def before_save(self):
|
||||||
|
called["save"] += 1
|
||||||
|
|
||||||
|
def after_restore(self):
|
||||||
|
called["restore"] += 1
|
||||||
|
|
||||||
|
saveable = DummyHook(name="dummy")
|
||||||
|
|
||||||
|
saver = functional_saver.MultiDeviceSaver([saveable])
|
||||||
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
|
|
||||||
|
self.evaluate(saver.save(constant_op.constant(prefix)))
|
||||||
|
self.assertEqual({"save": 1, "restore": 0}, called)
|
||||||
|
|
||||||
|
self.evaluate(saver.restore(prefix))
|
||||||
|
self.assertEqual({"save": 1, "restore": 1}, called)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution(
|
ops.enable_eager_execution(
|
||||||
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
config=config_pb2.ConfigProto(device_count={"CPU": 3}))
|
||||||
|
59
tensorflow/python/training/saving/saveable_hook.py
Normal file
59
tensorflow/python/training/saving/saveable_hook.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""SaveableHook, for running callbacks at save and restore time."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.training.tracking import base
|
||||||
|
|
||||||
|
|
||||||
|
class SaveableHook(base.NoRestoreSaveable):
|
||||||
|
"""Base class for running callbacks at Save/Restore time.
|
||||||
|
|
||||||
|
Subclasses should override one or both methods to modify or read variables
|
||||||
|
during the saving process. No guarantees are made regarding the precedence
|
||||||
|
of execution between multiple `SaveableHook` objects, but execution is
|
||||||
|
guaranteed to occur before or after the respective event.
|
||||||
|
|
||||||
|
Users should emit the SaveableHook alongside other SaveableObjects, such as
|
||||||
|
in Trackable._gather_saveables_for_checkpoint().
|
||||||
|
|
||||||
|
Saves a single constant in order to be compliant with the SaveableObject API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
"""Creates a `SaveableHook` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name to save the object under.
|
||||||
|
"""
|
||||||
|
super(SaveableHook, self).__init__(
|
||||||
|
tensor=constant_op.constant(0),
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.op.device
|
||||||
|
|
||||||
|
def before_save(self):
|
||||||
|
"""This method will be called before iterating devices for saving."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_restore(self):
|
||||||
|
"""This method will be called after each device is restored."""
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user