From db6c140a7ecea6ccf4aa608404d4e69c980e8398 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Mar 2019 02:41:05 -0700 Subject: [PATCH] Fix nest.map_structure() bug where attr objects were repacked in the wrong order. PiperOrigin-RevId: 239153412 --- tensorflow/python/util/nest.py | 2 +- tensorflow/python/util/nest_test.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 8fb187fb5b6..d22f02f1a2c 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -72,7 +72,7 @@ def _get_attrs_items(obj): A list of (attr_name, attr_value) pairs, sorted by attr_name. """ attrs = getattr(obj.__class__, "__attrs_attrs__") - attr_names = sorted([a.name for a in attrs]) + attr_names = [a.name for a in attrs] return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names] diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 0540f71f7a9..9a8f82e8d48 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -68,6 +68,12 @@ class NestTest(parameterized.TestCase, test.TestCase): field1 = attr.ib() field2 = attr.ib() + @attr.s + class UnsortedSampleAttr(object): + field3 = attr.ib() + field1 = attr.ib() + field2 = attr.ib() + @test_util.assert_no_new_pyobjects_executing_eagerly def testAttrsFlattenAndPack(self): if attr is None: @@ -87,6 +93,21 @@ class NestTest(parameterized.TestCase, test.TestCase): with self.assertRaisesRegexp(TypeError, "object is not iterable"): flat = nest.flatten(NestTest.BadAttr()) + @parameterized.parameters( + {"values": [1, 2, 3]}, + {"values": [{"B": 10, "A": 20}, [1, 2], 3]}, + {"values": [(1, 2), [3, 4], 5]}, + {"values": [PointXY(1, 2), 3, 4]}, + ) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testAttrsMapStructure(self, values): + if attr is None: + self.skipTest("attr module is unavailable.") + + structure = NestTest.UnsortedSampleAttr(*values) + new_structure = nest.map_structure(lambda x: x, structure) + self.assertEqual(structure, new_structure) + @test_util.assert_no_new_pyobjects_executing_eagerly def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8))