Added nest support for attr.s decorated classes.

PiperOrigin-RevId: 214781911
This commit is contained in:
A. Unique TensorFlower 2018-09-27 08:56:28 -07:00 committed by TensorFlower Gardener
parent 77244534c0
commit cd1bdeafec
5 changed files with 121 additions and 1 deletions

View File

@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts. further sequences, tuples, and dicts.
attr.s decorated classes (http://www.attrs.org) are also supported, in the
same way as `namedtuple`.
The utilities here assume (and do not check) that the nested structures form a The utilities here assume (and do not check) that the nested structures form a
'tree', i.e., no references in the structure of the input of these functions 'tree', i.e., no references in the structure of the input of these functions
should be recursive. should be recursive.
@ -38,6 +41,12 @@ import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
def _get_attrs_values(obj):
"""Returns the list of values from an attrs instance."""
attrs = getattr(obj.__class__, "__attrs_attrs__")
return [getattr(obj, a.name) for a in attrs]
def _sorted(dict_): def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable.""" """Returns a sorted list of the dict keys, with error if keys not sortable."""
try: try:
@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
# See the swig file (util.i) for documentation. # See the swig file (util.i) for documentation.
_is_mapping = _pywrap_tensorflow.IsMapping _is_mapping = _pywrap_tensorflow.IsMapping
_is_attrs = _pywrap_tensorflow.IsAttrs
def _sequence_like(instance, args): def _sequence_like(instance, args):
@ -85,7 +95,7 @@ def _sequence_like(instance, args):
# corresponding `OrderedDict` to pack it back). # corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args)) result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
elif _is_namedtuple(instance): elif _is_namedtuple(instance) or _is_attrs(instance):
return type(instance)(*args) return type(instance)(*args)
else: else:
# Not a namedtuple # Not a namedtuple
@ -93,6 +103,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable): def _yield_value(iterable):
"""Yields the next value from the given iterable."""
if _is_mapping(iterable): if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the # Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict` # keys. Notice this means that we ignore the original order of `OrderedDict`
@ -101,6 +112,9 @@ def _yield_value(iterable):
# corresponding `OrderedDict` to pack it back). # corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable): for key in _sorted(iterable):
yield iterable[key] yield iterable[key]
elif _is_attrs(iterable):
for value in _get_attrs_values(iterable):
yield value
else: else:
for value in iterable: for value in iterable:
yield value yield value

View File

@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import nest from tensorflow.python.util import nest
try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None
class _CustomMapping(collections.Mapping): class _CustomMapping(collections.Mapping):
@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
if attr:
class BadAttr(object):
"""Class that has a non-iterable __attrs_attrs__."""
__attrs_attrs__ = None
@attr.s
class SampleAttr(object):
field1 = attr.ib()
field2 = attr.ib()
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsFlattenAndPack(self):
if attr is None:
self.skipTest("attr module is unavailable.")
field_values = [1, 2]
sample_attr = NestTest.SampleAttr(*field_values)
self.assertFalse(nest._is_attrs(field_values))
self.assertTrue(nest._is_attrs(sample_attr))
flat = nest.flatten(sample_attr)
self.assertEqual(field_values, flat)
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
self.assertEqual(restructured_from_flat, sample_attr)
# Check that flatten fails if attributes are not iterable
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
flat = nest.flatten(NestTest.BadAttr())
@test_util.assert_no_new_pyobjects_executing_eagerly @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self): def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8)) structure = ((3, 4), 5, (6, 7, (9, 10), 8))

View File

@ -192,6 +192,19 @@ int IsMappingHelper(PyObject* o) {
return check_cache->CachedLookup(o); return check_cache->CachedLookup(o);
} }
// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
int IsAttrsHelper(PyObject* o) {
Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
if (cls) {
return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
} else {
// PyObject_GetAttrString returns null on error
PyErr_Clear();
return 0;
}
}
// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise. // Returns 0 otherwise.
// Returns -1 if an error occurred. // Returns -1 if an error occurred.
@ -206,6 +219,7 @@ int IsSequenceHelper(PyObject* o) {
}); });
// We treat dicts and other mappings as special cases of sequences. // We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true; if (IsMappingHelper(o)) return true;
if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG(WARNING) << "Sets are not currently considered sequences, " LOG(WARNING) << "Sets are not currently considered sequences, "
"but this may change in the future, " "but this may change in the future, "
@ -354,6 +368,38 @@ class SparseTensorValueIterator : public ValueIterator {
Safe_PyObjectPtr tensor_; Safe_PyObjectPtr tensor_;
}; };
class AttrsValueIterator : public ValueIterator {
public:
explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
Py_INCREF(nested);
cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
if (cls_) {
attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
if (attrs_) {
iter_.reset(PyObject_GetIter(attrs_.get()));
}
}
if (!iter_ || PyErr_Occurred()) invalidate();
}
Safe_PyObjectPtr next() override {
Safe_PyObjectPtr result;
Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
if (item) {
Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
result.reset(PyObject_GetAttr(nested_.get(), name.get()));
}
return result;
}
private:
Safe_PyObjectPtr nested_;
Safe_PyObjectPtr cls_;
Safe_PyObjectPtr attrs_;
Safe_PyObjectPtr iter_;
};
bool IsSparseTensorValueType(PyObject* o) { bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false; return false;
@ -372,6 +418,8 @@ ValueIteratorPtr GetValueIterator(PyObject* nested) {
return absl::make_unique<DictValueIterator>(nested); return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) { } else if (IsMappingHelper(nested)) {
return absl::make_unique<MappingValueIterator>(nested); return absl::make_unique<MappingValueIterator>(nested);
} else if (IsAttrsHelper(nested)) {
return absl::make_unique<AttrsValueIterator>(nested);
} else { } else {
return absl::make_unique<SequenceValueIterator>(nested); return absl::make_unique<SequenceValueIterator>(nested);
} }
@ -383,6 +431,8 @@ ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
return absl::make_unique<DictValueIterator>(nested); return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) { } else if (IsMappingHelper(nested)) {
return absl::make_unique<MappingValueIterator>(nested); return absl::make_unique<MappingValueIterator>(nested);
} else if (IsAttrsHelper(nested)) {
return absl::make_unique<AttrsValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) { } else if (IsSparseTensorValueType(nested)) {
return absl::make_unique<SparseTensorValueIterator>(nested); return absl::make_unique<SparseTensorValueIterator>(nested);
} else { } else {
@ -639,6 +689,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) { PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0); PyObject* list = PyList_New(0);

View File

@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping. // True if the sequence subclasses mapping.
bool IsMapping(PyObject* o); bool IsMapping(PyObject* o);
// Returns a true if its input is an instance of an attr.s decorated class.
//
// Args:
// o: the input to be checked.
//
// Returns:
// True if the object is an instance of an attr.s decorated class.
bool IsAttrs(PyObject* o);
// Implements the same interface as tensorflow.util.nest._same_namedtuples // Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields. // Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have

View File

@ -65,6 +65,18 @@ Returns:
%unignore tensorflow::swig::IsMapping; %unignore tensorflow::swig::IsMapping;
%noexception tensorflow::swig::IsMapping; %noexception tensorflow::swig::IsMapping;
%feature("docstring") tensorflow::swig::IsAttrs
"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
Args:
instance: An instance of a Python object.
Returns:
True if `instance` is an instance of an `attr.s` decorated class.
"""
%unignore tensorflow::swig::IsAttrs;
%noexception tensorflow::swig::IsAttrs;
%feature("docstring") tensorflow::swig::SameNamedtuples %feature("docstring") tensorflow::swig::SameNamedtuples
"Returns True if the two namedtuples have the same name and fields." "Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples; %unignore tensorflow::swig::SameNamedtuples;