Add flatten_with_tuple_paths_up_to(), the 'flatten' version of map_structure_with_tuple_paths_up_to().
PiperOrigin-RevId: 232673116
This commit is contained in:
parent
5b5636642d
commit
e183604755
@ -139,6 +139,7 @@ _nest_allowed_symbols = [
|
|||||||
'map_structure_with_tuple_paths',
|
'map_structure_with_tuple_paths',
|
||||||
'assert_shallow_structure',
|
'assert_shallow_structure',
|
||||||
'flatten_up_to',
|
'flatten_up_to',
|
||||||
|
'flatten_with_tuple_paths_up_to',
|
||||||
'map_structure_up_to',
|
'map_structure_up_to',
|
||||||
'map_structure_with_tuple_paths_up_to',
|
'map_structure_with_tuple_paths_up_to',
|
||||||
'get_traverse_shallow_structure',
|
'get_traverse_shallow_structure',
|
||||||
|
@ -761,6 +761,100 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True):
|
|||||||
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
|
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
|
||||||
|
"""Flattens `input_tree` up to `shallow_tree`.
|
||||||
|
|
||||||
|
Any further depth in structure in `input_tree` is retained as elements in the
|
||||||
|
partially flattened output.
|
||||||
|
|
||||||
|
Returns a list of (path, value) pairs, where value a leaf node in the
|
||||||
|
flattened tree, and path is the tuple path of that leaf in input_tree.
|
||||||
|
|
||||||
|
If `shallow_tree` and `input_tree` are not sequences, this returns a
|
||||||
|
single-element list: `[((), input_tree)]`.
|
||||||
|
|
||||||
|
Use Case:
|
||||||
|
|
||||||
|
Sometimes we may wish to partially flatten a nested sequence, retaining some
|
||||||
|
of the nested structure. We achieve this by specifying a shallow structure,
|
||||||
|
`shallow_tree`, we wish to flatten up to.
|
||||||
|
|
||||||
|
The input, `input_tree`, can be thought of as having the same structure as
|
||||||
|
`shallow_tree`, but with leaf nodes that are themselves tree structures.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
|
||||||
|
shallow_tree = [[True, True], [False, True]]
|
||||||
|
|
||||||
|
flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree,
|
||||||
|
shallow_tree)
|
||||||
|
|
||||||
|
# Output is:
|
||||||
|
# [((0, 0), [2, 2]),
|
||||||
|
# ((0, 1), [3, 3]),
|
||||||
|
# ((1, 0), [4, 9]),
|
||||||
|
# ((1, 1), [5, 5])]
|
||||||
|
#
|
||||||
|
# [((0, 0), True),
|
||||||
|
# ((0, 1), True),
|
||||||
|
# ((1, 0), False),
|
||||||
|
# ((1, 1), True)]
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
|
||||||
|
shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
|
||||||
|
|
||||||
|
input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
|
||||||
|
input_tree_flattened = flatten(input_tree)
|
||||||
|
|
||||||
|
# Output is:
|
||||||
|
# [((0, 0), ('a', 1)),
|
||||||
|
# ((0, 1, 0), ('b', 2)),
|
||||||
|
# ((0, 1, 1, 0), ('c', 3)),
|
||||||
|
# ((0, 1, 1, 1), ('d', 4))]
|
||||||
|
# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
|
||||||
|
```
|
||||||
|
|
||||||
|
Non-Sequence Edge Cases:
|
||||||
|
|
||||||
|
```python
|
||||||
|
flatten_with_tuple_paths_up_to(0, 0) # Output: [(), 0]
|
||||||
|
|
||||||
|
flatten_with_tuple_paths_up_to(0, [0, 1, 2]) # Output: [(), [0, 1, 2]]
|
||||||
|
|
||||||
|
flatten_with_tuple_paths_up_to([0, 1, 2], 0) # Output: TypeError
|
||||||
|
|
||||||
|
flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2])
|
||||||
|
# Output: [((0,) 0), ((1,), 1), ((2,), 2)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shallow_tree: a possibly pruned structure of input_tree.
|
||||||
|
input_tree: an arbitrarily nested structure or a scalar object.
|
||||||
|
Note, numpy arrays are considered scalars.
|
||||||
|
check_types: bool. If True, check that each node in shallow_tree has the
|
||||||
|
same type as the corresponding node in input_tree.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python list, the partially flattened version of `input_tree` according to
|
||||||
|
the structure of `shallow_tree`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
"""
|
||||||
|
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
|
||||||
|
return list(_yield_flat_up_to(shallow_tree, input_tree))
|
||||||
|
|
||||||
|
|
||||||
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
|
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||||
"""Applies a function or op to a number of partially flattened inputs.
|
"""Applies a function or op to a number of partially flattened inputs.
|
||||||
|
|
||||||
|
@ -686,6 +686,244 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
||||||
|
|
||||||
|
def testFlattenWithTuplePathsUpTo(self):
|
||||||
|
def get_paths_and_values(shallow_tree, input_tree):
|
||||||
|
path_value_pairs = nest.flatten_with_tuple_paths_up_to(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
paths = [p for p, _ in path_value_pairs]
|
||||||
|
values = [v for _, v in path_value_pairs]
|
||||||
|
return paths, values
|
||||||
|
|
||||||
|
# Shallow tree ends at scalar.
|
||||||
|
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
|
||||||
|
shallow_tree = [[True, True], [False, True]]
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths,
|
||||||
|
[(0, 0), (0, 1), (1, 0), (1, 1)])
|
||||||
|
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths,
|
||||||
|
[(0, 0), (0, 1), (1, 0), (1, 1)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
|
||||||
|
|
||||||
|
# Shallow tree ends at string.
|
||||||
|
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
|
||||||
|
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
|
||||||
|
(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
input_tree_flattened_paths = [p for p, _ in
|
||||||
|
nest.flatten_with_tuple_paths(input_tree)]
|
||||||
|
input_tree_flattened = nest.flatten(input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
[(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||||
|
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
|
||||||
|
|
||||||
|
self.assertEqual(input_tree_flattened_paths,
|
||||||
|
[(0, 0, 0), (0, 0, 1),
|
||||||
|
(0, 1, 0, 0), (0, 1, 0, 1),
|
||||||
|
(0, 1, 1, 0, 0), (0, 1, 1, 0, 1),
|
||||||
|
(0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)])
|
||||||
|
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
|
||||||
|
|
||||||
|
# Make sure dicts are correctly flattened, yielding values, not keys.
|
||||||
|
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
|
||||||
|
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
|
||||||
|
(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
[("a",), ("b",), ("d", 0), ("d", 1)])
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||||
|
[1, {"c": 2}, 3, (4, 5)])
|
||||||
|
|
||||||
|
# Namedtuples.
|
||||||
|
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||||
|
input_tree = ab_tuple(a=[0, 1], b=2)
|
||||||
|
shallow_tree = ab_tuple(a=0, b=1)
|
||||||
|
(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
[("a",), ("b",)])
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||||
|
[[0, 1], 2])
|
||||||
|
|
||||||
|
# Nested dicts, OrderedDicts and namedtuples.
|
||||||
|
input_tree = collections.OrderedDict(
|
||||||
|
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
|
||||||
|
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
|
||||||
|
shallow_tree = input_tree
|
||||||
|
(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
[("a", "a", 0),
|
||||||
|
("a", "a", 1, "b"),
|
||||||
|
("a", "b"),
|
||||||
|
("c", "d"),
|
||||||
|
("c", "e", "f")])
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
|
||||||
|
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
|
||||||
|
(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
[("a",),
|
||||||
|
("c", "d"),
|
||||||
|
("c", "e")])
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||||
|
[ab_tuple(a=[0, {"b": 1}], b=2),
|
||||||
|
3,
|
||||||
|
collections.OrderedDict([("f", 4)])])
|
||||||
|
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
|
||||||
|
(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
|
||||||
|
[("a",), ("c",)])
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||||
|
[ab_tuple(a=[0, {"b": 1}], b=2),
|
||||||
|
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
|
||||||
|
|
||||||
|
## Shallow non-list edge-case.
|
||||||
|
# Using iterable elements.
|
||||||
|
input_tree = ["input_tree"]
|
||||||
|
shallow_tree = "shallow_tree"
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
input_tree = ["input_tree_0", "input_tree_1"]
|
||||||
|
shallow_tree = "shallow_tree"
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
# Test case where len(shallow_tree) < len(input_tree)
|
||||||
|
input_tree = {"a": "A", "b": "B", "c": "C"}
|
||||||
|
shallow_tree = {"a": 1, "c": 2}
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [("a",), ("c",)])
|
||||||
|
self.assertEqual(flattened_input_tree, ["A", "C"])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [("a",), ("c",)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [1, 2])
|
||||||
|
|
||||||
|
# Using non-iterable elements.
|
||||||
|
input_tree = [0]
|
||||||
|
shallow_tree = 9
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
input_tree = [0, 1]
|
||||||
|
shallow_tree = 9
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
## Both non-list edge-case.
|
||||||
|
# Using iterable elements.
|
||||||
|
input_tree = "input_tree"
|
||||||
|
shallow_tree = "shallow_tree"
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
# Using non-iterable elements.
|
||||||
|
input_tree = 0
|
||||||
|
shallow_tree = 0
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [()])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
## Input non-list edge-case.
|
||||||
|
# Using iterable elements.
|
||||||
|
input_tree = "input_tree"
|
||||||
|
shallow_tree = ["shallow_tree"]
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
TypeError,
|
||||||
|
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [(0,)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
||||||
|
|
||||||
|
input_tree = "input_tree"
|
||||||
|
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
TypeError,
|
||||||
|
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
||||||
|
|
||||||
|
# Using non-iterable elements.
|
||||||
|
input_tree = 0
|
||||||
|
shallow_tree = [9]
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
TypeError,
|
||||||
|
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [(0,)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
||||||
|
|
||||||
|
input_tree = 0
|
||||||
|
shallow_tree = [9, 8]
|
||||||
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
TypeError,
|
||||||
|
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
|
||||||
|
(flattened_input_tree_paths,
|
||||||
|
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
|
||||||
|
(flattened_shallow_tree_paths,
|
||||||
|
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, shallow_tree)
|
||||||
|
|
||||||
def testMapStructureUpTo(self):
|
def testMapStructureUpTo(self):
|
||||||
# Named tuples.
|
# Named tuples.
|
||||||
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||||
|
Loading…
Reference in New Issue
Block a user