diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 653ca525dce..758cba74877 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a Python sequence, tuple (including `namedtuple`), or dict that can contain further sequences, tuples, and dicts. +attr.s decorated classes (http://www.attrs.org) are also supported, in the +same way as `namedtuple`. + The utilities here assume (and do not check) that the nested structures form a 'tree', i.e., no references in the structure of the input of these functions should be recursive. @@ -38,6 +41,12 @@ import six as _six from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow +def _get_attrs_values(obj): + """Returns the list of values from an attrs instance.""" + attrs = getattr(obj.__class__, "__attrs_attrs__") + return [getattr(obj, a.name) for a in attrs] + + def _sorted(dict_): """Returns a sorted list of the dict keys, with error if keys not sortable.""" try: @@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False): # See the swig file (util.i) for documentation. _is_mapping = _pywrap_tensorflow.IsMapping +_is_attrs = _pywrap_tensorflow.IsAttrs def _sequence_like(instance, args): @@ -85,7 +95,7 @@ def _sequence_like(instance, args): # corresponding `OrderedDict` to pack it back). result = dict(zip(_sorted(instance), args)) return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) - elif _is_namedtuple(instance): + elif _is_namedtuple(instance) or _is_attrs(instance): return type(instance)(*args) else: # Not a namedtuple @@ -93,6 +103,7 @@ def _sequence_like(instance, args): def _yield_value(iterable): + """Yields the next value from the given iterable.""" if _is_mapping(iterable): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` @@ -101,6 +112,9 @@ def _yield_value(iterable): # corresponding `OrderedDict` to pack it back). for key in _sorted(iterable): yield iterable[key] + elif _is_attrs(iterable): + for value in _get_attrs_values(iterable): + yield value else: for value in iterable: yield value diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index bfb4c6f910f..e03a8daaa19 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.util import nest +try: + import attr # pylint:disable=g-import-not-at-top +except ImportError: + attr = None + class _CustomMapping(collections.Mapping): @@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase): PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name + if attr: + class BadAttr(object): + """Class that has a non-iterable __attrs_attrs__.""" + __attrs_attrs__ = None + + @attr.s + class SampleAttr(object): + field1 = attr.ib() + field2 = attr.ib() + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testAttrsFlattenAndPack(self): + if attr is None: + self.skipTest("attr module is unavailable.") + + field_values = [1, 2] + sample_attr = NestTest.SampleAttr(*field_values) + self.assertFalse(nest._is_attrs(field_values)) + self.assertTrue(nest._is_attrs(sample_attr)) + flat = nest.flatten(sample_attr) + self.assertEqual(field_values, flat) + restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) + self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) + self.assertEqual(restructured_from_flat, sample_attr) + + # Check that flatten fails if attributes are not iterable + with self.assertRaisesRegexp(TypeError, "object is not iterable"): + flat = nest.flatten(NestTest.BadAttr()) + @test_util.assert_no_new_pyobjects_executing_eagerly def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 2087957b311..38b8491c664 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -192,6 +192,19 @@ int IsMappingHelper(PyObject* o) { return check_cache->CachedLookup(o); } +// Returns 1 if `o` is an instance of attrs-decorated class. +// Returns 0 otherwise. +int IsAttrsHelper(PyObject* o) { + Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__")); + if (cls) { + return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); + } else { + // PyObject_GetAttrString returns null on error + PyErr_Clear(); + return 0; + } +} + // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. @@ -206,6 +219,7 @@ int IsSequenceHelper(PyObject* o) { }); // We treat dicts and other mappings as special cases of sequences. if (IsMappingHelper(o)) return true; + if (IsAttrsHelper(o)) return true; if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { LOG(WARNING) << "Sets are not currently considered sequences, " "but this may change in the future, " @@ -354,6 +368,38 @@ class SparseTensorValueIterator : public ValueIterator { Safe_PyObjectPtr tensor_; }; +class AttrsValueIterator : public ValueIterator { + public: + explicit AttrsValueIterator(PyObject* nested) : nested_(nested) { + Py_INCREF(nested); + cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__")); + if (cls_) { + attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__")); + if (attrs_) { + iter_.reset(PyObject_GetIter(attrs_.get())); + } + } + if (!iter_ || PyErr_Occurred()) invalidate(); + } + + Safe_PyObjectPtr next() override { + Safe_PyObjectPtr result; + Safe_PyObjectPtr item(PyIter_Next(iter_.get())); + if (item) { + Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name")); + result.reset(PyObject_GetAttr(nested_.get(), name.get())); + } + + return result; + } + + private: + Safe_PyObjectPtr nested_; + Safe_PyObjectPtr cls_; + Safe_PyObjectPtr attrs_; + Safe_PyObjectPtr iter_; +}; + bool IsSparseTensorValueType(PyObject* o) { if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { return false; @@ -372,6 +418,8 @@ ValueIteratorPtr GetValueIterator(PyObject* nested) { return absl::make_unique(nested); } else if (IsMappingHelper(nested)) { return absl::make_unique(nested); + } else if (IsAttrsHelper(nested)) { + return absl::make_unique(nested); } else { return absl::make_unique(nested); } @@ -383,6 +431,8 @@ ValueIteratorPtr GetValueIteratorForData(PyObject* nested) { return absl::make_unique(nested); } else if (IsMappingHelper(nested)) { return absl::make_unique(nested); + } else if (IsAttrsHelper(nested)) { + return absl::make_unique(nested); } else if (IsSparseTensorValueType(nested)) { return absl::make_unique(nested); } else { @@ -639,6 +689,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } +bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 343605285ea..01f85ea1dc9 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict); // True if the sequence subclasses mapping. bool IsMapping(PyObject* o); +// Returns a true if its input is an instance of an attr.s decorated class. +// +// Args: +// o: the input to be checked. +// +// Returns: +// True if the object is an instance of an attr.s decorated class. +bool IsAttrs(PyObject* o); + // Implements the same interface as tensorflow.util.nest._same_namedtuples // Returns Py_True iff the two namedtuples have the same name and fields. // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 104a6156367..32a6e684fa9 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -65,6 +65,18 @@ Returns: %unignore tensorflow::swig::IsMapping; %noexception tensorflow::swig::IsMapping; +%feature("docstring") tensorflow::swig::IsAttrs +"""Returns True iff `instance` is an instance of an `attr.s` decorated class. + +Args: + instance: An instance of a Python object. + +Returns: + True if `instance` is an instance of an `attr.s` decorated class. +""" +%unignore tensorflow::swig::IsAttrs; +%noexception tensorflow::swig::IsAttrs; + %feature("docstring") tensorflow::swig::SameNamedtuples "Returns True if the two namedtuples have the same name and fields." %unignore tensorflow::swig::SameNamedtuples;