flatten_up_to should return values, not keys

PiperOrigin-RevId: 163809688
This commit is contained in:
A. Unique TensorFlower 2017-08-01 03:04:57 -07:00 committed by TensorFlower Gardener
parent 6209b4b524
commit c7b674fa28
2 changed files with 46 additions and 1 deletions

View File

@ -382,7 +382,8 @@ def map_structure(func, *structure, **check_types_dict):
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
_yield_value(input_tree)):
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
yield input_leaf
else:

View File

@ -301,6 +301,7 @@ class NestTest(test.TestCase):
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
@ -308,6 +309,7 @@ class NestTest(test.TestCase):
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
@ -317,6 +319,46 @@ class NestTest(test.TestCase):
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])
# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0),
("b", {"d": 3, "e": 1})])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
@ -401,6 +443,7 @@ class NestTest(test.TestCase):
self.assertEqual(flattened_shallow_tree, shallow_tree)
def testMapStructureUpTo(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
@ -410,6 +453,7 @@ class NestTest(test.TestCase):
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)
# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = nest.map_structure_up_to(