Fix nest.map_structure() bug where attr objects were repacked in the wrong order.

PiperOrigin-RevId: 239153412
This commit is contained in:
A. Unique TensorFlower 2019-03-19 02:41:05 -07:00 committed by TensorFlower Gardener
parent b244682665
commit db6c140a7e
2 changed files with 22 additions and 1 deletions

View File

@ -72,7 +72,7 @@ def _get_attrs_items(obj):
A list of (attr_name, attr_value) pairs, sorted by attr_name.
"""
attrs = getattr(obj.__class__, "__attrs_attrs__")
attr_names = sorted([a.name for a in attrs])
attr_names = [a.name for a in attrs]
return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]

View File

@ -68,6 +68,12 @@ class NestTest(parameterized.TestCase, test.TestCase):
field1 = attr.ib()
field2 = attr.ib()
@attr.s
class UnsortedSampleAttr(object):
field3 = attr.ib()
field1 = attr.ib()
field2 = attr.ib()
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsFlattenAndPack(self):
if attr is None:
@ -87,6 +93,21 @@ class NestTest(parameterized.TestCase, test.TestCase):
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
flat = nest.flatten(NestTest.BadAttr())
@parameterized.parameters(
{"values": [1, 2, 3]},
{"values": [{"B": 10, "A": 20}, [1, 2], 3]},
{"values": [(1, 2), [3, 4], 5]},
{"values": [PointXY(1, 2), 3, 4]},
)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsMapStructure(self, values):
if attr is None:
self.skipTest("attr module is unavailable.")
structure = NestTest.UnsortedSampleAttr(*values)
new_structure = nest.map_structure(lambda x: x, structure)
self.assertEqual(structure, new_structure)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))