Switch to wrapt for trackable dict data structures instead of subclassing collections.Mapping
Adds a dependency on wrapt to TensorFlow's pip package. It's not a very large dependency, and this is a usability win when subclassing TensorFlow types (Model, Module, Layer, etc.). wrapt fixes issues with CPython's direct access to dictionaries by inheriting the memory layout of the wrapped object, so we can now pass isinstance(obj, dict) checks while staying correct (i.e. {}.update(obj) won't look at the possibly-stale wrapper object's memory). There are several new oddities with the wrapt approach, but overall this is much better. We're already doing dict wrapping, just not in a way that passes isinstance(obj, dict) checks. We need it to support restore-on-create for variables added to objects while executing eagerly. I tried switching _ListWrapper to wrapt too. It works in most Python versions (with some tweaks to TF to accommodate the change), but apparently in Python 3.4 isinstance(obj, (list, tuple)) is not functioning properly. There are no correctness issues with actually subclassing list, and it means type(obj) shows up as a subclass of list, so this isn't too bad. Since ObjectProxy copies the memory layout of the wrapped object we can't do both of these at the same time. PiperOrigin-RevId: 243118363
This commit is contained in:
parent
f8c7522bb4
commit
48cb1ae640
@ -108,8 +108,8 @@ class NumpyState(base.Trackable):
|
||||
except AttributeError:
|
||||
value = _NumpyWrapper(value)
|
||||
self._track_trackable(value, name=name, overwrite=True)
|
||||
elif (name not in ("_setattr_tracking", "_update_uid")
|
||||
and getattr(self, "_setattr_tracking", True)):
|
||||
elif (name not in ("_self_setattr_tracking", "_self_update_uid")
|
||||
and getattr(self, "_self_setattr_tracking", True)):
|
||||
# Mixing restore()-created attributes with user-added trackable
|
||||
# objects is tricky, since we can't use the `_lookup_dependency` trick to
|
||||
# re-create attributes (we might accidentally steal the restoration for
|
||||
@ -154,4 +154,3 @@ class _NumpyWrapper(core_python_state.PythonState):
|
||||
self.array = numpy.load(string_file, allow_pickle=False)
|
||||
finally:
|
||||
string_file.close()
|
||||
|
||||
|
@ -216,6 +216,7 @@ tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
@ -247,6 +248,7 @@ tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/__init__.py
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
|
@ -3070,8 +3070,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
|
||||
} else if (PyTuple_Check(arg)) {
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
|
||||
arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
|
||||
} else if (PyDict_Check(arg)) {
|
||||
tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
|
||||
} else if (tensorflow::swig::IsMapping(arg)) {
|
||||
tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
|
||||
if (PyList_Sort(keys.get()) == -1) {
|
||||
return tensorflow::errors::Internal("Unable to sort keys");
|
||||
}
|
||||
@ -3083,9 +3083,9 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
|
||||
PyObject* key = PyList_GetItem(keys.get(), i);
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
|
||||
PyObject* value = PyDict_GetItem(arg, key);
|
||||
TF_RETURN_IF_ERROR(
|
||||
TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
|
||||
tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
|
||||
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
|
||||
value.get(), include_tensor_ranks_only, result));
|
||||
}
|
||||
} else {
|
||||
PyObject* object = PyWeakref_NewRef(arg, nullptr);
|
||||
|
@ -1921,8 +1921,8 @@ class Layer(module.Module):
|
||||
[w for w in self._non_trainable_weights if w is not existing_value])
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if (name == '_setattr_tracking' or
|
||||
not getattr(self, '_setattr_tracking', True) or
|
||||
if (name == '_self_setattr_tracking' or
|
||||
not getattr(self, '_self_setattr_tracking', True) or
|
||||
getattr(self, '_is_graph_network', False) or
|
||||
# Exclude @property.setters from tracking
|
||||
hasattr(self.__class__, name)):
|
||||
@ -2279,3 +2279,4 @@ def default(method):
|
||||
# Avoid breaking users who directly import this symbol from this file.
|
||||
# TODO(fchollet): remove this.
|
||||
InputSpec = input_spec.InputSpec # pylint:disable=invalid-name
|
||||
|
||||
|
@ -405,7 +405,7 @@ class Network(base_layer.Layer):
|
||||
layer, name='layer-%d' % layer_index, overwrite=True)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if not getattr(self, '_setattr_tracking', True):
|
||||
if not getattr(self, '_self_setattr_tracking', True):
|
||||
super(Network, self).__setattr__(name, value)
|
||||
return
|
||||
|
||||
|
@ -99,8 +99,8 @@ class Module(tracking.AutoTrackable):
|
||||
# include when flattening (these reference dependencies reachable via other
|
||||
# object attributes).
|
||||
_TF_MODULE_IGNORED_PROPERTIES = frozenset((
|
||||
"_unconditional_checkpoint_dependencies",
|
||||
"_unconditional_dependency_names"
|
||||
"_self_unconditional_checkpoint_dependencies",
|
||||
"_self_unconditional_dependency_names"
|
||||
))
|
||||
|
||||
def __init__(self, name=None):
|
||||
|
@ -76,6 +76,7 @@ py_library(
|
||||
":base",
|
||||
":layer_utils",
|
||||
"//tensorflow/python/saved_model:revived_types",
|
||||
"@wrapt",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -452,12 +452,12 @@ def no_automatic_dependency_tracking(method):
|
||||
"""
|
||||
|
||||
def _method_wrapper(self, *args, **kwargs):
|
||||
previous_value = getattr(self, "_setattr_tracking", True)
|
||||
self._setattr_tracking = False # pylint: disable=protected-access
|
||||
previous_value = getattr(self, "_self_setattr_tracking", True)
|
||||
self._self_setattr_tracking = False # pylint: disable=protected-access
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
finally:
|
||||
self._setattr_tracking = previous_value # pylint: disable=protected-access
|
||||
self._self_setattr_tracking = previous_value # pylint: disable=protected-access
|
||||
return result
|
||||
|
||||
return tf_decorator.make_decorator(
|
||||
@ -473,6 +473,40 @@ class Trackable(object):
|
||||
checks.
|
||||
"""
|
||||
|
||||
# For compatibility with wrapt.ObjectProxy, attributes are all prefixed with
|
||||
# _self_. We have some properties to forward semi-public attributes to their
|
||||
# _self_ equivalents.
|
||||
|
||||
@property
|
||||
def _setattr_tracking(self):
|
||||
if not hasattr(self, "_self_setattr_tracking"):
|
||||
self._self_setattr_tracking = True
|
||||
return self._self_setattr_tracking
|
||||
|
||||
@_setattr_tracking.setter
|
||||
def _setattr_tracking(self, value):
|
||||
self._self_setattr_tracking = value
|
||||
|
||||
@property
|
||||
def _update_uid(self):
|
||||
return self._self_update_uid
|
||||
|
||||
@_update_uid.setter
|
||||
def _update_uid(self, value):
|
||||
self._self_update_uid = value
|
||||
|
||||
@property
|
||||
def _unconditional_checkpoint_dependencies(self):
|
||||
return self._self_unconditional_checkpoint_dependencies
|
||||
|
||||
@property
|
||||
def _unconditional_dependency_names(self):
|
||||
return self._self_unconditional_dependency_names
|
||||
|
||||
@property
|
||||
def _name_based_restores(self):
|
||||
return self._self_name_based_restores
|
||||
|
||||
# Trackable does not do automatic dependency tracking, but uses the
|
||||
# no_automatic_dependency_tracking decorator so it can avoid adding
|
||||
# dependencies if a subclass is Trackable / inherits from Model (both of
|
||||
@ -483,7 +517,7 @@ class Trackable(object):
|
||||
|
||||
Not __init__, since most objects will forget to call it.
|
||||
"""
|
||||
if hasattr(self, "_unconditional_checkpoint_dependencies"):
|
||||
if hasattr(self, "_self_unconditional_checkpoint_dependencies"):
|
||||
# __init__ already called. This check means that we don't need
|
||||
# Trackable.__init__() in the constructor of every TensorFlow object.
|
||||
return
|
||||
@ -491,21 +525,21 @@ class Trackable(object):
|
||||
# `Trackable`, notably `Optimizer`s, may override the
|
||||
# _checkpoint_dependencies property with conditional dependencies
|
||||
# (e.g. based on the current graph when saving).
|
||||
self._unconditional_checkpoint_dependencies = []
|
||||
self._self_unconditional_checkpoint_dependencies = []
|
||||
# Maps names -> Trackable objects
|
||||
self._unconditional_dependency_names = {}
|
||||
self._self_unconditional_dependency_names = {}
|
||||
# Restorations for other Trackable objects on which this object may
|
||||
# eventually depend. Maps local name -> CheckpointPosition list. Optimizers
|
||||
# tack on conditional dependencies, and so need separate management of
|
||||
# deferred dependencies too.
|
||||
self._unconditional_deferred_dependencies = {}
|
||||
self._self_unconditional_deferred_dependencies = {}
|
||||
# The UID of the highest assignment to this object. Used to ensure that the
|
||||
# last requested assignment determines the final value of an object.
|
||||
if hasattr(self, "_update_uid"):
|
||||
if hasattr(self, "_self_update_uid"):
|
||||
raise AssertionError(
|
||||
"Internal error: the object had an update UID set before its "
|
||||
"initialization code was run.")
|
||||
self._update_uid = -1
|
||||
self._self_update_uid = -1
|
||||
# When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
|
||||
# instances, which should be checked when creating variables or other
|
||||
# saveables. These are passed on recursively to all dependencies, since
|
||||
@ -513,7 +547,7 @@ class Trackable(object):
|
||||
# being restored in advance. This mechanism is only necessary for
|
||||
# restore-on-create when executing eagerly, and so is unused when graph
|
||||
# building.
|
||||
self._name_based_restores = set()
|
||||
self._self_name_based_restores = set()
|
||||
|
||||
def _no_dependency(self, value):
|
||||
"""If automatic dependency tracking is enabled, ignores `value`."""
|
||||
@ -521,10 +555,10 @@ class Trackable(object):
|
||||
|
||||
def _name_based_attribute_restore(self, checkpoint):
|
||||
"""Restore the object's attributes from a name-based checkpoint."""
|
||||
self._name_based_restores.add(checkpoint)
|
||||
if self._update_uid < checkpoint.restore_uid:
|
||||
self._self_name_based_restores.add(checkpoint)
|
||||
if self._self_update_uid < checkpoint.restore_uid:
|
||||
checkpoint.eager_restore(self)
|
||||
self._update_uid = checkpoint.restore_uid
|
||||
self._self_update_uid = checkpoint.restore_uid
|
||||
|
||||
@property
|
||||
def _checkpoint_dependencies(self):
|
||||
@ -537,7 +571,7 @@ class Trackable(object):
|
||||
`Trackable` dependencies which should be saved along with this
|
||||
object.
|
||||
"""
|
||||
return self._unconditional_checkpoint_dependencies
|
||||
return self._self_unconditional_checkpoint_dependencies
|
||||
|
||||
@property
|
||||
def _deferred_dependencies(self):
|
||||
@ -552,7 +586,7 @@ class Trackable(object):
|
||||
A dictionary mapping from local name to a list of CheckpointPosition
|
||||
objects.
|
||||
"""
|
||||
return self._unconditional_deferred_dependencies
|
||||
return self._self_unconditional_deferred_dependencies
|
||||
|
||||
def _lookup_dependency(self, name):
|
||||
"""Look up a dependency by name.
|
||||
@ -565,7 +599,7 @@ class Trackable(object):
|
||||
A `Trackable` object, or `None` if no dependency by this name was
|
||||
found.
|
||||
"""
|
||||
return self._unconditional_dependency_names.get(name, None)
|
||||
return self._self_unconditional_dependency_names.get(name, None)
|
||||
|
||||
def _add_variable_with_custom_getter(
|
||||
self, name, shape=None, dtype=dtypes.float32,
|
||||
@ -715,14 +749,15 @@ class Trackable(object):
|
||||
# This is a weird thing to do, but we're not going to stop people from
|
||||
# using __setattr__.
|
||||
for index, (old_name, _) in enumerate(
|
||||
self._unconditional_checkpoint_dependencies):
|
||||
self._self_unconditional_checkpoint_dependencies):
|
||||
if name == old_name:
|
||||
self._unconditional_checkpoint_dependencies[index] = new_reference
|
||||
self._self_unconditional_checkpoint_dependencies[
|
||||
index] = new_reference
|
||||
elif current_object is None:
|
||||
self._unconditional_checkpoint_dependencies.append(new_reference)
|
||||
self._self_unconditional_checkpoint_dependencies.append(new_reference)
|
||||
self._handle_deferred_dependencies(
|
||||
name=name, trackable=trackable)
|
||||
self._unconditional_dependency_names[name] = trackable
|
||||
self._self_unconditional_dependency_names[name] = trackable
|
||||
return trackable
|
||||
|
||||
def _handle_deferred_dependencies(self, name, trackable):
|
||||
@ -759,7 +794,7 @@ class Trackable(object):
|
||||
|
||||
# Pass on any name-based restores queued in this object.
|
||||
for name_based_restore in sorted(
|
||||
self._name_based_restores,
|
||||
self._self_name_based_restores,
|
||||
key=lambda checkpoint: checkpoint.restore_uid,
|
||||
reverse=True):
|
||||
trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
|
||||
@ -789,9 +824,9 @@ class Trackable(object):
|
||||
# If the UID of this restore is lower than our current update UID, we don't
|
||||
# need to actually restore the object. However, we should pass the
|
||||
# restoration on to our dependencies.
|
||||
if checkpoint.restore_uid > self._update_uid:
|
||||
if checkpoint.restore_uid > self._self_update_uid:
|
||||
restore_ops = checkpoint_position.restore_ops()
|
||||
self._update_uid = checkpoint.restore_uid
|
||||
self._self_update_uid = checkpoint.restore_uid
|
||||
else:
|
||||
restore_ops = ()
|
||||
for child in checkpoint_position.object_proto.children:
|
||||
|
@ -23,6 +23,11 @@ import operator
|
||||
import sys
|
||||
|
||||
import six
|
||||
try:
|
||||
import wrapt
|
||||
except ImportError:
|
||||
# Fall back to the build-time dependency if the system package is not available.
|
||||
from .....third_party import wrapt
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
@ -133,15 +138,25 @@ class TrackableDataStructure(base.Trackable):
|
||||
"""Base class for data structures which contain trackable objects."""
|
||||
|
||||
def __init__(self):
|
||||
self.trainable = True
|
||||
self._extra_variables = []
|
||||
# Attributes prefixed with "_self_" for compatibility with
|
||||
# wrapt.ObjectProxy.
|
||||
self._self_trainable = True
|
||||
self._self_extra_variables = []
|
||||
|
||||
@property
|
||||
def trainable(self):
|
||||
return self._self_trainable
|
||||
|
||||
@trainable.setter
|
||||
def trainable(self, value):
|
||||
self._self_trainable = value
|
||||
|
||||
def _track_value(self, value, name):
|
||||
"""Add a dependency on `value`."""
|
||||
value = sticky_attribute_assignment(
|
||||
trackable=self, value=value, name=name)
|
||||
if isinstance(value, variables.Variable):
|
||||
self._extra_variables.append(value)
|
||||
self._self_extra_variables.append(value)
|
||||
if not isinstance(value, base.Trackable):
|
||||
raise _UntrackableError(value)
|
||||
if hasattr(value, "_use_resource_variables"):
|
||||
@ -177,14 +192,14 @@ class TrackableDataStructure(base.Trackable):
|
||||
return layer_utils.gather_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._extra_variables)
|
||||
extra_variables=self._self_extra_variables)
|
||||
|
||||
@property
|
||||
def non_trainable_weights(self):
|
||||
return layer_utils.gather_non_trainable_weights(
|
||||
trainable=self.trainable,
|
||||
sub_layers=self._layers,
|
||||
extra_variables=self._extra_variables)
|
||||
extra_variables=self._self_extra_variables)
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
@ -295,6 +310,7 @@ class List(TrackableDataStructure, collections.Sequence):
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
"""Collect values for TrackableDataStructure."""
|
||||
return self
|
||||
|
||||
def append(self, value):
|
||||
@ -353,6 +369,8 @@ class List(TrackableDataStructure, collections.Sequence):
|
||||
|
||||
|
||||
# TODO(tomhennigan) Update to collections.UserList?
|
||||
# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
|
||||
# Python 3.4 support (may still be tricky).
|
||||
class _ListWrapper(List, collections.MutableSequence,
|
||||
# Shadowed, but there for isinstance checks.
|
||||
list):
|
||||
@ -452,7 +470,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
|
||||
if isinstance(key, slice):
|
||||
# Note: this is quite inefficient, but the list API supports a broad range
|
||||
# of slice setters (e.g. truncate, extend, replace) and immitating this
|
||||
# of slice setters (e.g. truncate, extend, replace) and imitating this
|
||||
# for a range of Python versions is non-trivial.
|
||||
storage_copy = list(self._storage)
|
||||
self._storage[key] = value
|
||||
@ -587,6 +605,7 @@ class Mapping(TrackableDataStructure, collections.Mapping):
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
"""Collect values for TrackableDataStructure."""
|
||||
# Sort items deterministically by key
|
||||
ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
|
||||
if ordered:
|
||||
@ -627,19 +646,7 @@ class Mapping(TrackableDataStructure, collections.Mapping):
|
||||
return iter(self._storage)
|
||||
|
||||
|
||||
# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance
|
||||
# checks seems infeasible. CPython will not call Python methods/properties on
|
||||
# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead
|
||||
# collects elements directly from dict_subclass's C structs. So subclassing dict
|
||||
# implies that the storage has to be "self" (i.e. the C structs for the object
|
||||
# must be updated correctly), but we also need that storage to be the wrapped
|
||||
# dictionary to avoid synchronization bugs (un-tracked external modifications
|
||||
# should still show up when the dict is accessed through the wrapper). Monkey
|
||||
# patching all of the "wrapped" dict's methods instead of creating a wrapper
|
||||
# object is an option, but not a very attractive one (replacing methods without
|
||||
# creating reference cycles is difficult, and then dicts would need to be
|
||||
# special cased everywhere as being trackable).
|
||||
class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
"""Wraps built-in dicts to support restore-on-create for variables.
|
||||
|
||||
_DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
|
||||
@ -648,52 +655,68 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
_DictWrapper will raise an exception on save.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args):
|
||||
if len(args) == 1 and isinstance(args[0], dict):
|
||||
return super(_DictWrapper, cls).__new__(cls)
|
||||
else:
|
||||
# Allow construction from a sequence, e.g. for nest.pack_sequence_as. In
|
||||
# this case there's nothing to wrap, so we make a normal dictionary. Also
|
||||
# allows constructing empty instances of the _DictWrapper type, as Session
|
||||
# is wont to do (and again there's nothing to wrap, so a normal dictionary
|
||||
# makes more sense).
|
||||
return dict(*args)
|
||||
|
||||
def __init__(self, wrapped_dict):
|
||||
self._non_string_key = False
|
||||
self._non_append_mutation = False
|
||||
self._external_modification = False
|
||||
super(_DictWrapper, self).__init__(wrapped_dict)
|
||||
def __init__(self, wrapped_dict=None):
|
||||
if wrapped_dict is None:
|
||||
# Allow zero-argument construction, e.g. from session.run's re-wrapping.
|
||||
wrapped_dict = {}
|
||||
if not isinstance(wrapped_dict, collections.Mapping):
|
||||
# Allow construction from a sequence, e.g. from nest.pack_sequence_as.
|
||||
wrapped_dict = dict(wrapped_dict)
|
||||
wrapt.ObjectProxy.__init__(self, wrapped_dict)
|
||||
TrackableDataStructure.__init__(self)
|
||||
self._self_non_string_key = False
|
||||
self._self_non_append_mutation = False
|
||||
self._self_external_modification = False
|
||||
self.__wrapped__.update(
|
||||
{key: self._track_value(
|
||||
value, name=self._name_element(key))
|
||||
for key, value in self.__wrapped__.items()})
|
||||
self._update_snapshot()
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if (hasattr(type(self), name)
|
||||
and isinstance(getattr(type(self), name), property)):
|
||||
# Bypass ObjectProxy for properties. Whether this workaround is necessary
|
||||
# appears to depend on the Python version but not the wrapt version: 3.4
|
||||
# in particular seems to look up properties on the wrapped object instead
|
||||
# of the wrapper without this logic.
|
||||
return object.__getattribute__(self, name)
|
||||
else:
|
||||
return super(_DictWrapper, self).__getattribute__(name)
|
||||
|
||||
def copy(self):
|
||||
return copy.copy(self)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def __copy__(self):
|
||||
copied = super(_DictWrapper, self).__copy__()
|
||||
copied._non_append_mutation = self._non_append_mutation
|
||||
copied._external_modification = self._external_modification
|
||||
copied._non_string_key = self._non_string_key
|
||||
copied = _DictWrapper(copy.copy(self.__wrapped__))
|
||||
copied._self_non_append_mutation = self._self_non_append_mutation
|
||||
copied._self_external_modification = self._self_external_modification
|
||||
copied._self_non_string_key = self._self_non_string_key
|
||||
return copied
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
copied = super(_DictWrapper, self).__deepcopy__(memo)
|
||||
copied._non_append_mutation = self._non_append_mutation
|
||||
copied._external_modification = self._external_modification
|
||||
copied._non_string_key = self._non_string_key
|
||||
copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
|
||||
copied._self_non_append_mutation = self._self_non_append_mutation
|
||||
copied._self_external_modification = self._self_external_modification
|
||||
copied._self_non_string_key = self._self_non_string_key
|
||||
return copied
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _make_storage(self, wrapped_dict):
|
||||
"""Re-use the wrapped dict for storage (to force them to be in sync)."""
|
||||
return wrapped_dict
|
||||
@property
|
||||
def _values(self):
|
||||
"""Collect values for TrackableDataStructure."""
|
||||
# Sort items deterministically by key
|
||||
ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
|
||||
if ordered:
|
||||
return ordered[1]
|
||||
return []
|
||||
|
||||
@property
|
||||
def _checkpoint_dependencies(self):
|
||||
"""Check that the object is saveable before listing its dependencies."""
|
||||
self._check_external_modification()
|
||||
if self._non_string_key:
|
||||
self._check_self_external_modification()
|
||||
if self._self_non_string_key:
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
"automatically on attribute assignment). The wrapped dictionary "
|
||||
@ -702,7 +725,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
|
||||
"object; it will be automatically un-wrapped and subsequently "
|
||||
"ignored." % (self,))
|
||||
if self._non_append_mutation:
|
||||
if self._self_non_append_mutation:
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
"automatically on attribute assignment). A key mapping to a "
|
||||
@ -711,7 +734,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
"dictionary checkpointed, wrap it in a "
|
||||
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
|
||||
"un-wrapped and subsequently ignored." % (self,))
|
||||
if self._external_modification:
|
||||
if self._self_external_modification:
|
||||
raise ValueError(
|
||||
"Unable to save the object %s (a dictionary wrapper constructed "
|
||||
"automatically on attribute assignment). The wrapped dictionary was "
|
||||
@ -721,30 +744,30 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
"dictionary checkpointed, wrap it in a "
|
||||
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
|
||||
"un-wrapped and subsequently ignored." % (
|
||||
self, self, self._last_wrapped_dict_snapshot))
|
||||
self, self, self._self_last_wrapped_dict_snapshot))
|
||||
assert not self._dirty # Any reason for dirtiness should have an exception.
|
||||
return super(_DictWrapper, self)._checkpoint_dependencies
|
||||
|
||||
@property
|
||||
def _dirty(self):
|
||||
"""Check if there has already been a mutation which prevents saving."""
|
||||
return (self._external_modification
|
||||
or self._non_append_mutation
|
||||
or self._non_string_key)
|
||||
return (self._self_external_modification
|
||||
or self._self_non_append_mutation
|
||||
or self._self_non_string_key)
|
||||
|
||||
def _check_external_modification(self):
|
||||
def _check_self_external_modification(self):
|
||||
"""Checks for any changes to the wrapped dict not through the wrapper."""
|
||||
if self._dirty:
|
||||
return
|
||||
if self != self._last_wrapped_dict_snapshot:
|
||||
self._external_modification = True
|
||||
self._last_wrapped_dict_snapshot = None
|
||||
if self != self._self_last_wrapped_dict_snapshot:
|
||||
self._self_external_modification = True
|
||||
self._self_last_wrapped_dict_snapshot = None
|
||||
|
||||
def _update_snapshot(self):
|
||||
"""Acknowledges tracked changes to the wrapped dict."""
|
||||
if self._dirty:
|
||||
return
|
||||
self._last_wrapped_dict_snapshot = dict(self)
|
||||
self._self_last_wrapped_dict_snapshot = dict(self)
|
||||
|
||||
def _track_value(self, value, name):
|
||||
"""Allows storage of non-trackable objects."""
|
||||
@ -759,7 +782,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
if not (string_key or no_dependency):
|
||||
# A non-string key maps to a trackable value. This data structure
|
||||
# is not saveable.
|
||||
self._non_string_key = True
|
||||
self._self_non_string_key = True
|
||||
return value
|
||||
except ValueError:
|
||||
# Even if this value isn't trackable, we need to make sure
|
||||
@ -768,15 +791,13 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
trackable=self, value=value, name=name)
|
||||
|
||||
def _name_element(self, key):
|
||||
"""Don't throw errors for non-string keys."""
|
||||
if isinstance(key, six.string_types):
|
||||
return super(_DictWrapper, self)._name_element(key)
|
||||
else:
|
||||
"""Tells TrackableDataStructure to use keys as names as-is."""
|
||||
return key
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Allow any modifications, but possibly mark the wrapper as unsaveable."""
|
||||
self._check_external_modification()
|
||||
self._check_self_external_modification()
|
||||
self._maybe_initialize_trackable()
|
||||
no_dep = isinstance(value, NoDependency)
|
||||
if isinstance(key, six.string_types):
|
||||
existing_dependency = self._lookup_dependency(key)
|
||||
@ -788,9 +809,9 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
# Non-string keys are OK as long as we have no reason to add a
|
||||
# dependency on the value (either because the value is not
|
||||
# trackable, or because it was wrapped in a NoDependency object).
|
||||
self._non_string_key = True
|
||||
if key in self._storage:
|
||||
previous_value = self._storage[key]
|
||||
self._self_non_string_key = True
|
||||
if key in self.__wrapped__:
|
||||
previous_value = self.__wrapped__[key]
|
||||
if previous_value is not value:
|
||||
if ((not no_dep and isinstance(value, base.Trackable))
|
||||
# We don't want to just check that the existing object is
|
||||
@ -800,32 +821,34 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
# A trackable object was replaced under the same key; this means
|
||||
# that restoring would be error-prone, so we'll throw an exception on
|
||||
# save.
|
||||
self._non_append_mutation = True
|
||||
self._storage[key] = value
|
||||
self._self_non_append_mutation = True
|
||||
self.__wrapped__[key] = value
|
||||
|
||||
self._update_snapshot()
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._check_external_modification()
|
||||
self._check_self_external_modification()
|
||||
existing_value = self[key]
|
||||
if isinstance(existing_value, base.Trackable):
|
||||
# Deleting tracked trackable values means restoring is problematic,
|
||||
# so we'll throw an exception on save.
|
||||
self._non_append_mutation = True
|
||||
del self._storage[key]
|
||||
self._self_non_append_mutation = True
|
||||
del self.__wrapped__[key]
|
||||
self._update_snapshot()
|
||||
|
||||
def __repr__(self):
|
||||
return "DictWrapper(%s)" % (repr(self._storage),)
|
||||
return "DictWrapper(%s)" % (repr(self.__wrapped__),)
|
||||
|
||||
def __hash__(self):
|
||||
raise TypeError("unhashable type: 'DictWrapper'")
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._storage == getattr(other, "_storage", other)
|
||||
# Override the TrackableDataStructure "== -> is" forwarding and go back to
|
||||
# the wrapt implementation.
|
||||
return self.__wrapped__ == other
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
for key, value in dict(*args, **kwargs).items():
|
||||
for key, value in six.iteritems(dict(*args, **kwargs)):
|
||||
self[key] = value
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
@ -838,6 +861,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
|
||||
def _is_function(x):
|
||||
return isinstance(x, (def_function.Function, defun.ConcreteFunction))
|
||||
|
||||
|
||||
revived_types.register_revived_type(
|
||||
"trackable_dict_wrapper",
|
||||
lambda obj: isinstance(obj, _DictWrapper),
|
||||
|
@ -17,12 +17,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -37,6 +39,8 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class HasList(training.Model):
|
||||
@ -106,6 +110,11 @@ class ListTests(test.TestCase):
|
||||
self.assertIn(v, model.trainable_variables)
|
||||
self.assertNotIn(v, model.non_trainable_variables)
|
||||
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = [1]
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testUpdatesForwarded(self):
|
||||
with context.graph_mode():
|
||||
@ -284,6 +293,20 @@ class ListWrapperTest(test.TestCase):
|
||||
if not_overridden:
|
||||
self.fail("_ListWrapper does not override %s" % (not_overridden))
|
||||
|
||||
def testSameStructure(self):
|
||||
l = [1]
|
||||
nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))
|
||||
|
||||
def testFunctionCaching(self):
|
||||
@def_function.function
|
||||
def f(list_input):
|
||||
return list_input[0] + constant_op.constant(1.)
|
||||
|
||||
first_trace = f.get_concrete_function([constant_op.constant(2.)])
|
||||
second_trace = f.get_concrete_function(
|
||||
data_structures._ListWrapper([constant_op.constant(3.)]))
|
||||
self.assertIs(first_trace, second_trace)
|
||||
|
||||
def testListWrapperBasic(self):
|
||||
# _ListWrapper, unlike List, compares like the built-in list type (since it
|
||||
# is used to automatically replace lists).
|
||||
@ -466,6 +489,11 @@ class MappingTests(test.TestCase):
|
||||
self.assertAllEqual(numpy.ones([6, 7]),
|
||||
self.evaluate(test_var))
|
||||
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.d = {"a": 2}
|
||||
json.dumps(obj.d, default=serialization.get_json_type)
|
||||
|
||||
def testNoOverwrite(self):
|
||||
mapping = data_structures.Mapping()
|
||||
original = data_structures.List()
|
||||
@ -581,7 +609,7 @@ class MappingTests(test.TestCase):
|
||||
model.d["a"] = []
|
||||
model.d.pop("a")
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
|
||||
with self.assertRaisesRegexp(ValueError, "Unable to save"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testExternalModificationNoSave(self):
|
||||
@ -709,7 +737,9 @@ class MappingTests(test.TestCase):
|
||||
original_sub = tracking.AutoTrackable()
|
||||
original.a = [[1.]]
|
||||
original.b = {"a": original_sub}
|
||||
self.assertIsInstance(original.b, dict)
|
||||
deep_copied = copy.deepcopy(original)
|
||||
self.assertIsInstance(deep_copied.b, dict)
|
||||
self.assertIsNot(original, deep_copied)
|
||||
self.assertIsNot(original_sub, deep_copied.b["a"])
|
||||
self.assertEqual([[1.]], deep_copied.a)
|
||||
@ -736,6 +766,20 @@ class MappingTests(test.TestCase):
|
||||
[1.]
|
||||
+ data_structures._ListWrapper([2.]))
|
||||
|
||||
def testSameStructure(self):
|
||||
d = {1: "a"}
|
||||
nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
|
||||
|
||||
def testFunctionCaching(self):
|
||||
@def_function.function
|
||||
def f(dict_input):
|
||||
return dict_input["x"] + constant_op.constant(1.)
|
||||
|
||||
first_trace = f.get_concrete_function({"x": constant_op.constant(2.)})
|
||||
second_trace = f.get_concrete_function(
|
||||
data_structures._DictWrapper({"x": constant_op.constant(3.)}))
|
||||
self.assertIs(first_trace, second_trace)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -71,7 +71,7 @@ class AutoTrackable(base.Trackable):
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Support self.foo = trackable syntax."""
|
||||
if getattr(self, "_setattr_tracking", True):
|
||||
if getattr(self, "_self_setattr_tracking", True):
|
||||
value = data_structures.sticky_attribute_assignment(
|
||||
trackable=self, value=value, name=name)
|
||||
super(AutoTrackable, self).__setattr__(name, value)
|
||||
|
@ -86,28 +86,6 @@ bool IsString(PyObject* o) {
|
||||
PyUnicode_Check(o);
|
||||
}
|
||||
|
||||
// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
|
||||
// and while we're at it give them consistent behavior by making sure the
|
||||
// returned value is a list.
|
||||
//
|
||||
// As with PyMapping_Keys, returns a new reference.
|
||||
//
|
||||
// On failure, returns nullptr.
|
||||
PyObject* MappingKeys(PyObject* o) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyMapping_Keys(o);
|
||||
#else
|
||||
static char key_method_name[] = "keys";
|
||||
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
|
||||
if (PyErr_Occurred() || raw_result.get() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return PySequence_Fast(
|
||||
raw_result.get(),
|
||||
"The '.keys()' method of a custom mapping returned a non-sequence.");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Equivalent to Python's 'o.__class__.__name__'
|
||||
// Note that '__class__' attribute is set only in new-style classes.
|
||||
// A lot of tensorflow code uses __class__ without checks, so it seems like
|
||||
@ -792,6 +770,28 @@ bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
|
||||
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
|
||||
bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
|
||||
|
||||
// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
|
||||
// and while we're at it give them consistent behavior by making sure the
|
||||
// returned value is a list.
|
||||
//
|
||||
// As with PyMapping_Keys, returns a new reference.
|
||||
//
|
||||
// On failure, returns nullptr.
|
||||
PyObject* MappingKeys(PyObject* o) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
return PyMapping_Keys(o);
|
||||
#else
|
||||
static char key_method_name[] = "keys";
|
||||
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
|
||||
if (PyErr_Occurred() || raw_result.get() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return PySequence_Fast(
|
||||
raw_result.get(),
|
||||
"The '.keys()' method of a custom mapping returned a non-sequence.");
|
||||
#endif
|
||||
}
|
||||
|
||||
PyObject* Flatten(PyObject* nested, bool expand_composites) {
|
||||
PyObject* list = PyList_New(0);
|
||||
const std::function<int(PyObject*)>& is_sequence_helper =
|
||||
|
@ -80,6 +80,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
|
||||
// True if the sequence subclasses mapping.
|
||||
bool IsMapping(PyObject* o);
|
||||
|
||||
// A version of PyMapping_Keys that works in C++11
|
||||
//
|
||||
// Args:
|
||||
// o: The input to extract keys from
|
||||
//
|
||||
// Returns:
|
||||
// A new reference to a list of keys in the mapping.
|
||||
PyObject* MappingKeys(PyObject* o);
|
||||
|
||||
// Returns a true if its input is an instance of an attr.s decorated class.
|
||||
//
|
||||
// Args:
|
||||
|
@ -62,6 +62,7 @@ REQUIRED_PACKAGES = [
|
||||
'tensorboard >= 1.13.0, < 1.14.0',
|
||||
'tensorflow_estimator >= 1.13.0rc0, < 1.14.0rc0',
|
||||
'termcolor >= 1.1.0',
|
||||
'wrapt >= 1.11.1',
|
||||
]
|
||||
|
||||
if sys.byteorder == 'little':
|
||||
|
@ -893,6 +893,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
build_file = clean_dep("//third_party:pybind11.BUILD"),
|
||||
)
|
||||
|
||||
tf_http_archive(
|
||||
name = "wrapt",
|
||||
build_file = clean_dep("//third_party:wrapt.BUILD"),
|
||||
sha256 = "8a6fb40e8f8b6a66b4ba81a4044c68e6a7b1782f21cfabc06fb765332b4c3e51",
|
||||
strip_prefix = "wrapt-1.11.1/src/wrapt",
|
||||
urls = [
|
||||
"http://mirror.tensorflow.org/github.com/GrahamDumpleton/wrapt/archive/1.11.1.tar.gz",
|
||||
"https://github.com/GrahamDumpleton/wrapt/archive/1.11.1.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# BIND DEFINITIONS
|
||||
#
|
||||
|
0
third_party/__init__.py
vendored
Normal file
0
third_party/__init__.py
vendored
Normal file
10
third_party/wrapt.BUILD
vendored
Normal file
10
third_party/wrapt.BUILD
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
py_library(
|
||||
name = "wrapt",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"decorators.py",
|
||||
"importer.py",
|
||||
"wrappers.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user