Roll back removal of "check_subtrees_length=True" from nest; rename "ignore_extra_entries=False"
Also added unit tests to ensure they work correctly. PiperOrigin-RevId: 307482938 Change-Id: Id7510d0628b69767f8cc5b3c03a31fe031bc7527
This commit is contained in:
parent
14a5a76e76
commit
de6c0ec676
@ -742,7 +742,8 @@ def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
|
||||
def assert_shallow_structure(shallow_tree,
|
||||
input_tree,
|
||||
check_types=True,
|
||||
expand_composites=False):
|
||||
expand_composites=False,
|
||||
ignore_extra_entries=False):
|
||||
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
|
||||
|
||||
That is, this function tests if the `input_tree` structure can be created from
|
||||
@ -765,6 +766,31 @@ def assert_shallow_structure(shallow_tree,
|
||||
assert_shallow_structure(shallow_tree, input_tree)
|
||||
```
|
||||
|
||||
The following code will **not** raise an exception:
|
||||
```python
|
||||
shallow_tree = ["a", "b"]
|
||||
input_tree = ["c", ["d", "e"], "f"]
|
||||
assert_shallow_structure(shallow_tree, input_tree,
|
||||
ignore_extra_entries=True)
|
||||
```
|
||||
|
||||
Similarly, the following code will also **not** raise an exception:
|
||||
```python
|
||||
shallow_tree = {"a": 1, "b": 2}
|
||||
input_tree = {"a": [1], "b": [2], "c": [3]}
|
||||
assert_shallow_structure(shallow_tree, input_tree,
|
||||
ignore_extra_entries=True)
|
||||
```
|
||||
|
||||
However, the following code **will** raise an exception:
|
||||
```python
|
||||
shallow_tree = {"a": [1], "b": [2], "c": [3]}
|
||||
input_tree = {"a": 1, "b": 2}
|
||||
assert_shallow_structure(shallow_tree, input_tree,
|
||||
ignore_extra_entries=True)
|
||||
```
|
||||
because `shallow_tree` now has keys which are not found in `input_tree`.
|
||||
|
||||
Args:
|
||||
shallow_tree: an arbitrarily nested structure.
|
||||
input_tree: an arbitrarily nested structure.
|
||||
@ -775,6 +801,11 @@ def assert_shallow_structure(shallow_tree,
|
||||
expand_composites: If true, then composite tensors such as
|
||||
`tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
|
||||
component tensors.
|
||||
ignore_extra_entries: if `False` (default) the subtrees `shallow_tree` and
|
||||
`input_tree` have to be the same length. If `True` sequences are treated
|
||||
as key-value like mappings allowing "larger" `input_tree` subtrees
|
||||
to be considered as valid. Note that this may drop parts of the
|
||||
`input_tree`.
|
||||
Raises:
|
||||
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||
TypeError: If the sequence types of `shallow_tree` are different from
|
||||
@ -840,7 +871,7 @@ def assert_shallow_structure(shallow_tree,
|
||||
"be a TypeSpec. Input has type: %s."
|
||||
% type(input_tree))
|
||||
else:
|
||||
if len(input_tree) != len(shallow_tree):
|
||||
if not ignore_extra_entries and len(input_tree) != len(shallow_tree):
|
||||
raise ValueError(
|
||||
_STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
|
||||
input_length=len(input_tree), shallow_length=len(shallow_tree)))
|
||||
@ -859,11 +890,12 @@ def assert_shallow_structure(shallow_tree,
|
||||
_yield_value(input_tree)):
|
||||
assert_shallow_structure(shallow_branch, input_branch,
|
||||
check_types=check_types,
|
||||
expand_composites=expand_composites)
|
||||
expand_composites=expand_composites,
|
||||
ignore_extra_entries=ignore_extra_entries)
|
||||
|
||||
|
||||
def flatten_up_to(shallow_tree, input_tree, check_types=True,
|
||||
expand_composites=False):
|
||||
expand_composites=False, ignore_extra_entries=False):
|
||||
"""Flattens `input_tree` up to `shallow_tree`.
|
||||
|
||||
Any further depth in structure in `input_tree` is retained as elements in the
|
||||
@ -916,6 +948,18 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
|
||||
flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
|
||||
```
|
||||
|
||||
Non-Full-Subtree case:
|
||||
|
||||
```python
|
||||
shallow_tree = ["a", "b"]
|
||||
input_tree = ["c", ["d", "e"], "f"]
|
||||
flattened = flatten_up_to(shallow_tree, input_tree,
|
||||
ignore_extra_entries=True)
|
||||
|
||||
# Output is:
|
||||
# ["c", ["d", "e"]]
|
||||
```
|
||||
|
||||
Args:
|
||||
shallow_tree: a possibly pruned structure of input_tree.
|
||||
input_tree: an arbitrarily nested structure or a scalar object.
|
||||
@ -925,6 +969,11 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
|
||||
expand_composites: If true, then composite tensors such as
|
||||
`tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
|
||||
component tensors.
|
||||
ignore_extra_entries: if `False` (default) the subtrees `shallow_tree` and
|
||||
`input_tree` have to be the same length. If `True` sequences are treated
|
||||
as key-value like mappings allowing "larger" `input_tree` subtrees
|
||||
to be considered as valid. Note that this may drop parts of the
|
||||
`input_tree`.
|
||||
|
||||
Returns:
|
||||
A Python list, the partially flattened version of `input_tree` according to
|
||||
@ -941,7 +990,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
|
||||
assert_shallow_structure(shallow_tree,
|
||||
input_tree,
|
||||
check_types=check_types,
|
||||
expand_composites=expand_composites)
|
||||
expand_composites=expand_composites,
|
||||
ignore_extra_entries=ignore_extra_entries)
|
||||
# Discard paths returned by _yield_flat_up_to.
|
||||
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq))
|
||||
|
||||
@ -949,7 +999,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
|
||||
def flatten_with_tuple_paths_up_to(shallow_tree,
|
||||
input_tree,
|
||||
check_types=True,
|
||||
expand_composites=False):
|
||||
expand_composites=False,
|
||||
ignore_extra_entries=False):
|
||||
"""Flattens `input_tree` up to `shallow_tree`.
|
||||
|
||||
Any further depth in structure in `input_tree` is retained as elements in the
|
||||
@ -1030,6 +1081,11 @@ def flatten_with_tuple_paths_up_to(shallow_tree,
|
||||
expand_composites: If true, then composite tensors such as
|
||||
`tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
|
||||
component tensors.
|
||||
ignore_extra_entries: if `False` (default) the subtrees `shallow_tree` and
|
||||
`input_tree` have to be the same length. If `True` sequences are treated
|
||||
as key-value like mappings allowing "larger" `input_tree` subtrees
|
||||
to be considered as valid. Note that this may drop parts of the
|
||||
`input_tree`.
|
||||
|
||||
Returns:
|
||||
A Python list, the partially flattened version of `input_tree` according to
|
||||
@ -1046,7 +1102,8 @@ def flatten_with_tuple_paths_up_to(shallow_tree,
|
||||
assert_shallow_structure(shallow_tree,
|
||||
input_tree,
|
||||
check_types=check_types,
|
||||
expand_composites=expand_composites)
|
||||
expand_composites=expand_composites,
|
||||
ignore_extra_entries=ignore_extra_entries)
|
||||
return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
|
||||
|
||||
|
||||
@ -1107,11 +1164,11 @@ def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||
shallow_tree. The function `func` is applied to corresponding
|
||||
partially flattened elements of each input, so the function must support
|
||||
arity of `len(inputs)`.
|
||||
**kwargs: kwargs to feed to func(). Special kwarg
|
||||
`check_types` is not passed to func, but instead determines whether the
|
||||
types of iterables within the structures have to be same (e.g.
|
||||
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
|
||||
this set this argument to `False`.
|
||||
**kwargs: kwargs to feed to func(). Special kwargs
|
||||
`check_types`, `expand_composites`, and `ignore_extra_entries` are not
|
||||
passed to func, but instead to the structure comparison function. For
|
||||
details on these arguments, see the documentation of
|
||||
`assert_shallow_structure`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||
@ -1181,11 +1238,11 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||
inputs_i_value is the corresponding value from inputs[i].
|
||||
*inputs: nested structures that are all structurally compatible with
|
||||
shallow_tree.
|
||||
**kwargs: kwargs to feed to func(). Special kwarg
|
||||
`check_types` is not passed to func, but instead determines whether the
|
||||
types of iterables within the structures have to be same (e.g.
|
||||
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
|
||||
this set this argument to `False`.
|
||||
**kwargs: kwargs to feed to func(). Special kwargs
|
||||
`check_types`, `expand_composites`, and `ignore_extra_entries` are not
|
||||
passed to func, but instead to the structure comparison function. For
|
||||
details on these arguments, see the documentation of
|
||||
`assert_shallow_structure`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
|
||||
@ -1203,6 +1260,7 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||
|
||||
check_types = kwargs.pop("check_types", True)
|
||||
expand_composites = kwargs.pop("expand_composites", False)
|
||||
ignore_extra_entries = kwargs.pop("ignore_extra_entries", False)
|
||||
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||
|
||||
for input_tree in inputs:
|
||||
@ -1210,7 +1268,8 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||
shallow_tree,
|
||||
input_tree,
|
||||
check_types=check_types,
|
||||
expand_composites=expand_composites)
|
||||
expand_composites=expand_composites,
|
||||
ignore_extra_entries=ignore_extra_entries)
|
||||
|
||||
# Flatten each input separately, apply the function to corresponding elements,
|
||||
# then repack based on the structure of the first input.
|
||||
@ -1219,7 +1278,8 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||
shallow_tree,
|
||||
input_tree,
|
||||
check_types,
|
||||
expand_composites=expand_composites) for input_tree in inputs
|
||||
expand_composites=expand_composites,
|
||||
ignore_extra_entries=ignore_extra_entries) for input_tree in inputs
|
||||
]
|
||||
flat_path_list = [path for path, _
|
||||
in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)]
|
||||
|
@ -585,6 +585,10 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
||||
shallow_length=len(inp_abc))):
|
||||
nest.assert_shallow_structure(inp_abc, inp_ab)
|
||||
|
||||
# shallow structure may be "smaller" than input structure if
|
||||
# ignore_extra_entries=True.
|
||||
nest.assert_shallow_structure(inp_ab, inp_abc, ignore_extra_entries=True)
|
||||
|
||||
inp_ab1 = [(1, 1), (2, 2)]
|
||||
inp_ab2 = [[1, 1], [2, 2]]
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
@ -761,10 +765,19 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, expected_message): # pylint: disable=g-error-prone-assert-raises
|
||||
nest.assert_shallow_structure(shallow_tree, input_tree)
|
||||
|
||||
input_tree = {"c": "c", "d": "d", "b": "b", "a": "a"}
|
||||
shallow_tree = {"c": 1, "a": 2}
|
||||
flattened_shallow_tree = nest.flatten_up_to(
|
||||
shallow_tree, input_tree, ignore_extra_entries=True)
|
||||
# Returns values of input_tree associated with keys of shallow_tree; and
|
||||
# keys are in sorted lexicographic order.
|
||||
self.assertEqual(flattened_shallow_tree, ["a", "c"])
|
||||
|
||||
def testFlattenWithTuplePathsUpTo(self):
|
||||
def get_paths_and_values(shallow_tree, input_tree):
|
||||
def get_paths_and_values(shallow_tree, input_tree,
|
||||
ignore_extra_entries=False):
|
||||
path_value_pairs = nest.flatten_with_tuple_paths_up_to(
|
||||
shallow_tree, input_tree)
|
||||
shallow_tree, input_tree, ignore_extra_entries=ignore_extra_entries)
|
||||
paths = [p for p, _ in path_value_pairs]
|
||||
values = [v for _, v in path_value_pairs]
|
||||
return paths, values
|
||||
@ -899,6 +912,16 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
||||
shallow_length=len(shallow_tree))):
|
||||
get_paths_and_values(shallow_tree, input_tree)
|
||||
|
||||
(flattened_input_tree_paths,
|
||||
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree,
|
||||
ignore_extra_entries=True)
|
||||
(flattened_shallow_tree_paths,
|
||||
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_input_tree_paths, [("a",), ("c",)])
|
||||
self.assertEqual(flattened_input_tree, ["A", "C"])
|
||||
self.assertEqual(flattened_shallow_tree_paths, [("a",), ("c",)])
|
||||
self.assertEqual(flattened_shallow_tree, [1, 2])
|
||||
|
||||
# Using non-iterable elements.
|
||||
input_tree = [0]
|
||||
shallow_tree = 9
|
||||
@ -1055,6 +1078,15 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
||||
inp_val,
|
||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
|
||||
|
||||
shallow_structure = {"c": 1, "a": 2}
|
||||
input_structure = {"a": "a", "b": "b", "c": "c", "d": "d"}
|
||||
out = nest.map_structure_up_to(
|
||||
shallow_structure,
|
||||
lambda x: x + "0",
|
||||
input_structure,
|
||||
ignore_extra_entries=True)
|
||||
self.assertEqual(out, {"a": "a0", "c": "c0"})
|
||||
|
||||
def testGetTraverseShallowStructure(self):
|
||||
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
|
||||
scalar_traverse_r = nest.get_traverse_shallow_structure(
|
||||
|
Loading…
Reference in New Issue
Block a user