Added nest support for attr.s decorated classes.
PiperOrigin-RevId: 214781911
This commit is contained in:
parent
77244534c0
commit
cd1bdeafec
@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
|
|||||||
Python sequence, tuple (including `namedtuple`), or dict that can contain
|
Python sequence, tuple (including `namedtuple`), or dict that can contain
|
||||||
further sequences, tuples, and dicts.
|
further sequences, tuples, and dicts.
|
||||||
|
|
||||||
|
attr.s decorated classes (http://www.attrs.org) are also supported, in the
|
||||||
|
same way as `namedtuple`.
|
||||||
|
|
||||||
The utilities here assume (and do not check) that the nested structures form a
|
The utilities here assume (and do not check) that the nested structures form a
|
||||||
'tree', i.e., no references in the structure of the input of these functions
|
'tree', i.e., no references in the structure of the input of these functions
|
||||||
should be recursive.
|
should be recursive.
|
||||||
@ -38,6 +41,12 @@ import six as _six
|
|||||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attrs_values(obj):
|
||||||
|
"""Returns the list of values from an attrs instance."""
|
||||||
|
attrs = getattr(obj.__class__, "__attrs_attrs__")
|
||||||
|
return [getattr(obj, a.name) for a in attrs]
|
||||||
|
|
||||||
|
|
||||||
def _sorted(dict_):
|
def _sorted(dict_):
|
||||||
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
||||||
try:
|
try:
|
||||||
@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
|
|||||||
|
|
||||||
# See the swig file (util.i) for documentation.
|
# See the swig file (util.i) for documentation.
|
||||||
_is_mapping = _pywrap_tensorflow.IsMapping
|
_is_mapping = _pywrap_tensorflow.IsMapping
|
||||||
|
_is_attrs = _pywrap_tensorflow.IsAttrs
|
||||||
|
|
||||||
|
|
||||||
def _sequence_like(instance, args):
|
def _sequence_like(instance, args):
|
||||||
@ -85,7 +95,7 @@ def _sequence_like(instance, args):
|
|||||||
# corresponding `OrderedDict` to pack it back).
|
# corresponding `OrderedDict` to pack it back).
|
||||||
result = dict(zip(_sorted(instance), args))
|
result = dict(zip(_sorted(instance), args))
|
||||||
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
||||||
elif _is_namedtuple(instance):
|
elif _is_namedtuple(instance) or _is_attrs(instance):
|
||||||
return type(instance)(*args)
|
return type(instance)(*args)
|
||||||
else:
|
else:
|
||||||
# Not a namedtuple
|
# Not a namedtuple
|
||||||
@ -93,6 +103,7 @@ def _sequence_like(instance, args):
|
|||||||
|
|
||||||
|
|
||||||
def _yield_value(iterable):
|
def _yield_value(iterable):
|
||||||
|
"""Yields the next value from the given iterable."""
|
||||||
if _is_mapping(iterable):
|
if _is_mapping(iterable):
|
||||||
# Iterate through dictionaries in a deterministic order by sorting the
|
# Iterate through dictionaries in a deterministic order by sorting the
|
||||||
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||||||
@ -101,6 +112,9 @@ def _yield_value(iterable):
|
|||||||
# corresponding `OrderedDict` to pack it back).
|
# corresponding `OrderedDict` to pack it back).
|
||||||
for key in _sorted(iterable):
|
for key in _sorted(iterable):
|
||||||
yield iterable[key]
|
yield iterable[key]
|
||||||
|
elif _is_attrs(iterable):
|
||||||
|
for value in _get_attrs_values(iterable):
|
||||||
|
yield value
|
||||||
else:
|
else:
|
||||||
for value in iterable:
|
for value in iterable:
|
||||||
yield value
|
yield value
|
||||||
|
@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import attr # pylint:disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
attr = None
|
||||||
|
|
||||||
|
|
||||||
class _CustomMapping(collections.Mapping):
|
class _CustomMapping(collections.Mapping):
|
||||||
|
|
||||||
@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
|
|
||||||
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
|
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if attr:
|
||||||
|
class BadAttr(object):
|
||||||
|
"""Class that has a non-iterable __attrs_attrs__."""
|
||||||
|
__attrs_attrs__ = None
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class SampleAttr(object):
|
||||||
|
field1 = attr.ib()
|
||||||
|
field2 = attr.ib()
|
||||||
|
|
||||||
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
|
def testAttrsFlattenAndPack(self):
|
||||||
|
if attr is None:
|
||||||
|
self.skipTest("attr module is unavailable.")
|
||||||
|
|
||||||
|
field_values = [1, 2]
|
||||||
|
sample_attr = NestTest.SampleAttr(*field_values)
|
||||||
|
self.assertFalse(nest._is_attrs(field_values))
|
||||||
|
self.assertTrue(nest._is_attrs(sample_attr))
|
||||||
|
flat = nest.flatten(sample_attr)
|
||||||
|
self.assertEqual(field_values, flat)
|
||||||
|
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
|
||||||
|
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
|
||||||
|
self.assertEqual(restructured_from_flat, sample_attr)
|
||||||
|
|
||||||
|
# Check that flatten fails if attributes are not iterable
|
||||||
|
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
|
||||||
|
flat = nest.flatten(NestTest.BadAttr())
|
||||||
|
|
||||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
def testFlattenAndPack(self):
|
def testFlattenAndPack(self):
|
||||||
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
|
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
|
||||||
|
@ -192,6 +192,19 @@ int IsMappingHelper(PyObject* o) {
|
|||||||
return check_cache->CachedLookup(o);
|
return check_cache->CachedLookup(o);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns 1 if `o` is an instance of attrs-decorated class.
|
||||||
|
// Returns 0 otherwise.
|
||||||
|
int IsAttrsHelper(PyObject* o) {
|
||||||
|
Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
|
||||||
|
if (cls) {
|
||||||
|
return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
|
||||||
|
} else {
|
||||||
|
// PyObject_GetAttrString returns null on error
|
||||||
|
PyErr_Clear();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
|
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
|
||||||
// Returns 0 otherwise.
|
// Returns 0 otherwise.
|
||||||
// Returns -1 if an error occurred.
|
// Returns -1 if an error occurred.
|
||||||
@ -206,6 +219,7 @@ int IsSequenceHelper(PyObject* o) {
|
|||||||
});
|
});
|
||||||
// We treat dicts and other mappings as special cases of sequences.
|
// We treat dicts and other mappings as special cases of sequences.
|
||||||
if (IsMappingHelper(o)) return true;
|
if (IsMappingHelper(o)) return true;
|
||||||
|
if (IsAttrsHelper(o)) return true;
|
||||||
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
|
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
|
||||||
LOG(WARNING) << "Sets are not currently considered sequences, "
|
LOG(WARNING) << "Sets are not currently considered sequences, "
|
||||||
"but this may change in the future, "
|
"but this may change in the future, "
|
||||||
@ -354,6 +368,38 @@ class SparseTensorValueIterator : public ValueIterator {
|
|||||||
Safe_PyObjectPtr tensor_;
|
Safe_PyObjectPtr tensor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class AttrsValueIterator : public ValueIterator {
|
||||||
|
public:
|
||||||
|
explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
|
||||||
|
Py_INCREF(nested);
|
||||||
|
cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
|
||||||
|
if (cls_) {
|
||||||
|
attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
|
||||||
|
if (attrs_) {
|
||||||
|
iter_.reset(PyObject_GetIter(attrs_.get()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!iter_ || PyErr_Occurred()) invalidate();
|
||||||
|
}
|
||||||
|
|
||||||
|
Safe_PyObjectPtr next() override {
|
||||||
|
Safe_PyObjectPtr result;
|
||||||
|
Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
|
||||||
|
if (item) {
|
||||||
|
Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
|
||||||
|
result.reset(PyObject_GetAttr(nested_.get(), name.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Safe_PyObjectPtr nested_;
|
||||||
|
Safe_PyObjectPtr cls_;
|
||||||
|
Safe_PyObjectPtr attrs_;
|
||||||
|
Safe_PyObjectPtr iter_;
|
||||||
|
};
|
||||||
|
|
||||||
bool IsSparseTensorValueType(PyObject* o) {
|
bool IsSparseTensorValueType(PyObject* o) {
|
||||||
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
|
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
|
||||||
return false;
|
return false;
|
||||||
@ -372,6 +418,8 @@ ValueIteratorPtr GetValueIterator(PyObject* nested) {
|
|||||||
return absl::make_unique<DictValueIterator>(nested);
|
return absl::make_unique<DictValueIterator>(nested);
|
||||||
} else if (IsMappingHelper(nested)) {
|
} else if (IsMappingHelper(nested)) {
|
||||||
return absl::make_unique<MappingValueIterator>(nested);
|
return absl::make_unique<MappingValueIterator>(nested);
|
||||||
|
} else if (IsAttrsHelper(nested)) {
|
||||||
|
return absl::make_unique<AttrsValueIterator>(nested);
|
||||||
} else {
|
} else {
|
||||||
return absl::make_unique<SequenceValueIterator>(nested);
|
return absl::make_unique<SequenceValueIterator>(nested);
|
||||||
}
|
}
|
||||||
@ -383,6 +431,8 @@ ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
|
|||||||
return absl::make_unique<DictValueIterator>(nested);
|
return absl::make_unique<DictValueIterator>(nested);
|
||||||
} else if (IsMappingHelper(nested)) {
|
} else if (IsMappingHelper(nested)) {
|
||||||
return absl::make_unique<MappingValueIterator>(nested);
|
return absl::make_unique<MappingValueIterator>(nested);
|
||||||
|
} else if (IsAttrsHelper(nested)) {
|
||||||
|
return absl::make_unique<AttrsValueIterator>(nested);
|
||||||
} else if (IsSparseTensorValueType(nested)) {
|
} else if (IsSparseTensorValueType(nested)) {
|
||||||
return absl::make_unique<SparseTensorValueIterator>(nested);
|
return absl::make_unique<SparseTensorValueIterator>(nested);
|
||||||
} else {
|
} else {
|
||||||
@ -639,6 +689,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
|
|||||||
|
|
||||||
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
|
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
|
||||||
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
|
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
|
||||||
|
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
|
||||||
|
|
||||||
PyObject* Flatten(PyObject* nested) {
|
PyObject* Flatten(PyObject* nested) {
|
||||||
PyObject* list = PyList_New(0);
|
PyObject* list = PyList_New(0);
|
||||||
|
@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
|
|||||||
// True if the sequence subclasses mapping.
|
// True if the sequence subclasses mapping.
|
||||||
bool IsMapping(PyObject* o);
|
bool IsMapping(PyObject* o);
|
||||||
|
|
||||||
|
// Returns a true if its input is an instance of an attr.s decorated class.
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// o: the input to be checked.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// True if the object is an instance of an attr.s decorated class.
|
||||||
|
bool IsAttrs(PyObject* o);
|
||||||
|
|
||||||
// Implements the same interface as tensorflow.util.nest._same_namedtuples
|
// Implements the same interface as tensorflow.util.nest._same_namedtuples
|
||||||
// Returns Py_True iff the two namedtuples have the same name and fields.
|
// Returns Py_True iff the two namedtuples have the same name and fields.
|
||||||
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
|
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
|
||||||
|
@ -65,6 +65,18 @@ Returns:
|
|||||||
%unignore tensorflow::swig::IsMapping;
|
%unignore tensorflow::swig::IsMapping;
|
||||||
%noexception tensorflow::swig::IsMapping;
|
%noexception tensorflow::swig::IsMapping;
|
||||||
|
|
||||||
|
%feature("docstring") tensorflow::swig::IsAttrs
|
||||||
|
"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: An instance of a Python object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if `instance` is an instance of an `attr.s` decorated class.
|
||||||
|
"""
|
||||||
|
%unignore tensorflow::swig::IsAttrs;
|
||||||
|
%noexception tensorflow::swig::IsAttrs;
|
||||||
|
|
||||||
%feature("docstring") tensorflow::swig::SameNamedtuples
|
%feature("docstring") tensorflow::swig::SameNamedtuples
|
||||||
"Returns True if the two namedtuples have the same name and fields."
|
"Returns True if the two namedtuples have the same name and fields."
|
||||||
%unignore tensorflow::swig::SameNamedtuples;
|
%unignore tensorflow::swig::SameNamedtuples;
|
||||||
|
Loading…
Reference in New Issue
Block a user