Roll back removal of "check_subtrees_length=True" from nest; rename "ignore_extra_entries=False"

Also added unit tests to ensure they work correctly.

PiperOrigin-RevId: 307482938
Change-Id: Id7510d0628b69767f8cc5b3c03a31fe031bc7527
This commit is contained in:
Eugene Brevdo 2020-04-20 14:57:58 -07:00 committed by TensorFlower Gardener
parent 14a5a76e76
commit de6c0ec676
2 changed files with 113 additions and 21 deletions

View File

@ -742,7 +742,8 @@ def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
def assert_shallow_structure(shallow_tree,
input_tree,
check_types=True,
expand_composites=False):
expand_composites=False,
ignore_extra_entries=False):
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
That is, this function tests if the `input_tree` structure can be created from
@ -765,6 +766,31 @@ def assert_shallow_structure(shallow_tree,
assert_shallow_structure(shallow_tree, input_tree)
```
The following code will **not** raise an exception:
```python
shallow_tree = ["a", "b"]
input_tree = ["c", ["d", "e"], "f"]
assert_shallow_structure(shallow_tree, input_tree,
ignore_extra_entries=True)
```
Similarly, the following code will also **not** raise an exception:
```python
shallow_tree = {"a": 1, "b": 2}
input_tree = {"a": [1], "b": [2], "c": [3]}
assert_shallow_structure(shallow_tree, input_tree,
ignore_extra_entries=True)
```
However, the following code **will** raise an exception:
```python
shallow_tree = {"a": [1], "b": [2], "c": [3]}
input_tree = {"a": 1, "b": 2}
assert_shallow_structure(shallow_tree, input_tree,
ignore_extra_entries=True)
```
because `shallow_tree` now has keys which are not found in `input_tree`.
Args:
shallow_tree: an arbitrarily nested structure.
input_tree: an arbitrarily nested structure.
@ -775,6 +801,11 @@ def assert_shallow_structure(shallow_tree,
expand_composites: If true, then composite tensors such as
`tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
component tensors.
ignore_extra_entries: if `False` (default) the subtrees `shallow_tree` and
`input_tree` have to be the same length. If `True` sequences are treated
as key-value like mappings allowing "larger" `input_tree` subtrees
to be considered as valid. Note that this may drop parts of the
`input_tree`.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
TypeError: If the sequence types of `shallow_tree` are different from
@ -840,7 +871,7 @@ def assert_shallow_structure(shallow_tree,
"be a TypeSpec. Input has type: %s."
% type(input_tree))
else:
if len(input_tree) != len(shallow_tree):
if not ignore_extra_entries and len(input_tree) != len(shallow_tree):
raise ValueError(
_STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
input_length=len(input_tree), shallow_length=len(shallow_tree)))
@ -859,11 +890,12 @@ def assert_shallow_structure(shallow_tree,
_yield_value(input_tree)):
assert_shallow_structure(shallow_branch, input_branch,
check_types=check_types,
expand_composites=expand_composites)
expand_composites=expand_composites,
ignore_extra_entries=ignore_extra_entries)
def flatten_up_to(shallow_tree, input_tree, check_types=True,
expand_composites=False):
expand_composites=False, ignore_extra_entries=False):
"""Flattens `input_tree` up to `shallow_tree`.
Any further depth in structure in `input_tree` is retained as elements in the
@ -916,6 +948,18 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
```
Non-Full-Subtree case:
```python
shallow_tree = ["a", "b"]
input_tree = ["c", ["d", "e"], "f"]
flattened = flatten_up_to(shallow_tree, input_tree,
ignore_extra_entries=True)
# Output is:
# ["c", ["d", "e"]]
```
Args:
shallow_tree: a possibly pruned structure of input_tree.
input_tree: an arbitrarily nested structure or a scalar object.
@ -925,6 +969,11 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
expand_composites: If true, then composite tensors such as
`tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
component tensors.
ignore_extra_entries: if `False` (default) the subtrees `shallow_tree` and
`input_tree` have to be the same length. If `True` sequences are treated
as key-value like mappings allowing "larger" `input_tree` subtrees
to be considered as valid. Note that this may drop parts of the
`input_tree`.
Returns:
A Python list, the partially flattened version of `input_tree` according to
@ -941,7 +990,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
assert_shallow_structure(shallow_tree,
input_tree,
check_types=check_types,
expand_composites=expand_composites)
expand_composites=expand_composites,
ignore_extra_entries=ignore_extra_entries)
# Discard paths returned by _yield_flat_up_to.
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq))
@ -949,7 +999,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True,
def flatten_with_tuple_paths_up_to(shallow_tree,
input_tree,
check_types=True,
expand_composites=False):
expand_composites=False,
ignore_extra_entries=False):
"""Flattens `input_tree` up to `shallow_tree`.
Any further depth in structure in `input_tree` is retained as elements in the
@ -1030,6 +1081,11 @@ def flatten_with_tuple_paths_up_to(shallow_tree,
expand_composites: If true, then composite tensors such as
`tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
component tensors.
ignore_extra_entries: if `False` (default) the subtrees `shallow_tree` and
`input_tree` have to be the same length. If `True` sequences are treated
as key-value like mappings allowing "larger" `input_tree` subtrees
to be considered as valid. Note that this may drop parts of the
`input_tree`.
Returns:
A Python list, the partially flattened version of `input_tree` according to
@ -1046,7 +1102,8 @@ def flatten_with_tuple_paths_up_to(shallow_tree,
assert_shallow_structure(shallow_tree,
input_tree,
check_types=check_types,
expand_composites=expand_composites)
expand_composites=expand_composites,
ignore_extra_entries=ignore_extra_entries)
return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
@ -1107,11 +1164,11 @@ def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
shallow_tree. The function `func` is applied to corresponding
partially flattened elements of each input, so the function must support
arity of `len(inputs)`.
**kwargs: kwargs to feed to func(). Special kwarg
`check_types` is not passed to func, but instead determines whether the
types of iterables within the structures have to be same (e.g.
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
this set this argument to `False`.
**kwargs: kwargs to feed to func(). Special kwargs
`check_types`, `expand_composites`, and `ignore_extra_entries` are not
passed to func, but instead to the structure comparison function. For
details on these arguments, see the documentation of
`assert_shallow_structure`.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
@ -1181,11 +1238,11 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
inputs_i_value is the corresponding value from inputs[i].
*inputs: nested structures that are all structurally compatible with
shallow_tree.
**kwargs: kwargs to feed to func(). Special kwarg
`check_types` is not passed to func, but instead determines whether the
types of iterables within the structures have to be same (e.g.
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
this set this argument to `False`.
**kwargs: kwargs to feed to func(). Special kwargs
`check_types`, `expand_composites`, and `ignore_extra_entries` are not
passed to func, but instead to the structure comparison function. For
details on these arguments, see the documentation of
`assert_shallow_structure`.
Raises:
TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
@ -1203,6 +1260,7 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
check_types = kwargs.pop("check_types", True)
expand_composites = kwargs.pop("expand_composites", False)
ignore_extra_entries = kwargs.pop("ignore_extra_entries", False)
is_seq = is_sequence_or_composite if expand_composites else is_sequence
for input_tree in inputs:
@ -1210,7 +1268,8 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
shallow_tree,
input_tree,
check_types=check_types,
expand_composites=expand_composites)
expand_composites=expand_composites,
ignore_extra_entries=ignore_extra_entries)
# Flatten each input separately, apply the function to corresponding elements,
# then repack based on the structure of the first input.
@ -1219,7 +1278,8 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
shallow_tree,
input_tree,
check_types,
expand_composites=expand_composites) for input_tree in inputs
expand_composites=expand_composites,
ignore_extra_entries=ignore_extra_entries) for input_tree in inputs
]
flat_path_list = [path for path, _
in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)]

View File

@ -585,6 +585,10 @@ class NestTest(parameterized.TestCase, test.TestCase):
shallow_length=len(inp_abc))):
nest.assert_shallow_structure(inp_abc, inp_ab)
# shallow structure may be "smaller" than input structure if
# ignore_extra_entries=True.
nest.assert_shallow_structure(inp_ab, inp_abc, ignore_extra_entries=True)
inp_ab1 = [(1, 1), (2, 2)]
inp_ab2 = [[1, 1], [2, 2]]
with self.assertRaisesWithLiteralMatch(
@ -761,10 +765,19 @@ class NestTest(parameterized.TestCase, test.TestCase):
with self.assertRaisesRegexp(ValueError, expected_message): # pylint: disable=g-error-prone-assert-raises
nest.assert_shallow_structure(shallow_tree, input_tree)
input_tree = {"c": "c", "d": "d", "b": "b", "a": "a"}
shallow_tree = {"c": 1, "a": 2}
flattened_shallow_tree = nest.flatten_up_to(
shallow_tree, input_tree, ignore_extra_entries=True)
# Returns values of input_tree associated with keys of shallow_tree; and
# keys are in sorted lexicographic order.
self.assertEqual(flattened_shallow_tree, ["a", "c"])
def testFlattenWithTuplePathsUpTo(self):
def get_paths_and_values(shallow_tree, input_tree):
def get_paths_and_values(shallow_tree, input_tree,
ignore_extra_entries=False):
path_value_pairs = nest.flatten_with_tuple_paths_up_to(
shallow_tree, input_tree)
shallow_tree, input_tree, ignore_extra_entries=ignore_extra_entries)
paths = [p for p, _ in path_value_pairs]
values = [v for _, v in path_value_pairs]
return paths, values
@ -899,6 +912,16 @@ class NestTest(parameterized.TestCase, test.TestCase):
shallow_length=len(shallow_tree))):
get_paths_and_values(shallow_tree, input_tree)
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree,
ignore_extra_entries=True)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [("a",), ("c",)])
self.assertEqual(flattened_input_tree, ["A", "C"])
self.assertEqual(flattened_shallow_tree_paths, [("a",), ("c",)])
self.assertEqual(flattened_shallow_tree, [1, 2])
# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
@ -1055,6 +1078,15 @@ class NestTest(parameterized.TestCase, test.TestCase):
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
shallow_structure = {"c": 1, "a": 2}
input_structure = {"a": "a", "b": "b", "c": "c", "d": "d"}
out = nest.map_structure_up_to(
shallow_structure,
lambda x: x + "0",
input_structure,
ignore_extra_entries=True)
self.assertEqual(out, {"a": "a0", "c": "c0"})
def testGetTraverseShallowStructure(self):
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
scalar_traverse_r = nest.get_traverse_shallow_structure(