Adds tf.train.experimental.PythonState
Allows users to hook into TF's object-based checkpointing with arbitrary Python state. Doing this now for the TF Agents migration. This is a reasonable place to cutoff and expose a public API; the NumPy stuff can be copied out of contrib. Experimental for now pending Trackable being exposed, since it may make more sense to combine the Variable/Tensor saving API with the non-Tensor Python state API. PiperOrigin-RevId: 235579652
This commit is contained in:
parent
06f414573f
commit
4d6f8114b0
tensorflow
@ -46,7 +46,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
|
||||
from tensorflow.contrib.checkpoint.python.python_state import NumpyState
|
||||
from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
|
||||
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
|
||||
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
|
||||
from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph
|
||||
@ -55,6 +54,7 @@ from tensorflow.python.training.tracking.base import Trackable as Checkpointable
|
||||
from tensorflow.python.training.tracking.data_structures import List
|
||||
from tensorflow.python.training.tracking.data_structures import Mapping
|
||||
from tensorflow.python.training.tracking.data_structures import NoDependency
|
||||
from tensorflow.python.training.tracking.python_state import PythonState as PythonStateWrapper
|
||||
from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable
|
||||
from tensorflow.python.training.tracking.util import capture_dependencies
|
||||
from tensorflow.python.training.tracking.util import list_objects
|
||||
@ -62,3 +62,4 @@ from tensorflow.python.training.tracking.util import object_metadata
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(module_name=__name__)
|
||||
|
||||
|
@ -17,13 +17,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import six
|
||||
|
||||
import numpy
|
||||
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import python_state as core_python_state
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
@ -129,29 +126,7 @@ class NumpyState(base.Trackable):
|
||||
super(NumpyState, self).__setattr__(name, value)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class PythonStateWrapper(base.Trackable):
|
||||
"""Wraps a Python object for storage in an object-based checkpoint."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _serialize(self):
|
||||
"""Callback for `PythonStringStateSaveable` to serialize the object."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _deserialize(self, string_value):
|
||||
"""Callback for `PythonStringStateSaveable` to deserialize the object."""
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Specify callbacks for saving and restoring `array`."""
|
||||
return {
|
||||
"py_state": functools.partial(
|
||||
base.PythonStringStateSaveable,
|
||||
state_callback=self._serialize,
|
||||
restore_callback=self._deserialize)
|
||||
}
|
||||
|
||||
|
||||
class _NumpyWrapper(PythonStateWrapper):
|
||||
class _NumpyWrapper(core_python_state.PythonState):
|
||||
"""Wraps a NumPy array for storage in an object-based checkpoint."""
|
||||
|
||||
def __init__(self, array):
|
||||
@ -162,7 +137,7 @@ class _NumpyWrapper(PythonStateWrapper):
|
||||
"""
|
||||
self.array = array
|
||||
|
||||
def _serialize(self):
|
||||
def serialize(self):
|
||||
"""Callback to serialize the array."""
|
||||
string_file = BytesIO()
|
||||
try:
|
||||
@ -172,7 +147,7 @@ class _NumpyWrapper(PythonStateWrapper):
|
||||
string_file.close()
|
||||
return serialized
|
||||
|
||||
def _deserialize(self, string_value):
|
||||
def deserialize(self, string_value):
|
||||
"""Callback to deserialize the array."""
|
||||
string_file = BytesIO(string_value)
|
||||
try:
|
||||
|
@ -3816,6 +3816,7 @@ py_library(
|
||||
"//tensorflow/python/keras/optimizer_v2:learning_rate_schedule",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:python_state",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
|
@ -48,6 +48,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"sysconfig/__init__.py",
|
||||
"test/__init__.py",
|
||||
"train/__init__.py",
|
||||
"train/experimental/__init__.py",
|
||||
"version/__init__.py",
|
||||
# END GENERATED FILES
|
||||
]
|
||||
|
@ -70,6 +70,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
||||
"sysconfig/__init__.py",
|
||||
"test/__init__.py",
|
||||
"train/__init__.py",
|
||||
"train/experimental/__init__.py",
|
||||
"train/queue_runner/__init__.py",
|
||||
"user_ops/__init__.py",
|
||||
"version/__init__.py",
|
||||
|
@ -257,3 +257,23 @@ tf_py_test(
|
||||
"notsan", # b/74395663
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "python_state",
|
||||
srcs = ["python_state.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "python_state_test",
|
||||
srcs = ["python_state_test.py"],
|
||||
additional_deps = [
|
||||
":base",
|
||||
":util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
92
tensorflow/python/training/tracking/python_state.py
Normal file
92
tensorflow/python/training/tracking/python_state.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Utilities for including Python state in TensorFlow checkpoints."""
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import functools
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("train.experimental.PythonState")
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class PythonState(base.Trackable):
|
||||
"""A mixin for putting Python state in an object-based checkpoint.
|
||||
|
||||
This is an abstract class which allows extensions to TensorFlow's object-based
|
||||
checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy
|
||||
arrays:
|
||||
|
||||
```python
|
||||
import io
|
||||
import numpy
|
||||
|
||||
class NumpyWrapper(tf.train.experimental.PythonState):
|
||||
|
||||
def __init__(self, array):
|
||||
self.array = array
|
||||
|
||||
def serialize(self):
|
||||
string_file = io.BytesIO()
|
||||
try:
|
||||
numpy.save(string_file, self.array, allow_pickle=False)
|
||||
serialized = string_file.getvalue()
|
||||
finally:
|
||||
string_file.close()
|
||||
return serialized
|
||||
|
||||
def deserialize(self, string_value):
|
||||
string_file = io.BytesIO(string_value)
|
||||
try:
|
||||
self.array = numpy.load(string_file, allow_pickle=False)
|
||||
finally:
|
||||
string_file.close()
|
||||
```
|
||||
|
||||
Instances of `NumpyWrapper` are checkpointable objects, and will be saved and
|
||||
restored from checkpoints along with TensorFlow state like variables.
|
||||
|
||||
```python
|
||||
root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
|
||||
save_path = root.save(prefix)
|
||||
root.numpy.array *= 2.
|
||||
assert [2.] == root.numpy.array
|
||||
root.restore(save_path)
|
||||
assert [1.] == root.numpy.array
|
||||
```
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def serialize(self):
|
||||
"""Callback to serialize the object. Returns a string."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, string_value):
|
||||
"""Callback to deserialize the object."""
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Specify callbacks for saving and restoring `array`."""
|
||||
return {
|
||||
"py_state": functools.partial(
|
||||
base.PythonStringStateSaveable,
|
||||
state_callback=self.serialize,
|
||||
restore_callback=self.deserialize)
|
||||
}
|
244
tensorflow/python/training/tracking/python_state_test.py
Normal file
244
tensorflow/python/training/tracking/python_state_test.py
Normal file
@ -0,0 +1,244 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import io
|
||||
import os
|
||||
|
||||
import numpy
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import python_state
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class _NumpyState(base.Trackable):
|
||||
"""A checkpointable object whose NumPy array attributes are saved/restored.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
arrays = _NumpyState()
|
||||
checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
|
||||
arrays.x = numpy.zeros([3, 4])
|
||||
save_path = checkpoint.save("/tmp/ckpt")
|
||||
arrays.x[1, 1] = 4.
|
||||
checkpoint.restore(save_path)
|
||||
assert (arrays.x == numpy.zeros([3, 4])).all()
|
||||
|
||||
second_checkpoint = tf.train.Checkpoint(
|
||||
numpy_arrays=_NumpyState())
|
||||
# Attributes of NumpyState objects are created automatically by restore()
|
||||
second_checkpoint.restore(save_path)
|
||||
assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
|
||||
```
|
||||
|
||||
Note that `NumpyState` objects re-create the attributes of the previously
|
||||
saved object on `restore()`. This is in contrast to TensorFlow variables, for
|
||||
which a `Variable` object must be created and assigned to an attribute.
|
||||
|
||||
This snippet works both when graph building and when executing eagerly. On
|
||||
save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
|
||||
a placeholder when graph building, or as a string constant when executing
|
||||
eagerly). When restoring they skip the TensorFlow graph entirely, and so no
|
||||
restore ops need be run. This means that restoration always happens eagerly,
|
||||
rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
|
||||
TensorFlow variables when graph building.
|
||||
"""
|
||||
|
||||
def _lookup_dependency(self, name):
|
||||
"""Create placeholder NumPy arrays for to-be-restored attributes.
|
||||
|
||||
Typically `_lookup_dependency` is used to check by name whether a dependency
|
||||
exists. We cheat slightly by creating a checkpointable object for `name` if
|
||||
we don't already have one, giving us attribute re-creation behavior when
|
||||
loading a checkpoint.
|
||||
|
||||
Args:
|
||||
name: The name of the dependency being checked.
|
||||
Returns:
|
||||
An existing dependency if one exists, or a new `_NumpyWrapper` placeholder
|
||||
dependency (which will generally be restored immediately).
|
||||
"""
|
||||
value = super(_NumpyState, self)._lookup_dependency(name)
|
||||
if value is None:
|
||||
value = _NumpyWrapper(numpy.array([]))
|
||||
new_reference = base.TrackableReference(name=name, ref=value)
|
||||
self._unconditional_checkpoint_dependencies.append(new_reference)
|
||||
self._unconditional_dependency_names[name] = value
|
||||
super(_NumpyState, self).__setattr__(name, value)
|
||||
return value
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""Un-wrap `_NumpyWrapper` objects when accessing attributes."""
|
||||
value = super(_NumpyState, self).__getattribute__(name)
|
||||
if isinstance(value, _NumpyWrapper):
|
||||
return value.array
|
||||
return value
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Automatically wrap NumPy arrays assigned to attributes."""
|
||||
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
|
||||
# ndarrays checkpointable natively and using standard checkpointable list
|
||||
# tracking.
|
||||
if isinstance(value, (numpy.ndarray, numpy.generic)):
|
||||
try:
|
||||
existing = super(_NumpyState, self).__getattribute__(name)
|
||||
existing.array = value
|
||||
return
|
||||
except AttributeError:
|
||||
value = _NumpyWrapper(value)
|
||||
self._track_trackable(value, name=name, overwrite=True)
|
||||
elif (name not in ("_setattr_tracking", "_update_uid")
|
||||
and getattr(self, "_setattr_tracking", True)):
|
||||
# Mixing restore()-created attributes with user-added checkpointable
|
||||
# objects is tricky, since we can't use the `_lookup_dependency` trick to
|
||||
# re-create attributes (we might accidentally steal the restoration for
|
||||
# another checkpointable object). For now `_NumpyState` objects must be
|
||||
# leaf nodes. Theoretically we could add some extra arguments to
|
||||
# `_lookup_dependency` to figure out whether we should create a NumPy
|
||||
# array for the attribute or not.
|
||||
raise NotImplementedError(
|
||||
("Assigned %s to the %s property of %s, which is not a NumPy array. "
|
||||
"Currently mixing NumPy arrays and other checkpointable objects is "
|
||||
"not supported. File a feature request if this limitation bothers "
|
||||
"you.")
|
||||
% (value, name, self))
|
||||
super(_NumpyState, self).__setattr__(name, value)
|
||||
|
||||
|
||||
class _NumpyWrapper(python_state.PythonState):
|
||||
"""Wraps a NumPy array for storage in an object-based checkpoint."""
|
||||
|
||||
def __init__(self, array):
|
||||
"""Specify a NumPy array to wrap.
|
||||
|
||||
Args:
|
||||
array: The NumPy array to save and restore (may be overwritten).
|
||||
"""
|
||||
self.array = array
|
||||
|
||||
def serialize(self):
|
||||
"""Callback to serialize the array."""
|
||||
string_file = io.BytesIO()
|
||||
try:
|
||||
numpy.save(string_file, self.array, allow_pickle=False)
|
||||
serialized = string_file.getvalue()
|
||||
finally:
|
||||
string_file.close()
|
||||
return serialized
|
||||
|
||||
def deserialize(self, string_value):
|
||||
"""Callback to deserialize the array."""
|
||||
string_file = io.BytesIO(string_value)
|
||||
try:
|
||||
self.array = numpy.load(string_file, allow_pickle=False)
|
||||
finally:
|
||||
string_file.close()
|
||||
|
||||
|
||||
class NumpyStateTests(test.TestCase):
|
||||
|
||||
def testWrapper(self):
|
||||
directory = self.get_temp_dir()
|
||||
prefix = os.path.join(directory, "ckpt")
|
||||
root = util.Checkpoint(numpy=_NumpyWrapper(numpy.array([1.])))
|
||||
save_path = root.save(prefix)
|
||||
root.numpy.array *= 2.
|
||||
self.assertEqual([2.], root.numpy.array)
|
||||
root.restore(save_path)
|
||||
self.assertEqual([1.], root.numpy.array)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSaveRestoreNumpyState(self):
|
||||
directory = self.get_temp_dir()
|
||||
prefix = os.path.join(directory, "ckpt")
|
||||
save_state = _NumpyState()
|
||||
saver = util.Checkpoint(numpy=save_state)
|
||||
save_state.a = numpy.ones([2, 2])
|
||||
save_state.b = numpy.ones([2, 2])
|
||||
save_state.b = numpy.zeros([2, 2])
|
||||
save_state.c = numpy.int64(3)
|
||||
self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
|
||||
self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
|
||||
self.assertEqual(3, save_state.c)
|
||||
first_save_path = saver.save(prefix)
|
||||
save_state.a[1, 1] = 2.
|
||||
save_state.c = numpy.int64(4)
|
||||
second_save_path = saver.save(prefix)
|
||||
|
||||
load_state = _NumpyState()
|
||||
loader = util.Checkpoint(numpy=load_state)
|
||||
loader.restore(first_save_path).initialize_or_restore()
|
||||
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
|
||||
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
|
||||
self.assertEqual(3, load_state.c)
|
||||
load_state.a[0, 0] = 42.
|
||||
self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
|
||||
loader.restore(first_save_path).run_restore_ops()
|
||||
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
|
||||
loader.restore(second_save_path).run_restore_ops()
|
||||
self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
|
||||
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
|
||||
self.assertEqual(4, load_state.c)
|
||||
|
||||
def testNoGraphPollution(self):
|
||||
graph = ops.Graph()
|
||||
with graph.as_default(), session.Session():
|
||||
directory = self.get_temp_dir()
|
||||
prefix = os.path.join(directory, "ckpt")
|
||||
save_state = _NumpyState()
|
||||
saver = util.Checkpoint(numpy=save_state)
|
||||
save_state.a = numpy.ones([2, 2])
|
||||
save_path = saver.save(prefix)
|
||||
saver.restore(save_path)
|
||||
graph.finalize()
|
||||
saver.save(prefix)
|
||||
save_state.a = numpy.zeros([2, 2])
|
||||
saver.save(prefix)
|
||||
saver.restore(save_path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNoMixedNumpyStateTF(self):
|
||||
save_state = _NumpyState()
|
||||
save_state.a = numpy.ones([2, 2])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
save_state.v = variables.Variable(1.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDocstringExample(self):
|
||||
arrays = _NumpyState()
|
||||
checkpoint = util.Checkpoint(numpy_arrays=arrays)
|
||||
arrays.x = numpy.zeros([3, 4])
|
||||
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
arrays.x[1, 1] = 4.
|
||||
checkpoint.restore(save_path)
|
||||
self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
|
||||
|
||||
second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState())
|
||||
second_checkpoint.restore(save_path)
|
||||
self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -68,6 +68,7 @@ from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
|
||||
from tensorflow.python.training.basic_session_run_hooks import ProfilerHook
|
||||
from tensorflow.python.training.basic_loops import basic_train_loop
|
||||
from tensorflow.python.training.tracking.python_state import PythonState
|
||||
from tensorflow.python.training.tracking.util import Checkpoint
|
||||
from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
|
||||
from tensorflow.python.training.checkpoint_utils import list_variables
|
||||
@ -142,3 +143,4 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
|
||||
tf_export("train.SequenceExample")(SequenceExample)
|
||||
tf_export("train.ServerDef")(ServerDef)
|
||||
# pylint: enable=undefined-variable
|
||||
|
||||
|
@ -0,0 +1,17 @@
|
||||
path: "tensorflow.train.experimental.PythonState"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.python_state.PythonState\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize"
|
||||
argspec: "args=[\'self\', \'string_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "serialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.train.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "PythonState"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
@ -240,6 +240,10 @@ tf_module {
|
||||
name: "WorkerSessionCreator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "queue_runner"
|
||||
mtype: "<type \'module\'>"
|
||||
|
@ -0,0 +1,17 @@
|
||||
path: "tensorflow.train.experimental.PythonState"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.python_state.PythonState\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize"
|
||||
argspec: "args=[\'self\', \'string_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "serialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.train.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "PythonState"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
@ -68,6 +68,10 @@ tf_module {
|
||||
name: "ServerDef"
|
||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_checkpoint_state"
|
||||
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user