flatten_up_to should return values, not keys
PiperOrigin-RevId: 163809688
This commit is contained in:
parent
6209b4b524
commit
c7b674fa28
@ -382,7 +382,8 @@ def map_structure(func, *structure, **check_types_dict):
|
||||
def _yield_flat_up_to(shallow_tree, input_tree):
|
||||
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||||
if is_sequence(shallow_tree):
|
||||
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
|
||||
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||||
_yield_value(input_tree)):
|
||||
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||||
yield input_leaf
|
||||
else:
|
||||
|
@ -301,6 +301,7 @@ class NestTest(test.TestCase):
|
||||
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
|
||||
|
||||
def testFlattenUpTo(self):
|
||||
# Shallow tree ends at scalar.
|
||||
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
|
||||
shallow_tree = [[True, True], [False, True]]
|
||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||
@ -308,6 +309,7 @@ class NestTest(test.TestCase):
|
||||
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
|
||||
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
|
||||
|
||||
# Shallow tree ends at string.
|
||||
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
|
||||
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
|
||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||
@ -317,6 +319,46 @@ class NestTest(test.TestCase):
|
||||
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
|
||||
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
|
||||
|
||||
# Make sure dicts are correctly flattened, yielding values, not keys.
|
||||
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
|
||||
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
|
||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||
input_tree)
|
||||
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||
[1, {"c": 2}, 3, (4, 5)])
|
||||
|
||||
# Namedtuples.
|
||||
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||
input_tree = ab_tuple(a=[0, 1], b=2)
|
||||
shallow_tree = ab_tuple(a=0, b=1)
|
||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||
input_tree)
|
||||
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||
[[0, 1], 2])
|
||||
|
||||
# Nested dicts, OrderedDicts and namedtuples.
|
||||
input_tree = collections.OrderedDict(
|
||||
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
|
||||
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
|
||||
shallow_tree = input_tree
|
||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||
input_tree)
|
||||
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
|
||||
shallow_tree = collections.OrderedDict([("a", 0),
|
||||
("b", {"d": 3, "e": 1})])
|
||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||
input_tree)
|
||||
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||
[ab_tuple(a=[0, {"b": 1}], b=2),
|
||||
3,
|
||||
collections.OrderedDict([("f", 4)])])
|
||||
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
|
||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||
input_tree)
|
||||
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||
[ab_tuple(a=[0, {"b": 1}], b=2),
|
||||
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
|
||||
|
||||
## Shallow non-list edge-case.
|
||||
# Using iterable elements.
|
||||
input_tree = ["input_tree"]
|
||||
@ -401,6 +443,7 @@ class NestTest(test.TestCase):
|
||||
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
||||
|
||||
def testMapStructureUpTo(self):
|
||||
# Named tuples.
|
||||
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||
op_tuple = collections.namedtuple("op_tuple", "add, mul")
|
||||
inp_val = ab_tuple(a=2, b=3)
|
||||
@ -410,6 +453,7 @@ class NestTest(test.TestCase):
|
||||
self.assertEqual(out.a, 6)
|
||||
self.assertEqual(out.b, 15)
|
||||
|
||||
# Lists.
|
||||
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
|
||||
name_list = ["evens", ["odds", "primes"]]
|
||||
out = nest.map_structure_up_to(
|
||||
|
Loading…
Reference in New Issue
Block a user