Switch to wrapt for trackable dict data structures instead of subclassing collections.Mapping

Adds a dependency on wrapt to TensorFlow's pip package. It's not a very large dependency, and this is a usability win when subclassing TensorFlow types (Model, Module, Layer, etc.).

wrapt fixes issues with CPython's direct access to dictionaries by inheriting the memory layout of the wrapped object, so we can now pass isinstance(obj, dict) checks while staying correct (i.e. {}.update(obj) won't look at the possibly-stale wrapper object's memory). There are several new oddities with the wrapt approach, but overall this is much better.

We're already doing dict wrapping, just not in a way that passes isinstance(obj, dict) checks. We need it to support restore-on-create for variables added to objects while executing eagerly.

I tried switching _ListWrapper to wrapt too. It works in most Python versions (with some tweaks to TF to accommodate the change), but apparently in Python 3.4 isinstance(obj, (list, tuple)) is not functioning properly. There are no correctness issues with actually subclassing list, and it means type(obj) shows up as a subclass of list, so this isn't too bad. Since ObjectProxy copies the memory layout of the wrapped object we can't do both of these at the same time.

PiperOrigin-RevId: 243118363
This commit is contained in:
Allen Lavoie 2019-04-11 12:23:41 -07:00 committed by TensorFlower Gardener
parent f8c7522bb4
commit 48cb1ae640
17 changed files with 274 additions and 137 deletions

View File

@ -108,8 +108,8 @@ class NumpyState(base.Trackable):
except AttributeError:
value = _NumpyWrapper(value)
self._track_trackable(value, name=name, overwrite=True)
elif (name not in ("_setattr_tracking", "_update_uid")
and getattr(self, "_setattr_tracking", True)):
elif (name not in ("_self_setattr_tracking", "_self_update_uid")
and getattr(self, "_self_setattr_tracking", True)):
# Mixing restore()-created attributes with user-added trackable
# objects is tricky, since we can't use the `_lookup_dependency` trick to
# re-create attributes (we might accidentally steal the restoration for
@ -154,4 +154,3 @@ class _NumpyWrapper(core_python_state.PythonState):
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()

View File

@ -216,6 +216,7 @@ tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
@ -247,6 +248,7 @@ tensorflow/third_party/codegen.BUILD
tensorflow/third_party/cub.BUILD
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/__init__.py
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/eigen.BUILD

View File

@ -3070,8 +3070,8 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
} else if (PyTuple_Check(arg)) {
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
} else if (PyDict_Check(arg)) {
tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
} else if (tensorflow::swig::IsMapping(arg)) {
tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
if (PyList_Sort(keys.get()) == -1) {
return tensorflow::errors::Internal("Unable to sort keys");
}
@ -3083,9 +3083,9 @@ tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
PyObject* key = PyList_GetItem(keys.get(), i);
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
PyObject* value = PyDict_GetItem(arg, key);
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(value, include_tensor_ranks_only, result));
tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
value.get(), include_tensor_ranks_only, result));
}
} else {
PyObject* object = PyWeakref_NewRef(arg, nullptr);

View File

@ -1921,8 +1921,8 @@ class Layer(module.Module):
[w for w in self._non_trainable_weights if w is not existing_value])
def __setattr__(self, name, value):
if (name == '_setattr_tracking' or
not getattr(self, '_setattr_tracking', True) or
if (name == '_self_setattr_tracking' or
not getattr(self, '_self_setattr_tracking', True) or
getattr(self, '_is_graph_network', False) or
# Exclude @property.setters from tracking
hasattr(self.__class__, name)):
@ -2279,3 +2279,4 @@ def default(method):
# Avoid breaking users who directly import this symbol from this file.
# TODO(fchollet): remove this.
InputSpec = input_spec.InputSpec # pylint:disable=invalid-name

View File

@ -405,7 +405,7 @@ class Network(base_layer.Layer):
layer, name='layer-%d' % layer_index, overwrite=True)
def __setattr__(self, name, value):
if not getattr(self, '_setattr_tracking', True):
if not getattr(self, '_self_setattr_tracking', True):
super(Network, self).__setattr__(name, value)
return

View File

@ -99,8 +99,8 @@ class Module(tracking.AutoTrackable):
# include when flattening (these reference dependencies reachable via other
# object attributes).
_TF_MODULE_IGNORED_PROPERTIES = frozenset((
"_unconditional_checkpoint_dependencies",
"_unconditional_dependency_names"
"_self_unconditional_checkpoint_dependencies",
"_self_unconditional_dependency_names"
))
def __init__(self, name=None):

View File

@ -76,6 +76,7 @@ py_library(
":base",
":layer_utils",
"//tensorflow/python/saved_model:revived_types",
"@wrapt",
],
)

View File

@ -452,12 +452,12 @@ def no_automatic_dependency_tracking(method):
"""
def _method_wrapper(self, *args, **kwargs):
previous_value = getattr(self, "_setattr_tracking", True)
self._setattr_tracking = False # pylint: disable=protected-access
previous_value = getattr(self, "_self_setattr_tracking", True)
self._self_setattr_tracking = False # pylint: disable=protected-access
try:
result = method(self, *args, **kwargs)
finally:
self._setattr_tracking = previous_value # pylint: disable=protected-access
self._self_setattr_tracking = previous_value # pylint: disable=protected-access
return result
return tf_decorator.make_decorator(
@ -473,6 +473,40 @@ class Trackable(object):
checks.
"""
# For compatibility with wrapt.ObjectProxy, attributes are all prefixed with
# _self_. We have some properties to forward semi-public attributes to their
# _self_ equivalents.
@property
def _setattr_tracking(self):
if not hasattr(self, "_self_setattr_tracking"):
self._self_setattr_tracking = True
return self._self_setattr_tracking
@_setattr_tracking.setter
def _setattr_tracking(self, value):
self._self_setattr_tracking = value
@property
def _update_uid(self):
return self._self_update_uid
@_update_uid.setter
def _update_uid(self, value):
self._self_update_uid = value
@property
def _unconditional_checkpoint_dependencies(self):
return self._self_unconditional_checkpoint_dependencies
@property
def _unconditional_dependency_names(self):
return self._self_unconditional_dependency_names
@property
def _name_based_restores(self):
return self._self_name_based_restores
# Trackable does not do automatic dependency tracking, but uses the
# no_automatic_dependency_tracking decorator so it can avoid adding
# dependencies if a subclass is Trackable / inherits from Model (both of
@ -483,7 +517,7 @@ class Trackable(object):
Not __init__, since most objects will forget to call it.
"""
if hasattr(self, "_unconditional_checkpoint_dependencies"):
if hasattr(self, "_self_unconditional_checkpoint_dependencies"):
# __init__ already called. This check means that we don't need
# Trackable.__init__() in the constructor of every TensorFlow object.
return
@ -491,21 +525,21 @@ class Trackable(object):
# `Trackable`, notably `Optimizer`s, may override the
# _checkpoint_dependencies property with conditional dependencies
# (e.g. based on the current graph when saving).
self._unconditional_checkpoint_dependencies = []
self._self_unconditional_checkpoint_dependencies = []
# Maps names -> Trackable objects
self._unconditional_dependency_names = {}
self._self_unconditional_dependency_names = {}
# Restorations for other Trackable objects on which this object may
# eventually depend. Maps local name -> CheckpointPosition list. Optimizers
# tack on conditional dependencies, and so need separate management of
# deferred dependencies too.
self._unconditional_deferred_dependencies = {}
self._self_unconditional_deferred_dependencies = {}
# The UID of the highest assignment to this object. Used to ensure that the
# last requested assignment determines the final value of an object.
if hasattr(self, "_update_uid"):
if hasattr(self, "_self_update_uid"):
raise AssertionError(
"Internal error: the object had an update UID set before its "
"initialization code was run.")
self._update_uid = -1
self._self_update_uid = -1
# When executing eagerly, holds a collection of _NameBasedRestoreCoordinator
# instances, which should be checked when creating variables or other
# saveables. These are passed on recursively to all dependencies, since
@ -513,7 +547,7 @@ class Trackable(object):
# being restored in advance. This mechanism is only necessary for
# restore-on-create when executing eagerly, and so is unused when graph
# building.
self._name_based_restores = set()
self._self_name_based_restores = set()
def _no_dependency(self, value):
"""If automatic dependency tracking is enabled, ignores `value`."""
@ -521,10 +555,10 @@ class Trackable(object):
def _name_based_attribute_restore(self, checkpoint):
"""Restore the object's attributes from a name-based checkpoint."""
self._name_based_restores.add(checkpoint)
if self._update_uid < checkpoint.restore_uid:
self._self_name_based_restores.add(checkpoint)
if self._self_update_uid < checkpoint.restore_uid:
checkpoint.eager_restore(self)
self._update_uid = checkpoint.restore_uid
self._self_update_uid = checkpoint.restore_uid
@property
def _checkpoint_dependencies(self):
@ -537,7 +571,7 @@ class Trackable(object):
`Trackable` dependencies which should be saved along with this
object.
"""
return self._unconditional_checkpoint_dependencies
return self._self_unconditional_checkpoint_dependencies
@property
def _deferred_dependencies(self):
@ -552,7 +586,7 @@ class Trackable(object):
A dictionary mapping from local name to a list of CheckpointPosition
objects.
"""
return self._unconditional_deferred_dependencies
return self._self_unconditional_deferred_dependencies
def _lookup_dependency(self, name):
"""Look up a dependency by name.
@ -565,7 +599,7 @@ class Trackable(object):
A `Trackable` object, or `None` if no dependency by this name was
found.
"""
return self._unconditional_dependency_names.get(name, None)
return self._self_unconditional_dependency_names.get(name, None)
def _add_variable_with_custom_getter(
self, name, shape=None, dtype=dtypes.float32,
@ -715,14 +749,15 @@ class Trackable(object):
# This is a weird thing to do, but we're not going to stop people from
# using __setattr__.
for index, (old_name, _) in enumerate(
self._unconditional_checkpoint_dependencies):
self._self_unconditional_checkpoint_dependencies):
if name == old_name:
self._unconditional_checkpoint_dependencies[index] = new_reference
self._self_unconditional_checkpoint_dependencies[
index] = new_reference
elif current_object is None:
self._unconditional_checkpoint_dependencies.append(new_reference)
self._self_unconditional_checkpoint_dependencies.append(new_reference)
self._handle_deferred_dependencies(
name=name, trackable=trackable)
self._unconditional_dependency_names[name] = trackable
self._self_unconditional_dependency_names[name] = trackable
return trackable
def _handle_deferred_dependencies(self, name, trackable):
@ -759,7 +794,7 @@ class Trackable(object):
# Pass on any name-based restores queued in this object.
for name_based_restore in sorted(
self._name_based_restores,
self._self_name_based_restores,
key=lambda checkpoint: checkpoint.restore_uid,
reverse=True):
trackable._name_based_attribute_restore(name_based_restore) # pylint: disable=protected-access
@ -789,9 +824,9 @@ class Trackable(object):
# If the UID of this restore is lower than our current update UID, we don't
# need to actually restore the object. However, we should pass the
# restoration on to our dependencies.
if checkpoint.restore_uid > self._update_uid:
if checkpoint.restore_uid > self._self_update_uid:
restore_ops = checkpoint_position.restore_ops()
self._update_uid = checkpoint.restore_uid
self._self_update_uid = checkpoint.restore_uid
else:
restore_ops = ()
for child in checkpoint_position.object_proto.children:

View File

@ -23,6 +23,11 @@ import operator
import sys
import six
try:
import wrapt
except ImportError:
# Fall back to the build-time dependency if the system package is not available.
from .....third_party import wrapt
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@ -133,15 +138,25 @@ class TrackableDataStructure(base.Trackable):
"""Base class for data structures which contain trackable objects."""
def __init__(self):
self.trainable = True
self._extra_variables = []
# Attributes prefixed with "_self_" for compatibility with
# wrapt.ObjectProxy.
self._self_trainable = True
self._self_extra_variables = []
@property
def trainable(self):
return self._self_trainable
@trainable.setter
def trainable(self, value):
self._self_trainable = value
def _track_value(self, value, name):
"""Add a dependency on `value`."""
value = sticky_attribute_assignment(
trackable=self, value=value, name=name)
if isinstance(value, variables.Variable):
self._extra_variables.append(value)
self._self_extra_variables.append(value)
if not isinstance(value, base.Trackable):
raise _UntrackableError(value)
if hasattr(value, "_use_resource_variables"):
@ -177,14 +192,14 @@ class TrackableDataStructure(base.Trackable):
return layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._extra_variables)
extra_variables=self._self_extra_variables)
@property
def non_trainable_weights(self):
return layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self._layers,
extra_variables=self._extra_variables)
extra_variables=self._self_extra_variables)
@property
def weights(self):
@ -295,6 +310,7 @@ class List(TrackableDataStructure, collections.Sequence):
@property
def _values(self):
"""Collect values for TrackableDataStructure."""
return self
def append(self, value):
@ -353,6 +369,8 @@ class List(TrackableDataStructure, collections.Sequence):
# TODO(tomhennigan) Update to collections.UserList?
# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
# Python 3.4 support (may still be tricky).
class _ListWrapper(List, collections.MutableSequence,
# Shadowed, but there for isinstance checks.
list):
@ -452,7 +470,7 @@ class _ListWrapper(List, collections.MutableSequence,
if isinstance(key, slice):
# Note: this is quite inefficient, but the list API supports a broad range
# of slice setters (e.g. truncate, extend, replace) and immitating this
# of slice setters (e.g. truncate, extend, replace) and imitating this
# for a range of Python versions is non-trivial.
storage_copy = list(self._storage)
self._storage[key] = value
@ -587,6 +605,7 @@ class Mapping(TrackableDataStructure, collections.Mapping):
@property
def _values(self):
"""Collect values for TrackableDataStructure."""
# Sort items deterministically by key
ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
if ordered:
@ -627,19 +646,7 @@ class Mapping(TrackableDataStructure, collections.Mapping):
return iter(self._storage)
# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance
# checks seems infeasible. CPython will not call Python methods/properties on
# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead
# collects elements directly from dict_subclass's C structs. So subclassing dict
# implies that the storage has to be "self" (i.e. the C structs for the object
# must be updated correctly), but we also need that storage to be the wrapped
# dictionary to avoid synchronization bugs (un-tracked external modifications
# should still show up when the dict is accessed through the wrapper). Monkey
# patching all of the "wrapped" dict's methods instead of creating a wrapper
# object is an option, but not a very attractive one (replacing methods without
# creating reference cycles is difficult, and then dicts would need to be
# special cased everywhere as being trackable).
class _DictWrapper(Mapping, collections.MutableMapping):
class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
"""Wraps built-in dicts to support restore-on-create for variables.
_DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
@ -648,52 +655,68 @@ class _DictWrapper(Mapping, collections.MutableMapping):
_DictWrapper will raise an exception on save.
"""
def __new__(cls, *args):
if len(args) == 1 and isinstance(args[0], dict):
return super(_DictWrapper, cls).__new__(cls)
else:
# Allow construction from a sequence, e.g. for nest.pack_sequence_as. In
# this case there's nothing to wrap, so we make a normal dictionary. Also
# allows constructing empty instances of the _DictWrapper type, as Session
# is wont to do (and again there's nothing to wrap, so a normal dictionary
# makes more sense).
return dict(*args)
def __init__(self, wrapped_dict):
self._non_string_key = False
self._non_append_mutation = False
self._external_modification = False
super(_DictWrapper, self).__init__(wrapped_dict)
def __init__(self, wrapped_dict=None):
if wrapped_dict is None:
# Allow zero-argument construction, e.g. from session.run's re-wrapping.
wrapped_dict = {}
if not isinstance(wrapped_dict, collections.Mapping):
# Allow construction from a sequence, e.g. from nest.pack_sequence_as.
wrapped_dict = dict(wrapped_dict)
wrapt.ObjectProxy.__init__(self, wrapped_dict)
TrackableDataStructure.__init__(self)
self._self_non_string_key = False
self._self_non_append_mutation = False
self._self_external_modification = False
self.__wrapped__.update(
{key: self._track_value(
value, name=self._name_element(key))
for key, value in self.__wrapped__.items()})
self._update_snapshot()
def __getattribute__(self, name):
if (hasattr(type(self), name)
and isinstance(getattr(type(self), name), property)):
# Bypass ObjectProxy for properties. Whether this workaround is necessary
# appears to depend on the Python version but not the wrapt version: 3.4
# in particular seems to look up properties on the wrapped object instead
# of the wrapper without this logic.
return object.__getattribute__(self, name)
else:
return super(_DictWrapper, self).__getattribute__(name)
def copy(self):
return copy.copy(self)
# pylint: disable=protected-access
def __copy__(self):
copied = super(_DictWrapper, self).__copy__()
copied._non_append_mutation = self._non_append_mutation
copied._external_modification = self._external_modification
copied._non_string_key = self._non_string_key
copied = _DictWrapper(copy.copy(self.__wrapped__))
copied._self_non_append_mutation = self._self_non_append_mutation
copied._self_external_modification = self._self_external_modification
copied._self_non_string_key = self._self_non_string_key
return copied
def __deepcopy__(self, memo):
copied = super(_DictWrapper, self).__deepcopy__(memo)
copied._non_append_mutation = self._non_append_mutation
copied._external_modification = self._external_modification
copied._non_string_key = self._non_string_key
copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo))
copied._self_non_append_mutation = self._self_non_append_mutation
copied._self_external_modification = self._self_external_modification
copied._self_non_string_key = self._self_non_string_key
return copied
# pylint: enable=protected-access
def _make_storage(self, wrapped_dict):
"""Re-use the wrapped dict for storage (to force them to be in sync)."""
return wrapped_dict
@property
def _values(self):
"""Collect values for TrackableDataStructure."""
# Sort items deterministically by key
ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
if ordered:
return ordered[1]
return []
@property
def _checkpoint_dependencies(self):
"""Check that the object is saveable before listing its dependencies."""
self._check_external_modification()
if self._non_string_key:
self._check_self_external_modification()
if self._self_non_string_key:
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). The wrapped dictionary "
@ -702,7 +725,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
"checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
"object; it will be automatically un-wrapped and subsequently "
"ignored." % (self,))
if self._non_append_mutation:
if self._self_non_append_mutation:
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). A key mapping to a "
@ -711,7 +734,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
"dictionary checkpointed, wrap it in a "
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
"un-wrapped and subsequently ignored." % (self,))
if self._external_modification:
if self._self_external_modification:
raise ValueError(
"Unable to save the object %s (a dictionary wrapper constructed "
"automatically on attribute assignment). The wrapped dictionary was "
@ -721,30 +744,30 @@ class _DictWrapper(Mapping, collections.MutableMapping):
"dictionary checkpointed, wrap it in a "
"tf.contrib.checkpoint.NoDependency object; it will be automatically "
"un-wrapped and subsequently ignored." % (
self, self, self._last_wrapped_dict_snapshot))
self, self, self._self_last_wrapped_dict_snapshot))
assert not self._dirty # Any reason for dirtiness should have an exception.
return super(_DictWrapper, self)._checkpoint_dependencies
@property
def _dirty(self):
"""Check if there has already been a mutation which prevents saving."""
return (self._external_modification
or self._non_append_mutation
or self._non_string_key)
return (self._self_external_modification
or self._self_non_append_mutation
or self._self_non_string_key)
def _check_external_modification(self):
def _check_self_external_modification(self):
"""Checks for any changes to the wrapped dict not through the wrapper."""
if self._dirty:
return
if self != self._last_wrapped_dict_snapshot:
self._external_modification = True
self._last_wrapped_dict_snapshot = None
if self != self._self_last_wrapped_dict_snapshot:
self._self_external_modification = True
self._self_last_wrapped_dict_snapshot = None
def _update_snapshot(self):
"""Acknowledges tracked changes to the wrapped dict."""
if self._dirty:
return
self._last_wrapped_dict_snapshot = dict(self)
self._self_last_wrapped_dict_snapshot = dict(self)
def _track_value(self, value, name):
"""Allows storage of non-trackable objects."""
@ -759,7 +782,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
if not (string_key or no_dependency):
# A non-string key maps to a trackable value. This data structure
# is not saveable.
self._non_string_key = True
self._self_non_string_key = True
return value
except ValueError:
# Even if this value isn't trackable, we need to make sure
@ -768,15 +791,13 @@ class _DictWrapper(Mapping, collections.MutableMapping):
trackable=self, value=value, name=name)
def _name_element(self, key):
"""Don't throw errors for non-string keys."""
if isinstance(key, six.string_types):
return super(_DictWrapper, self)._name_element(key)
else:
return key
"""Tells TrackableDataStructure to use keys as names as-is."""
return key
def __setitem__(self, key, value):
"""Allow any modifications, but possibly mark the wrapper as unsaveable."""
self._check_external_modification()
self._check_self_external_modification()
self._maybe_initialize_trackable()
no_dep = isinstance(value, NoDependency)
if isinstance(key, six.string_types):
existing_dependency = self._lookup_dependency(key)
@ -788,9 +809,9 @@ class _DictWrapper(Mapping, collections.MutableMapping):
# Non-string keys are OK as long as we have no reason to add a
# dependency on the value (either because the value is not
# trackable, or because it was wrapped in a NoDependency object).
self._non_string_key = True
if key in self._storage:
previous_value = self._storage[key]
self._self_non_string_key = True
if key in self.__wrapped__:
previous_value = self.__wrapped__[key]
if previous_value is not value:
if ((not no_dep and isinstance(value, base.Trackable))
# We don't want to just check that the existing object is
@ -800,32 +821,34 @@ class _DictWrapper(Mapping, collections.MutableMapping):
# A trackable object was replaced under the same key; this means
# that restoring would be error-prone, so we'll throw an exception on
# save.
self._non_append_mutation = True
self._storage[key] = value
self._self_non_append_mutation = True
self.__wrapped__[key] = value
self._update_snapshot()
def __delitem__(self, key):
self._check_external_modification()
self._check_self_external_modification()
existing_value = self[key]
if isinstance(existing_value, base.Trackable):
# Deleting tracked trackable values means restoring is problematic,
# so we'll throw an exception on save.
self._non_append_mutation = True
del self._storage[key]
self._self_non_append_mutation = True
del self.__wrapped__[key]
self._update_snapshot()
def __repr__(self):
return "DictWrapper(%s)" % (repr(self._storage),)
return "DictWrapper(%s)" % (repr(self.__wrapped__),)
def __hash__(self):
raise TypeError("unhashable type: 'DictWrapper'")
def __eq__(self, other):
return self._storage == getattr(other, "_storage", other)
# Override the TrackableDataStructure "== -> is" forwarding and go back to
# the wrapt implementation.
return self.__wrapped__ == other
def update(self, *args, **kwargs):
for key, value in dict(*args, **kwargs).items():
for key, value in six.iteritems(dict(*args, **kwargs)):
self[key] = value
def _list_functions_for_serialization(self):
@ -838,6 +861,7 @@ class _DictWrapper(Mapping, collections.MutableMapping):
def _is_function(x):
return isinstance(x, (def_function.Function, defun.ConcreteFunction))
revived_types.register_revived_type(
"trackable_dict_wrapper",
lambda obj: isinstance(obj, _DictWrapper),

View File

@ -17,12 +17,14 @@ from __future__ import division
from __future__ import print_function
import copy
import json
import os
import numpy
import six
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
@ -37,6 +39,8 @@ from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
class HasList(training.Model):
@ -106,6 +110,11 @@ class ListTests(test.TestCase):
self.assertIn(v, model.trainable_variables)
self.assertNotIn(v, model.non_trainable_variables)
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.l = [1]
json.dumps(obj.l, default=serialization.get_json_type)
@test_util.run_v1_only("b/120545219")
def testUpdatesForwarded(self):
with context.graph_mode():
@ -284,6 +293,20 @@ class ListWrapperTest(test.TestCase):
if not_overridden:
self.fail("_ListWrapper does not override %s" % (not_overridden))
def testSameStructure(self):
l = [1]
nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))
def testFunctionCaching(self):
@def_function.function
def f(list_input):
return list_input[0] + constant_op.constant(1.)
first_trace = f.get_concrete_function([constant_op.constant(2.)])
second_trace = f.get_concrete_function(
data_structures._ListWrapper([constant_op.constant(3.)]))
self.assertIs(first_trace, second_trace)
def testListWrapperBasic(self):
# _ListWrapper, unlike List, compares like the built-in list type (since it
# is used to automatically replace lists).
@ -466,6 +489,11 @@ class MappingTests(test.TestCase):
self.assertAllEqual(numpy.ones([6, 7]),
self.evaluate(test_var))
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.d = {"a": 2}
json.dumps(obj.d, default=serialization.get_json_type)
def testNoOverwrite(self):
mapping = data_structures.Mapping()
original = data_structures.List()
@ -581,7 +609,7 @@ class MappingTests(test.TestCase):
model.d["a"] = []
model.d.pop("a")
save_path = os.path.join(self.get_temp_dir(), "ckpt")
with self.assertRaisesRegexp(ValueError, "overwritten or deleted"):
with self.assertRaisesRegexp(ValueError, "Unable to save"):
model.save_weights(save_path)
def testExternalModificationNoSave(self):
@ -709,7 +737,9 @@ class MappingTests(test.TestCase):
original_sub = tracking.AutoTrackable()
original.a = [[1.]]
original.b = {"a": original_sub}
self.assertIsInstance(original.b, dict)
deep_copied = copy.deepcopy(original)
self.assertIsInstance(deep_copied.b, dict)
self.assertIsNot(original, deep_copied)
self.assertIsNot(original_sub, deep_copied.b["a"])
self.assertEqual([[1.]], deep_copied.a)
@ -736,6 +766,20 @@ class MappingTests(test.TestCase):
[1.]
+ data_structures._ListWrapper([2.]))
def testSameStructure(self):
d = {1: "a"}
nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
def testFunctionCaching(self):
@def_function.function
def f(dict_input):
return dict_input["x"] + constant_op.constant(1.)
first_trace = f.get_concrete_function({"x": constant_op.constant(2.)})
second_trace = f.get_concrete_function(
data_structures._DictWrapper({"x": constant_op.constant(3.)}))
self.assertIs(first_trace, second_trace)
if __name__ == "__main__":
test.main()

View File

@ -71,7 +71,7 @@ class AutoTrackable(base.Trackable):
def __setattr__(self, name, value):
"""Support self.foo = trackable syntax."""
if getattr(self, "_setattr_tracking", True):
if getattr(self, "_self_setattr_tracking", True):
value = data_structures.sticky_attribute_assignment(
trackable=self, value=value, name=name)
super(AutoTrackable, self).__setattr__(name, value)

View File

@ -86,28 +86,6 @@ bool IsString(PyObject* o) {
PyUnicode_Check(o);
}
// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
// and while we're at it give them consistent behavior by making sure the
// returned value is a list.
//
// As with PyMapping_Keys, returns a new reference.
//
// On failure, returns nullptr.
PyObject* MappingKeys(PyObject* o) {
#if PY_MAJOR_VERSION >= 3
return PyMapping_Keys(o);
#else
static char key_method_name[] = "keys";
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
if (PyErr_Occurred() || raw_result.get() == nullptr) {
return nullptr;
}
return PySequence_Fast(
raw_result.get(),
"The '.keys()' method of a custom mapping returned a non-sequence.");
#endif
}
// Equivalent to Python's 'o.__class__.__name__'
// Note that '__class__' attribute is set only in new-style classes.
// A lot of tensorflow code uses __class__ without checks, so it seems like
@ -792,6 +770,28 @@ bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
// and while we're at it give them consistent behavior by making sure the
// returned value is a list.
//
// As with PyMapping_Keys, returns a new reference.
//
// On failure, returns nullptr.
PyObject* MappingKeys(PyObject* o) {
#if PY_MAJOR_VERSION >= 3
return PyMapping_Keys(o);
#else
static char key_method_name[] = "keys";
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
if (PyErr_Occurred() || raw_result.get() == nullptr) {
return nullptr;
}
return PySequence_Fast(
raw_result.get(),
"The '.keys()' method of a custom mapping returned a non-sequence.");
#endif
}
PyObject* Flatten(PyObject* nested, bool expand_composites) {
PyObject* list = PyList_New(0);
const std::function<int(PyObject*)>& is_sequence_helper =

View File

@ -80,6 +80,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);
// A version of PyMapping_Keys that works in C++11
//
// Args:
// o: The input to extract keys from
//
// Returns:
// A new reference to a list of keys in the mapping.
PyObject* MappingKeys(PyObject* o);
// Returns a true if its input is an instance of an attr.s decorated class.
//
// Args:

View File

@ -62,6 +62,7 @@ REQUIRED_PACKAGES = [
'tensorboard >= 1.13.0, < 1.14.0',
'tensorflow_estimator >= 1.13.0rc0, < 1.14.0rc0',
'termcolor >= 1.1.0',
'wrapt >= 1.11.1',
]
if sys.byteorder == 'little':

View File

@ -893,6 +893,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
build_file = clean_dep("//third_party:pybind11.BUILD"),
)
tf_http_archive(
name = "wrapt",
build_file = clean_dep("//third_party:wrapt.BUILD"),
sha256 = "8a6fb40e8f8b6a66b4ba81a4044c68e6a7b1782f21cfabc06fb765332b4c3e51",
strip_prefix = "wrapt-1.11.1/src/wrapt",
urls = [
"http://mirror.tensorflow.org/github.com/GrahamDumpleton/wrapt/archive/1.11.1.tar.gz",
"https://github.com/GrahamDumpleton/wrapt/archive/1.11.1.tar.gz",
],
)
##############################################################################
# BIND DEFINITIONS
#

0
third_party/__init__.py vendored Normal file
View File

10
third_party/wrapt.BUILD vendored Normal file
View File

@ -0,0 +1,10 @@
py_library(
name = "wrapt",
srcs = [
"__init__.py",
"decorators.py",
"importer.py",
"wrappers.py",
],
visibility = ["//visibility:public"],
)