diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index f3f3887afc5..7eecbf79e5c 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -382,7 +382,8 @@ def map_structure(func, *structure, **check_types_dict): def _yield_flat_up_to(shallow_tree, input_tree): """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" if is_sequence(shallow_tree): - for shallow_branch, input_branch in zip(shallow_tree, input_tree): + for shallow_branch, input_branch in zip(_yield_value(shallow_tree), + _yield_value(input_tree)): for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): yield input_leaf else: diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 375e30e9534..b4a0525e6c3 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -301,6 +301,7 @@ class NestTest(test.TestCase): nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) def testFlattenUpTo(self): + # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) @@ -308,6 +309,7 @@ class NestTest(test.TestCase): self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) + # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, @@ -317,6 +319,46 @@ class NestTest(test.TestCase): [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) + # Make sure dicts are correctly flattened, yielding values, not keys. + input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} + shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [1, {"c": 2}, 3, (4, 5)]) + + # Namedtuples. + ab_tuple = collections.namedtuple("ab_tuple", "a, b") + input_tree = ab_tuple(a=[0, 1], b=2) + shallow_tree = ab_tuple(a=0, b=1) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [[0, 1], 2]) + + # Nested dicts, OrderedDicts and namedtuples. + input_tree = collections.OrderedDict( + [("a", ab_tuple(a=[0, {"b": 1}], b=2)), + ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) + shallow_tree = input_tree + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) + shallow_tree = collections.OrderedDict([("a", 0), + ("b", {"d": 3, "e": 1})]) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [ab_tuple(a=[0, {"b": 1}], b=2), + 3, + collections.OrderedDict([("f", 4)])]) + shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [ab_tuple(a=[0, {"b": 1}], b=2), + {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) + ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] @@ -401,6 +443,7 @@ class NestTest(test.TestCase): self.assertEqual(flattened_shallow_tree, shallow_tree) def testMapStructureUpTo(self): + # Named tuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) @@ -410,6 +453,7 @@ class NestTest(test.TestCase): self.assertEqual(out.a, 6) self.assertEqual(out.b, 15) + # Lists. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ["evens", ["odds", "primes"]] out = nest.map_structure_up_to(