Make map_structure_with_atomic work with attrs.

PiperOrigin-RevId: 316037273
Change-Id: I1316a99e657c92a8835772260744175c15d55d33
This commit is contained in:
A. Unique TensorFlower 2020-06-11 21:03:15 -07:00 committed by TensorFlower Gardener
parent b62ad45a87
commit b5e6999ea2
2 changed files with 38 additions and 0 deletions

View File

@ -179,6 +179,8 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
'Received non-atomic and non-sequence element: {}'.format(nested))
if nest._is_mapping(nested):
values = [nested[k] for k in nest._sorted(nested)]
elif nest._is_attrs(nested):
values = _astuple(nested)
else:
values = nested
mapped_values = [
@ -533,3 +535,15 @@ def to_numpy_or_python_type(tensors):
return t # Don't turn ragged or sparse tensors to NumPy.
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
def _astuple(attrs):
"""Converts the given attrs to tuple non-recursively."""
cls = type(attrs)
fields = getattr(cls, '__attrs_attrs__', None)
if fields is None:
raise ValueError('%r is not an attrs-decorated class.' % cls)
values = []
for field in fields:
values.append(getattr(attrs, field.name))
return tuple(values)

View File

@ -29,6 +29,11 @@ from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase):
@ -158,5 +163,24 @@ class ConvertInnerNodeDataTest(test.TestCase):
self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
class AttrsTest(test.TestCase):
def test_map_structure_with_atomic_accept_attr(self):
if attr is None:
self.skipTest('attr module is unavailable.')
@attr.s(frozen=True)
class Foo(object):
bar = attr.ib()
self.assertEqual(
Foo(2),
tf_utils.map_structure_with_atomic(
is_atomic_fn=lambda x: isinstance(x, int),
map_fn=lambda x: x + 1,
nested=Foo(1)))
if __name__ == '__main__':
test.main()