Adds tf.train.experimental.PythonState

Allows users to hook into TF's object-based checkpointing with arbitrary Python state.

Doing this now for the TF Agents migration. This is a reasonable place to cutoff and expose a public API; the NumPy stuff can be copied out of contrib.

Experimental for now pending Trackable being exposed, since it may make more sense to combine the Variable/Tensor saving API with the non-Tensor Python state API.

PiperOrigin-RevId: 235579652
This commit is contained in:
Allen Lavoie 2019-02-25 12:21:39 -08:00 committed by TensorFlower Gardener
parent 06f414573f
commit 4d6f8114b0
15 changed files with 423 additions and 30 deletions

View File

@ -46,7 +46,6 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.python_state import NumpyState
from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph
@ -55,6 +54,7 @@ from tensorflow.python.training.tracking.base import Trackable as Checkpointable
from tensorflow.python.training.tracking.data_structures import List
from tensorflow.python.training.tracking.data_structures import Mapping
from tensorflow.python.training.tracking.data_structures import NoDependency
from tensorflow.python.training.tracking.python_state import PythonState as PythonStateWrapper
from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable
from tensorflow.python.training.tracking.util import capture_dependencies
from tensorflow.python.training.tracking.util import list_objects
@ -62,3 +62,4 @@ from tensorflow.python.training.tracking.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)

View File

@ -17,13 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import six
import numpy
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import python_state as core_python_state
# pylint: disable=g-import-not-at-top
try:
@ -129,29 +126,7 @@ class NumpyState(base.Trackable):
super(NumpyState, self).__setattr__(name, value)
@six.add_metaclass(abc.ABCMeta)
class PythonStateWrapper(base.Trackable):
"""Wraps a Python object for storage in an object-based checkpoint."""
@abc.abstractmethod
def _serialize(self):
"""Callback for `PythonStringStateSaveable` to serialize the object."""
@abc.abstractmethod
def _deserialize(self, string_value):
"""Callback for `PythonStringStateSaveable` to deserialize the object."""
def _gather_saveables_for_checkpoint(self):
"""Specify callbacks for saving and restoring `array`."""
return {
"py_state": functools.partial(
base.PythonStringStateSaveable,
state_callback=self._serialize,
restore_callback=self._deserialize)
}
class _NumpyWrapper(PythonStateWrapper):
class _NumpyWrapper(core_python_state.PythonState):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
@ -162,7 +137,7 @@ class _NumpyWrapper(PythonStateWrapper):
"""
self.array = array
def _serialize(self):
def serialize(self):
"""Callback to serialize the array."""
string_file = BytesIO()
try:
@ -172,7 +147,7 @@ class _NumpyWrapper(PythonStateWrapper):
string_file.close()
return serialized
def _deserialize(self, string_value):
def deserialize(self, string_value):
"""Callback to deserialize the array."""
string_file = BytesIO(string_value)
try:

View File

@ -3816,6 +3816,7 @@ py_library(
"//tensorflow/python/keras/optimizer_v2:learning_rate_schedule",
"//tensorflow/python/ops/losses",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:python_state",
"//tensorflow/python/training/tracking:util",
"//third_party/py/numpy",
"@six_archive//:six",

View File

@ -48,6 +48,7 @@ TENSORFLOW_API_INIT_FILES = [
"sysconfig/__init__.py",
"test/__init__.py",
"train/__init__.py",
"train/experimental/__init__.py",
"version/__init__.py",
# END GENERATED FILES
]

View File

@ -70,6 +70,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"sysconfig/__init__.py",
"test/__init__.py",
"train/__init__.py",
"train/experimental/__init__.py",
"train/queue_runner/__init__.py",
"user_ops/__init__.py",
"version/__init__.py",

View File

@ -257,3 +257,23 @@ tf_py_test(
"notsan", # b/74395663
],
)
py_library(
name = "python_state",
srcs = ["python_state.py"],
srcs_version = "PY2AND3",
deps = [
":base",
],
)
tf_py_test(
name = "python_state_test",
srcs = ["python_state_test.py"],
additional_deps = [
":base",
":util",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
],
)

View File

@ -0,0 +1,92 @@
"""Utilities for including Python state in TensorFlow checkpoints."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import six
from tensorflow.python.training.tracking import base
from tensorflow.python.util.tf_export import tf_export
@tf_export("train.experimental.PythonState")
@six.add_metaclass(abc.ABCMeta)
class PythonState(base.Trackable):
"""A mixin for putting Python state in an object-based checkpoint.
This is an abstract class which allows extensions to TensorFlow's object-based
checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy
arrays:
```python
import io
import numpy
class NumpyWrapper(tf.train.experimental.PythonState):
def __init__(self, array):
self.array = array
def serialize(self):
string_file = io.BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
serialized = string_file.getvalue()
finally:
string_file.close()
return serialized
def deserialize(self, string_value):
string_file = io.BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
```
Instances of `NumpyWrapper` are checkpointable objects, and will be saved and
restored from checkpoints along with TensorFlow state like variables.
```python
root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
save_path = root.save(prefix)
root.numpy.array *= 2.
assert [2.] == root.numpy.array
root.restore(save_path)
assert [1.] == root.numpy.array
```
"""
@abc.abstractmethod
def serialize(self):
"""Callback to serialize the object. Returns a string."""
@abc.abstractmethod
def deserialize(self, string_value):
"""Callback to deserialize the object."""
def _gather_saveables_for_checkpoint(self):
"""Specify callbacks for saving and restoring `array`."""
return {
"py_state": functools.partial(
base.PythonStringStateSaveable,
state_callback=self.serialize,
restore_callback=self.deserialize)
}

View File

@ -0,0 +1,244 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import numpy
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import python_state
from tensorflow.python.training.tracking import util
class _NumpyState(base.Trackable):
"""A checkpointable object whose NumPy array attributes are saved/restored.
Example usage:
```python
arrays = _NumpyState()
checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
arrays.x = numpy.zeros([3, 4])
save_path = checkpoint.save("/tmp/ckpt")
arrays.x[1, 1] = 4.
checkpoint.restore(save_path)
assert (arrays.x == numpy.zeros([3, 4])).all()
second_checkpoint = tf.train.Checkpoint(
numpy_arrays=_NumpyState())
# Attributes of NumpyState objects are created automatically by restore()
second_checkpoint.restore(save_path)
assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
```
Note that `NumpyState` objects re-create the attributes of the previously
saved object on `restore()`. This is in contrast to TensorFlow variables, for
which a `Variable` object must be created and assigned to an attribute.
This snippet works both when graph building and when executing eagerly. On
save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
a placeholder when graph building, or as a string constant when executing
eagerly). When restoring they skip the TensorFlow graph entirely, and so no
restore ops need be run. This means that restoration always happens eagerly,
rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
TensorFlow variables when graph building.
"""
def _lookup_dependency(self, name):
"""Create placeholder NumPy arrays for to-be-restored attributes.
Typically `_lookup_dependency` is used to check by name whether a dependency
exists. We cheat slightly by creating a checkpointable object for `name` if
we don't already have one, giving us attribute re-creation behavior when
loading a checkpoint.
Args:
name: The name of the dependency being checked.
Returns:
An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
dependency (which will generally be restored immediately).
"""
value = super(_NumpyState, self)._lookup_dependency(name)
if value is None:
value = _NumpyWrapper(numpy.array([]))
new_reference = base.TrackableReference(name=name, ref=value)
self._unconditional_checkpoint_dependencies.append(new_reference)
self._unconditional_dependency_names[name] = value
super(_NumpyState, self).__setattr__(name, value)
return value
def __getattribute__(self, name):
"""Un-wrap `_NumpyWrapper` objects when accessing attributes."""
value = super(_NumpyState, self).__getattribute__(name)
if isinstance(value, _NumpyWrapper):
return value.array
return value
def __setattr__(self, name, value):
"""Automatically wrap NumPy arrays assigned to attributes."""
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# tracking.
if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
existing = super(_NumpyState, self).__getattribute__(name)
existing.array = value
return
except AttributeError:
value = _NumpyWrapper(value)
self._track_trackable(value, name=name, overwrite=True)
elif (name not in ("_setattr_tracking", "_update_uid")
and getattr(self, "_setattr_tracking", True)):
# Mixing restore()-created attributes with user-added checkpointable
# objects is tricky, since we can't use the `_lookup_dependency` trick to
# re-create attributes (we might accidentally steal the restoration for
# another checkpointable object). For now `_NumpyState` objects must be
# leaf nodes. Theoretically we could add some extra arguments to
# `_lookup_dependency` to figure out whether we should create a NumPy
# array for the attribute or not.
raise NotImplementedError(
("Assigned %s to the %s property of %s, which is not a NumPy array. "
"Currently mixing NumPy arrays and other checkpointable objects is "
"not supported. File a feature request if this limitation bothers "
"you.")
% (value, name, self))
super(_NumpyState, self).__setattr__(name, value)
class _NumpyWrapper(python_state.PythonState):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
"""Specify a NumPy array to wrap.
Args:
array: The NumPy array to save and restore (may be overwritten).
"""
self.array = array
def serialize(self):
"""Callback to serialize the array."""
string_file = io.BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
serialized = string_file.getvalue()
finally:
string_file.close()
return serialized
def deserialize(self, string_value):
"""Callback to deserialize the array."""
string_file = io.BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
class NumpyStateTests(test.TestCase):
def testWrapper(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")
root = util.Checkpoint(numpy=_NumpyWrapper(numpy.array([1.])))
save_path = root.save(prefix)
root.numpy.array *= 2.
self.assertEqual([2.], root.numpy.array)
root.restore(save_path)
self.assertEqual([1.], root.numpy.array)
@test_util.run_in_graph_and_eager_modes
def testSaveRestoreNumpyState(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")
save_state = _NumpyState()
saver = util.Checkpoint(numpy=save_state)
save_state.a = numpy.ones([2, 2])
save_state.b = numpy.ones([2, 2])
save_state.b = numpy.zeros([2, 2])
save_state.c = numpy.int64(3)
self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
self.assertEqual(3, save_state.c)
first_save_path = saver.save(prefix)
save_state.a[1, 1] = 2.
save_state.c = numpy.int64(4)
second_save_path = saver.save(prefix)
load_state = _NumpyState()
loader = util.Checkpoint(numpy=load_state)
loader.restore(first_save_path).initialize_or_restore()
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
self.assertEqual(3, load_state.c)
load_state.a[0, 0] = 42.
self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
loader.restore(first_save_path).run_restore_ops()
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
loader.restore(second_save_path).run_restore_ops()
self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
self.assertEqual(4, load_state.c)
def testNoGraphPollution(self):
graph = ops.Graph()
with graph.as_default(), session.Session():
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")
save_state = _NumpyState()
saver = util.Checkpoint(numpy=save_state)
save_state.a = numpy.ones([2, 2])
save_path = saver.save(prefix)
saver.restore(save_path)
graph.finalize()
saver.save(prefix)
save_state.a = numpy.zeros([2, 2])
saver.save(prefix)
saver.restore(save_path)
@test_util.run_in_graph_and_eager_modes
def testNoMixedNumpyStateTF(self):
save_state = _NumpyState()
save_state.a = numpy.ones([2, 2])
with self.assertRaises(NotImplementedError):
save_state.v = variables.Variable(1.)
@test_util.run_in_graph_and_eager_modes
def testDocstringExample(self):
arrays = _NumpyState()
checkpoint = util.Checkpoint(numpy_arrays=arrays)
arrays.x = numpy.zeros([3, 4])
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
arrays.x[1, 1] = 4.
checkpoint.restore(save_path)
self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState())
second_checkpoint.restore(save_path)
self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

View File

@ -68,6 +68,7 @@ from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
from tensorflow.python.training.basic_loops import basic_train_loop
from tensorflow.python.training.tracking.python_state import PythonState
from tensorflow.python.training.tracking.util import Checkpoint
from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
from tensorflow.python.training.checkpoint_utils import list_variables
@ -142,3 +143,4 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
tf_export("train.SequenceExample")(SequenceExample)
tf_export("train.ServerDef")(ServerDef)
# pylint: enable=undefined-variable

View File

@ -0,0 +1,17 @@
path: "tensorflow.train.experimental.PythonState"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.python_state.PythonState\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "deserialize"
argspec: "args=[\'self\', \'string_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "serialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.train.experimental"
tf_module {
member {
name: "PythonState"
mtype: "<type \'type\'>"
}
}

View File

@ -240,6 +240,10 @@ tf_module {
name: "WorkerSessionCreator"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
member {
name: "queue_runner"
mtype: "<type \'module\'>"

View File

@ -0,0 +1,17 @@
path: "tensorflow.train.experimental.PythonState"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.python_state.PythonState\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
}
member_method {
name: "deserialize"
argspec: "args=[\'self\', \'string_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "serialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.train.experimental"
tf_module {
member {
name: "PythonState"
mtype: "<type \'type\'>"
}
}

View File

@ -68,6 +68,10 @@ tf_module {
name: "ServerDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
member_method {
name: "get_checkpoint_state"
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "