Added nest support for attr.s decorated classes.
PiperOrigin-RevId: 214781911
This commit is contained in:
parent
77244534c0
commit
cd1bdeafec
@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
|
||||
Python sequence, tuple (including `namedtuple`), or dict that can contain
|
||||
further sequences, tuples, and dicts.
|
||||
|
||||
attr.s decorated classes (http://www.attrs.org) are also supported, in the
|
||||
same way as `namedtuple`.
|
||||
|
||||
The utilities here assume (and do not check) that the nested structures form a
|
||||
'tree', i.e., no references in the structure of the input of these functions
|
||||
should be recursive.
|
||||
@ -38,6 +41,12 @@ import six as _six
|
||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||
|
||||
|
||||
def _get_attrs_values(obj):
|
||||
"""Returns the list of values from an attrs instance."""
|
||||
attrs = getattr(obj.__class__, "__attrs_attrs__")
|
||||
return [getattr(obj, a.name) for a in attrs]
|
||||
|
||||
|
||||
def _sorted(dict_):
|
||||
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
|
||||
try:
|
||||
@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
|
||||
|
||||
# See the swig file (util.i) for documentation.
|
||||
_is_mapping = _pywrap_tensorflow.IsMapping
|
||||
_is_attrs = _pywrap_tensorflow.IsAttrs
|
||||
|
||||
|
||||
def _sequence_like(instance, args):
|
||||
@ -85,7 +95,7 @@ def _sequence_like(instance, args):
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
result = dict(zip(_sorted(instance), args))
|
||||
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
|
||||
elif _is_namedtuple(instance):
|
||||
elif _is_namedtuple(instance) or _is_attrs(instance):
|
||||
return type(instance)(*args)
|
||||
else:
|
||||
# Not a namedtuple
|
||||
@ -93,6 +103,7 @@ def _sequence_like(instance, args):
|
||||
|
||||
|
||||
def _yield_value(iterable):
|
||||
"""Yields the next value from the given iterable."""
|
||||
if _is_mapping(iterable):
|
||||
# Iterate through dictionaries in a deterministic order by sorting the
|
||||
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||||
@ -101,6 +112,9 @@ def _yield_value(iterable):
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
for key in _sorted(iterable):
|
||||
yield iterable[key]
|
||||
elif _is_attrs(iterable):
|
||||
for value in _get_attrs_values(iterable):
|
||||
yield value
|
||||
else:
|
||||
for value in iterable:
|
||||
yield value
|
||||
|
@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
try:
|
||||
import attr # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
attr = None
|
||||
|
||||
|
||||
class _CustomMapping(collections.Mapping):
|
||||
|
||||
@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
|
||||
|
||||
if attr:
|
||||
class BadAttr(object):
|
||||
"""Class that has a non-iterable __attrs_attrs__."""
|
||||
__attrs_attrs__ = None
|
||||
|
||||
@attr.s
|
||||
class SampleAttr(object):
|
||||
field1 = attr.ib()
|
||||
field2 = attr.ib()
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testAttrsFlattenAndPack(self):
|
||||
if attr is None:
|
||||
self.skipTest("attr module is unavailable.")
|
||||
|
||||
field_values = [1, 2]
|
||||
sample_attr = NestTest.SampleAttr(*field_values)
|
||||
self.assertFalse(nest._is_attrs(field_values))
|
||||
self.assertTrue(nest._is_attrs(sample_attr))
|
||||
flat = nest.flatten(sample_attr)
|
||||
self.assertEqual(field_values, flat)
|
||||
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
|
||||
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
|
||||
self.assertEqual(restructured_from_flat, sample_attr)
|
||||
|
||||
# Check that flatten fails if attributes are not iterable
|
||||
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
|
||||
flat = nest.flatten(NestTest.BadAttr())
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testFlattenAndPack(self):
|
||||
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
|
||||
|
@ -192,6 +192,19 @@ int IsMappingHelper(PyObject* o) {
|
||||
return check_cache->CachedLookup(o);
|
||||
}
|
||||
|
||||
// Returns 1 if `o` is an instance of attrs-decorated class.
|
||||
// Returns 0 otherwise.
|
||||
int IsAttrsHelper(PyObject* o) {
|
||||
Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
|
||||
if (cls) {
|
||||
return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
|
||||
} else {
|
||||
// PyObject_GetAttrString returns null on error
|
||||
PyErr_Clear();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred.
|
||||
@ -206,6 +219,7 @@ int IsSequenceHelper(PyObject* o) {
|
||||
});
|
||||
// We treat dicts and other mappings as special cases of sequences.
|
||||
if (IsMappingHelper(o)) return true;
|
||||
if (IsAttrsHelper(o)) return true;
|
||||
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
|
||||
LOG(WARNING) << "Sets are not currently considered sequences, "
|
||||
"but this may change in the future, "
|
||||
@ -354,6 +368,38 @@ class SparseTensorValueIterator : public ValueIterator {
|
||||
Safe_PyObjectPtr tensor_;
|
||||
};
|
||||
|
||||
class AttrsValueIterator : public ValueIterator {
|
||||
public:
|
||||
explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
|
||||
Py_INCREF(nested);
|
||||
cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
|
||||
if (cls_) {
|
||||
attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
|
||||
if (attrs_) {
|
||||
iter_.reset(PyObject_GetIter(attrs_.get()));
|
||||
}
|
||||
}
|
||||
if (!iter_ || PyErr_Occurred()) invalidate();
|
||||
}
|
||||
|
||||
Safe_PyObjectPtr next() override {
|
||||
Safe_PyObjectPtr result;
|
||||
Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
|
||||
if (item) {
|
||||
Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
|
||||
result.reset(PyObject_GetAttr(nested_.get(), name.get()));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
Safe_PyObjectPtr nested_;
|
||||
Safe_PyObjectPtr cls_;
|
||||
Safe_PyObjectPtr attrs_;
|
||||
Safe_PyObjectPtr iter_;
|
||||
};
|
||||
|
||||
bool IsSparseTensorValueType(PyObject* o) {
|
||||
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
|
||||
return false;
|
||||
@ -372,6 +418,8 @@ ValueIteratorPtr GetValueIterator(PyObject* nested) {
|
||||
return absl::make_unique<DictValueIterator>(nested);
|
||||
} else if (IsMappingHelper(nested)) {
|
||||
return absl::make_unique<MappingValueIterator>(nested);
|
||||
} else if (IsAttrsHelper(nested)) {
|
||||
return absl::make_unique<AttrsValueIterator>(nested);
|
||||
} else {
|
||||
return absl::make_unique<SequenceValueIterator>(nested);
|
||||
}
|
||||
@ -383,6 +431,8 @@ ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
|
||||
return absl::make_unique<DictValueIterator>(nested);
|
||||
} else if (IsMappingHelper(nested)) {
|
||||
return absl::make_unique<MappingValueIterator>(nested);
|
||||
} else if (IsAttrsHelper(nested)) {
|
||||
return absl::make_unique<AttrsValueIterator>(nested);
|
||||
} else if (IsSparseTensorValueType(nested)) {
|
||||
return absl::make_unique<SparseTensorValueIterator>(nested);
|
||||
} else {
|
||||
@ -639,6 +689,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
|
||||
|
||||
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
|
||||
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
|
||||
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
|
||||
|
||||
PyObject* Flatten(PyObject* nested) {
|
||||
PyObject* list = PyList_New(0);
|
||||
|
@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
|
||||
// True if the sequence subclasses mapping.
|
||||
bool IsMapping(PyObject* o);
|
||||
|
||||
// Returns a true if its input is an instance of an attr.s decorated class.
|
||||
//
|
||||
// Args:
|
||||
// o: the input to be checked.
|
||||
//
|
||||
// Returns:
|
||||
// True if the object is an instance of an attr.s decorated class.
|
||||
bool IsAttrs(PyObject* o);
|
||||
|
||||
// Implements the same interface as tensorflow.util.nest._same_namedtuples
|
||||
// Returns Py_True iff the two namedtuples have the same name and fields.
|
||||
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
|
||||
|
@ -65,6 +65,18 @@ Returns:
|
||||
%unignore tensorflow::swig::IsMapping;
|
||||
%noexception tensorflow::swig::IsMapping;
|
||||
|
||||
%feature("docstring") tensorflow::swig::IsAttrs
|
||||
"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
|
||||
|
||||
Args:
|
||||
instance: An instance of a Python object.
|
||||
|
||||
Returns:
|
||||
True if `instance` is an instance of an `attr.s` decorated class.
|
||||
"""
|
||||
%unignore tensorflow::swig::IsAttrs;
|
||||
%noexception tensorflow::swig::IsAttrs;
|
||||
|
||||
%feature("docstring") tensorflow::swig::SameNamedtuples
|
||||
"Returns True if the two namedtuples have the same name and fields."
|
||||
%unignore tensorflow::swig::SameNamedtuples;
|
||||
|
Loading…
Reference in New Issue
Block a user