Make map_structure_with_atomic work with attrs.
PiperOrigin-RevId: 316037273 Change-Id: I1316a99e657c92a8835772260744175c15d55d33
This commit is contained in:
parent
b62ad45a87
commit
b5e6999ea2
|
@ -179,6 +179,8 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
|
|||
'Received non-atomic and non-sequence element: {}'.format(nested))
|
||||
if nest._is_mapping(nested):
|
||||
values = [nested[k] for k in nest._sorted(nested)]
|
||||
elif nest._is_attrs(nested):
|
||||
values = _astuple(nested)
|
||||
else:
|
||||
values = nested
|
||||
mapped_values = [
|
||||
|
@ -533,3 +535,15 @@ def to_numpy_or_python_type(tensors):
|
|||
return t # Don't turn ragged or sparse tensors to NumPy.
|
||||
|
||||
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
|
||||
|
||||
|
||||
def _astuple(attrs):
|
||||
"""Converts the given attrs to tuple non-recursively."""
|
||||
cls = type(attrs)
|
||||
fields = getattr(cls, '__attrs_attrs__', None)
|
||||
if fields is None:
|
||||
raise ValueError('%r is not an attrs-decorated class.' % cls)
|
||||
values = []
|
||||
for field in fields:
|
||||
values.append(getattr(attrs, field.name))
|
||||
return tuple(values)
|
||||
|
|
|
@ -29,6 +29,11 @@ from tensorflow.python.keras.utils import tf_utils
|
|||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
try:
|
||||
import attr # pylint:disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
attr = None
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase):
|
||||
|
@ -158,5 +163,24 @@ class ConvertInnerNodeDataTest(test.TestCase):
|
|||
self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
|
||||
|
||||
|
||||
class AttrsTest(test.TestCase):
|
||||
|
||||
def test_map_structure_with_atomic_accept_attr(self):
|
||||
if attr is None:
|
||||
self.skipTest('attr module is unavailable.')
|
||||
|
||||
@attr.s(frozen=True)
|
||||
class Foo(object):
|
||||
|
||||
bar = attr.ib()
|
||||
|
||||
self.assertEqual(
|
||||
Foo(2),
|
||||
tf_utils.map_structure_with_atomic(
|
||||
is_atomic_fn=lambda x: isinstance(x, int),
|
||||
map_fn=lambda x: x + 1,
|
||||
nested=Foo(1)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue