Make map_structure_with_atomic work with attrs.
PiperOrigin-RevId: 316037273 Change-Id: I1316a99e657c92a8835772260744175c15d55d33
This commit is contained in:
parent
b62ad45a87
commit
b5e6999ea2
@ -179,6 +179,8 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
|
|||||||
'Received non-atomic and non-sequence element: {}'.format(nested))
|
'Received non-atomic and non-sequence element: {}'.format(nested))
|
||||||
if nest._is_mapping(nested):
|
if nest._is_mapping(nested):
|
||||||
values = [nested[k] for k in nest._sorted(nested)]
|
values = [nested[k] for k in nest._sorted(nested)]
|
||||||
|
elif nest._is_attrs(nested):
|
||||||
|
values = _astuple(nested)
|
||||||
else:
|
else:
|
||||||
values = nested
|
values = nested
|
||||||
mapped_values = [
|
mapped_values = [
|
||||||
@ -533,3 +535,15 @@ def to_numpy_or_python_type(tensors):
|
|||||||
return t # Don't turn ragged or sparse tensors to NumPy.
|
return t # Don't turn ragged or sparse tensors to NumPy.
|
||||||
|
|
||||||
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
|
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
|
||||||
|
|
||||||
|
|
||||||
|
def _astuple(attrs):
|
||||||
|
"""Converts the given attrs to tuple non-recursively."""
|
||||||
|
cls = type(attrs)
|
||||||
|
fields = getattr(cls, '__attrs_attrs__', None)
|
||||||
|
if fields is None:
|
||||||
|
raise ValueError('%r is not an attrs-decorated class.' % cls)
|
||||||
|
values = []
|
||||||
|
for field in fields:
|
||||||
|
values.append(getattr(attrs, field.name))
|
||||||
|
return tuple(values)
|
||||||
|
|||||||
@ -29,6 +29,11 @@ from tensorflow.python.keras.utils import tf_utils
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
try:
|
||||||
|
import attr # pylint:disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
attr = None
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase):
|
class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase):
|
||||||
@ -158,5 +163,24 @@ class ConvertInnerNodeDataTest(test.TestCase):
|
|||||||
self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
|
self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
|
||||||
|
|
||||||
|
|
||||||
|
class AttrsTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_map_structure_with_atomic_accept_attr(self):
|
||||||
|
if attr is None:
|
||||||
|
self.skipTest('attr module is unavailable.')
|
||||||
|
|
||||||
|
@attr.s(frozen=True)
|
||||||
|
class Foo(object):
|
||||||
|
|
||||||
|
bar = attr.ib()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
Foo(2),
|
||||||
|
tf_utils.map_structure_with_atomic(
|
||||||
|
is_atomic_fn=lambda x: isinstance(x, int),
|
||||||
|
map_fn=lambda x: x + 1,
|
||||||
|
nested=Foo(1)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user