Added map_structure_with_tuple_paths_up_to(), rewrote existing functions in tree.py in terms of it, as they are subsets of its functionality.
This cleans the behavior of traversing with shallow_trees, in that it now allows for any shallow_tree that's an upper subset of the input_tree. For example, before this change, the following was not allowed, on the basis that len(shallow_tree) < len(input_tree): shallow_tree = {'a': 'A','c': 'C'} input_tree = {'a': 'A', 'b': 'B', 'c': 'C'} map_structure_up_to(shallow_tree, input_tree, func) PiperOrigin-RevId: 230883405
This commit is contained in:
parent
53d3a75b7b
commit
eda13c8416
@ -136,12 +136,15 @@ _nest_allowed_symbols = [
|
|||||||
'pack_sequence_as',
|
'pack_sequence_as',
|
||||||
'map_structure',
|
'map_structure',
|
||||||
'map_structure_with_paths',
|
'map_structure_with_paths',
|
||||||
|
'map_structure_with_tuple_paths',
|
||||||
'assert_shallow_structure',
|
'assert_shallow_structure',
|
||||||
'flatten_up_to',
|
'flatten_up_to',
|
||||||
'map_structure_up_to',
|
'map_structure_up_to',
|
||||||
|
'map_structure_with_tuple_paths_up_to',
|
||||||
'get_traverse_shallow_structure',
|
'get_traverse_shallow_structure',
|
||||||
'yield_flat_paths',
|
'yield_flat_paths',
|
||||||
'flatten_with_joined_string_paths',
|
'flatten_with_joined_string_paths',
|
||||||
|
'flatten_with_tuple_paths',
|
||||||
]
|
]
|
||||||
|
|
||||||
remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols)
|
remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols)
|
||||||
|
@ -41,10 +41,38 @@ import six as _six
|
|||||||
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
|
||||||
|
|
||||||
|
|
||||||
def _get_attrs_values(obj):
|
_SHALLOW_TREE_HAS_INVALID_KEYS = (
|
||||||
"""Returns the list of values from an attrs instance."""
|
"The shallow_tree's keys are not a subset of the input_tree's keys. The "
|
||||||
|
"shallow_tree has the following keys that are not in the input_tree: {}.")
|
||||||
|
|
||||||
|
_STRUCTURES_HAVE_MISMATCHING_TYPES = (
|
||||||
|
"The two structures don't have the same sequence type. Input structure has "
|
||||||
|
"type {shallow_type}, while shallow structure has type {input_type}.")
|
||||||
|
|
||||||
|
_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
|
||||||
|
"The input_tree has fewer elements than the input_tree. Input structure "
|
||||||
|
"has length {input_size}, while shallow structure has length "
|
||||||
|
"{shallow_size}.")
|
||||||
|
|
||||||
|
_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
|
||||||
|
"If shallow structure is a sequence, input must also be a sequence. "
|
||||||
|
"Input has type: {}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attrs_items(obj):
|
||||||
|
"""Returns a list of (name, value) pairs from an attrs instance.
|
||||||
|
|
||||||
|
The list will be sorted by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: an object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (attr_name, attr_value) pairs, sorted by attr_name.
|
||||||
|
"""
|
||||||
attrs = getattr(obj.__class__, "__attrs_attrs__")
|
attrs = getattr(obj.__class__, "__attrs_attrs__")
|
||||||
return [getattr(obj, a.name) for a in attrs]
|
attr_names = sorted([a.name for a in attrs])
|
||||||
|
return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
|
||||||
|
|
||||||
|
|
||||||
def _sorted(dict_):
|
def _sorted(dict_):
|
||||||
@ -106,24 +134,45 @@ def _sequence_like(instance, args):
|
|||||||
|
|
||||||
|
|
||||||
def _yield_value(iterable):
|
def _yield_value(iterable):
|
||||||
"""Yields the next value from the given iterable."""
|
for _, v in _yield_sorted_items(iterable):
|
||||||
if _is_mapping(iterable):
|
yield v
|
||||||
|
|
||||||
|
|
||||||
|
def _yield_sorted_items(iterable):
|
||||||
|
"""Yield (key, value) pairs for `iterable` in a deterministic order.
|
||||||
|
|
||||||
|
For Sequences, the key will be an int, the array index of a value.
|
||||||
|
For Mappings, the key will be the dictionary key.
|
||||||
|
For objects (e.g. namedtuples), the key will be the attribute name.
|
||||||
|
|
||||||
|
In all cases, the keys will be iterated in sorted order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iterable: an iterable.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The iterable's (key, value) pairs, in order of sorted keys.
|
||||||
|
"""
|
||||||
|
if isinstance(iterable, _collections.Mapping):
|
||||||
# Iterate through dictionaries in a deterministic order by sorting the
|
# Iterate through dictionaries in a deterministic order by sorting the
|
||||||
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
# keys. Notice this means that we ignore the original order of `OrderedDict`
|
||||||
# instances. This is intentional, to avoid potential bugs caused by mixing
|
# instances. This is intentional, to avoid potential bugs caused by mixing
|
||||||
# ordered and plain dicts (e.g., flattening a dict but using a
|
# ordered and plain dicts (e.g., flattening a dict but using a
|
||||||
# corresponding `OrderedDict` to pack it back).
|
# corresponding `OrderedDict` to pack it back).
|
||||||
for key in _sorted(iterable):
|
for key in _sorted(iterable):
|
||||||
yield iterable[key]
|
yield key, iterable[key]
|
||||||
elif _is_attrs(iterable):
|
elif _is_attrs(iterable):
|
||||||
for value in _get_attrs_values(iterable):
|
for item in _get_attrs_items(iterable):
|
||||||
yield value
|
yield item
|
||||||
|
elif _is_namedtuple(iterable):
|
||||||
|
for field in iterable._fields:
|
||||||
|
yield field, getattr(iterable, field)
|
||||||
elif _is_composite_tensor(iterable):
|
elif _is_composite_tensor(iterable):
|
||||||
for value in _yield_value(iterable._to_components()): # pylint: disable=protected-access
|
for item in enumerate(iterable._to_components()): # pylint: disable=protected-access
|
||||||
yield value
|
yield item
|
||||||
else:
|
else:
|
||||||
for value in iterable:
|
for item in enumerate(iterable):
|
||||||
yield value
|
yield item
|
||||||
|
|
||||||
|
|
||||||
# See the swig file (util.i) for documentation.
|
# See the swig file (util.i) for documentation.
|
||||||
@ -442,8 +491,15 @@ def map_structure_with_paths(func, *structure, **kwargs):
|
|||||||
the type of sequence in any of their substructures.
|
the type of sequence in any of their substructures.
|
||||||
ValueError: If no structures are provided.
|
ValueError: If no structures are provided.
|
||||||
"""
|
"""
|
||||||
return _map_structure_with_tuple_or_string_paths(
|
print("wheee I'm updated")
|
||||||
use_string_paths=True, func=func, structure=structure, kwargs=kwargs)
|
def wrapper_func(tuple_path, *inputs, **kwargs):
|
||||||
|
string_path = "/".join(str(s) for s in tuple_path)
|
||||||
|
return func(string_path, *inputs, **kwargs)
|
||||||
|
|
||||||
|
return map_structure_with_tuple_paths_up_to(structure[0],
|
||||||
|
wrapper_func,
|
||||||
|
*structure,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def map_structure_with_tuple_paths(func, *structure, **kwargs):
|
def map_structure_with_tuple_paths(func, *structure, **kwargs):
|
||||||
@ -479,52 +535,43 @@ def map_structure_with_tuple_paths(func, *structure, **kwargs):
|
|||||||
the type of sequence in any of their substructures.
|
the type of sequence in any of their substructures.
|
||||||
ValueError: If no structures are provided.
|
ValueError: If no structures are provided.
|
||||||
"""
|
"""
|
||||||
return _map_structure_with_tuple_or_string_paths(
|
return map_structure_with_tuple_paths_up_to(structure[0],
|
||||||
use_string_paths=False, func=func, structure=structure, kwargs=kwargs)
|
func,
|
||||||
|
*structure,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _map_structure_with_tuple_or_string_paths(
|
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
||||||
use_string_paths, func, structure, kwargs):
|
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
|
||||||
"""Implements `map_structure` with either tuple or string paths."""
|
|
||||||
|
|
||||||
if not callable(func):
|
Args:
|
||||||
raise TypeError("func must be callable, got: %s" % func)
|
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
|
||||||
if not structure:
|
input_tree: Nested structure. Return the paths and values from this tree.
|
||||||
raise ValueError("Must provide at least one structure")
|
Must have the same upper structure as shallow_tree.
|
||||||
|
path: Tuple. Optional argument, only used when recursing. The path from the
|
||||||
|
root of the original shallow_tree, down to the root of the shallow_tree
|
||||||
|
arg of this recursive call.
|
||||||
|
|
||||||
check_types = kwargs.pop("check_types", True)
|
Yields:
|
||||||
for other in structure[1:]:
|
Pairs of (path, value), where path the tuple path of a leaf node in
|
||||||
assert_same_structure(structure[0], other, check_types=check_types)
|
shallow_tree, and value is the value of the corresponding node in
|
||||||
|
input_tree.
|
||||||
if use_string_paths:
|
"""
|
||||||
flatten_func = flatten_with_joined_string_paths
|
if (isinstance(shallow_tree, _six.string_types) or
|
||||||
|
not any([isinstance(shallow_tree, _collections.Sequence),
|
||||||
|
isinstance(shallow_tree, _collections.Mapping),
|
||||||
|
_is_namedtuple(shallow_tree),
|
||||||
|
_is_attrs(shallow_tree)])):
|
||||||
|
yield (path, input_tree)
|
||||||
else:
|
else:
|
||||||
flatten_func = flatten_with_tuple_paths
|
input_tree = dict(_yield_sorted_items(input_tree))
|
||||||
|
for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
|
||||||
# First set paths_and_values to:
|
subpath = path + (shallow_key,)
|
||||||
# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]]
|
input_subtree = input_tree[shallow_key]
|
||||||
paths_and_values = [flatten_func(s) for s in structure]
|
for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
|
||||||
|
input_subtree,
|
||||||
# Now zip(*paths_and_values) would be:
|
path=subpath):
|
||||||
# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))]
|
yield (leaf_path, leaf_value)
|
||||||
# so grouped_by_path is set to:
|
|
||||||
# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]]
|
|
||||||
# Note that p1i, ... pmi must all be equal since the structures are the same.
|
|
||||||
grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)]
|
|
||||||
|
|
||||||
return pack_sequence_as(structure[0], [
|
|
||||||
func(paths[0], *values, **kwargs) for paths, values in grouped_by_path])
|
|
||||||
|
|
||||||
|
|
||||||
def _yield_flat_up_to(shallow_tree, input_tree):
|
|
||||||
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
|
||||||
if is_sequence(shallow_tree):
|
|
||||||
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
|
||||||
_yield_value(input_tree)):
|
|
||||||
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
|
||||||
yield input_leaf
|
|
||||||
else:
|
|
||||||
yield input_tree
|
|
||||||
|
|
||||||
|
|
||||||
def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
||||||
@ -538,8 +585,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
|||||||
|
|
||||||
The following code will raise an exception:
|
The following code will raise an exception:
|
||||||
```python
|
```python
|
||||||
shallow_tree = ["a", "b"]
|
shallow_tree = {"a": "A", "b": "B"}
|
||||||
input_tree = ["c", ["d", "e"], "f"]
|
input_tree = {"a": 1, "c": 2}
|
||||||
assert_shallow_structure(shallow_tree, input_tree)
|
assert_shallow_structure(shallow_tree, input_tree)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -578,40 +625,34 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
|||||||
input_is_namedtuple = _is_namedtuple(input_tree, False)
|
input_is_namedtuple = _is_namedtuple(input_tree, False)
|
||||||
if shallow_is_namedtuple and input_is_namedtuple:
|
if shallow_is_namedtuple and input_is_namedtuple:
|
||||||
if not _same_namedtuples(shallow_tree, input_tree):
|
if not _same_namedtuples(shallow_tree, input_tree):
|
||||||
raise TypeError(
|
raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
|
||||||
"The two namedtuples don't have the same sequence type. Input "
|
input_type=type(input_tree),
|
||||||
"structure has type %s, while shallow structure has type %s."
|
shallow_type=type(shallow_tree)))
|
||||||
% (type(input_tree), type(shallow_tree)))
|
|
||||||
elif not (isinstance(shallow_tree, _collections.Mapping)
|
elif not (isinstance(shallow_tree, _collections.Mapping)
|
||||||
and isinstance(input_tree, _collections.Mapping)):
|
and isinstance(input_tree, _collections.Mapping)):
|
||||||
raise TypeError(
|
raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
|
||||||
"The two structures don't have the same sequence type. Input "
|
input_type=type(input_tree),
|
||||||
"structure has type %s, while shallow structure has type %s."
|
shallow_type=type(shallow_tree)))
|
||||||
% (type(input_tree), type(shallow_tree)))
|
|
||||||
|
|
||||||
if len(input_tree) != len(shallow_tree):
|
if len(input_tree) < len(shallow_tree):
|
||||||
raise ValueError(
|
raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
|
||||||
"The two structures don't have the same sequence length. Input "
|
input_size=len(input_tree),
|
||||||
"structure has length %s, while shallow structure has length %s."
|
shallow_size=len(shallow_tree)))
|
||||||
% (len(input_tree), len(shallow_tree)))
|
|
||||||
|
|
||||||
if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)):
|
if isinstance(shallow_tree, _collections.Mapping):
|
||||||
if set(input_tree) != set(shallow_tree):
|
absent_keys = set(shallow_tree) - set(input_tree)
|
||||||
raise ValueError(
|
if absent_keys:
|
||||||
"The two structures don't have the same keys. Input "
|
raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
|
||||||
"structure has keys %s, while shallow structure has keys %s." %
|
.format(sorted(absent_keys)))
|
||||||
(list(_six.iterkeys(input_tree)),
|
|
||||||
list(_six.iterkeys(shallow_tree))))
|
|
||||||
|
|
||||||
input_tree = list(sorted(_six.iteritems(input_tree)))
|
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||||||
shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
|
_yield_value(input_tree)):
|
||||||
|
|
||||||
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
|
|
||||||
assert_shallow_structure(shallow_branch, input_branch,
|
assert_shallow_structure(shallow_branch, input_branch,
|
||||||
check_types=check_types)
|
check_types=check_types)
|
||||||
|
|
||||||
|
|
||||||
def flatten_up_to(shallow_tree, input_tree):
|
def flatten_up_to(shallow_tree, input_tree, check_types=True):
|
||||||
"""Flattens `input_tree` up to `shallow_tree`.
|
"""Flattens `input_tree` up to `shallow_tree`.
|
||||||
|
|
||||||
Any further depth in structure in `input_tree` is retained as elements in the
|
Any further depth in structure in `input_tree` is retained as elements in the
|
||||||
@ -668,6 +709,8 @@ def flatten_up_to(shallow_tree, input_tree):
|
|||||||
shallow_tree: a possibly pruned structure of input_tree.
|
shallow_tree: a possibly pruned structure of input_tree.
|
||||||
input_tree: an arbitrarily nested structure or a scalar object.
|
input_tree: an arbitrarily nested structure or a scalar object.
|
||||||
Note, numpy arrays are considered scalars.
|
Note, numpy arrays are considered scalars.
|
||||||
|
check_types: bool. If True, check that each node in shallow_tree has the
|
||||||
|
same type as the corresponding node in input_tree.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python list, the partially flattened version of `input_tree` according to
|
A Python list, the partially flattened version of `input_tree` according to
|
||||||
@ -680,11 +723,12 @@ def flatten_up_to(shallow_tree, input_tree):
|
|||||||
ValueError: If the sequence lengths of `shallow_tree` are different from
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
`input_tree`.
|
`input_tree`.
|
||||||
"""
|
"""
|
||||||
assert_shallow_structure(shallow_tree, input_tree)
|
assert_shallow_structure(shallow_tree, input_tree, check_types)
|
||||||
return list(_yield_flat_up_to(shallow_tree, input_tree))
|
# Discard paths returned by _yield_flat_up_to.
|
||||||
|
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
|
||||||
|
|
||||||
|
|
||||||
def map_structure_up_to(shallow_tree, func, *inputs):
|
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||||
"""Applies a function or op to a number of partially flattened inputs.
|
"""Applies a function or op to a number of partially flattened inputs.
|
||||||
|
|
||||||
The `inputs` are flattened up to `shallow_tree` before being mapped.
|
The `inputs` are flattened up to `shallow_tree` before being mapped.
|
||||||
@ -733,6 +777,11 @@ def map_structure_up_to(shallow_tree, func, *inputs):
|
|||||||
shallow_tree. The function `func` is applied to corresponding
|
shallow_tree. The function `func` is applied to corresponding
|
||||||
partially flattened elements of each input, so the function must support
|
partially flattened elements of each input, so the function must support
|
||||||
arity of `len(inputs)`.
|
arity of `len(inputs)`.
|
||||||
|
**kwargs: kwargs to feed to func(). Special kwarg
|
||||||
|
`check_types` is not passed to func, but instead determines whether the
|
||||||
|
types of iterables within the structures have to be same (e.g.
|
||||||
|
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
|
||||||
|
this set this argument to `False`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||||
@ -745,16 +794,93 @@ def map_structure_up_to(shallow_tree, func, *inputs):
|
|||||||
result of repeatedly applying `func`, with same structure as
|
result of repeatedly applying `func`, with same structure as
|
||||||
`shallow_tree`.
|
`shallow_tree`.
|
||||||
"""
|
"""
|
||||||
|
return map_structure_with_tuple_paths_up_to(
|
||||||
|
shallow_tree,
|
||||||
|
lambda _, *values: func(*values), # Discards the path arg.
|
||||||
|
*inputs,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||||
|
"""Applies a function or op to a number of partially flattened inputs.
|
||||||
|
|
||||||
|
Like map_structure_up_to(), except that the 'func' argument takes a path
|
||||||
|
tuple as its first argument, followed by the corresponding values from
|
||||||
|
*inputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
lowercase = {'a': 'a', 'b': ('b0', 'b1')}
|
||||||
|
uppercase = {'a': 'A', 'b': ('B0', 'B1')}
|
||||||
|
|
||||||
|
def print_path_and_values(path, *values):
|
||||||
|
print("path: {}, values: {}".format(path, values))
|
||||||
|
|
||||||
|
shallow_tree = {'a': None}
|
||||||
|
map_structure_with_tuple_paths_up_to(shallow_tree,
|
||||||
|
print_path_and_values,
|
||||||
|
lowercase,
|
||||||
|
uppercase)
|
||||||
|
>>> path: ('a',), values: ('a', 'A')
|
||||||
|
>>> path: ('b', 0), values: ('b0', 'B0')
|
||||||
|
>>> path: ('b', 1), values: ('b1', 'B1')
|
||||||
|
|
||||||
|
shallow_tree = {'b': None}
|
||||||
|
map_structure_with_tuple_paths_up_to(shallow_tree,
|
||||||
|
print_path_and_values,
|
||||||
|
lowercase,
|
||||||
|
uppercase,
|
||||||
|
check_types=False)
|
||||||
|
>>> path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1'))
|
||||||
|
|
||||||
|
shallow_tree = {'a': None, 'b': {1: None}}
|
||||||
|
map_structure_with_tuple_paths_up_to(shallow_tree,
|
||||||
|
print_path_and_values,
|
||||||
|
lowercase,
|
||||||
|
uppercase,
|
||||||
|
check_types=False)
|
||||||
|
>>> path: ('a',), values: ('a', 'A')
|
||||||
|
>>> path: ('b', 1), values: ('b1', B1')
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shallow_tree: a shallow tree, common to all the inputs.
|
||||||
|
func: callable that takes args (path, inputs_0_value, ... , inputs_N_value),
|
||||||
|
where path is a tuple path to a leaf node in shallow_tree, and
|
||||||
|
inputs_i_value is the corresponding value from inputs[i].
|
||||||
|
*inputs: nested structures that are all structurally compatible with
|
||||||
|
shallow_tree.
|
||||||
|
**kwargs: kwargs to feed to func(). Special kwarg
|
||||||
|
`check_types` is not passed to func, but instead determines whether the
|
||||||
|
types of iterables within the structures have to be same (e.g.
|
||||||
|
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
|
||||||
|
this set this argument to `False`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
|
||||||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of repeatedly applying `func`. Has same structure as `shallow_tree`.
|
||||||
|
"""
|
||||||
if not inputs:
|
if not inputs:
|
||||||
raise ValueError("Cannot map over no sequences")
|
raise ValueError("Cannot map over no sequences")
|
||||||
|
|
||||||
|
check_types = kwargs.pop("check_types", True)
|
||||||
|
|
||||||
for input_tree in inputs:
|
for input_tree in inputs:
|
||||||
assert_shallow_structure(shallow_tree, input_tree)
|
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
|
||||||
|
|
||||||
# Flatten each input separately, apply the function to corresponding elements,
|
# Flatten each input separately, apply the function to corresponding elements,
|
||||||
# then repack based on the structure of the first input.
|
# then repack based on the structure of the first input.
|
||||||
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
|
flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types)
|
||||||
for input_tree in inputs]
|
for input_tree in inputs]
|
||||||
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
|
flat_path_list = [path for path, _
|
||||||
|
in _yield_flat_up_to(shallow_tree, inputs[0])]
|
||||||
|
results = [func(*args, **kwargs) for args in zip(flat_path_list,
|
||||||
|
*flat_value_lists)]
|
||||||
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
|
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
|
||||||
|
|
||||||
|
|
||||||
@ -853,27 +979,8 @@ def yield_flat_paths(nest):
|
|||||||
Tuples containing index or key values which form the path to a specific
|
Tuples containing index or key values which form the path to a specific
|
||||||
leaf value in the nested structure.
|
leaf value in the nested structure.
|
||||||
"""
|
"""
|
||||||
|
for k, _ in _yield_flat_up_to(nest, nest):
|
||||||
# The _maybe_add_final_path_element function is used below in order to avoid
|
yield k
|
||||||
# adding trailing slashes when the sub-element recursed into is a leaf.
|
|
||||||
if isinstance(nest, (dict, _collections.Mapping)):
|
|
||||||
for key in _sorted(nest):
|
|
||||||
value = nest[key]
|
|
||||||
for sub_path in yield_flat_paths(value):
|
|
||||||
yield (key,) + sub_path
|
|
||||||
elif _is_namedtuple(nest):
|
|
||||||
for key in nest._fields:
|
|
||||||
value = getattr(nest, key)
|
|
||||||
for sub_path in yield_flat_paths(value):
|
|
||||||
yield (key,) + sub_path
|
|
||||||
elif isinstance(nest, _six.string_types):
|
|
||||||
yield ()
|
|
||||||
elif isinstance(nest, _collections.Sequence):
|
|
||||||
for idx, value in enumerate(nest):
|
|
||||||
for sub_path in yield_flat_paths(value):
|
|
||||||
yield (idx,) + sub_path
|
|
||||||
else:
|
|
||||||
yield ()
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_with_joined_string_paths(structure, separator="/"):
|
def flatten_with_joined_string_paths(structure, separator="/"):
|
||||||
|
@ -510,30 +510,28 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
def testAssertShallowStructure(self):
|
def testAssertShallowStructure(self):
|
||||||
inp_ab = ["a", "b"]
|
inp_ab = ["a", "b"]
|
||||||
inp_abc = ["a", "b", "c"]
|
inp_abc = ["a", "b", "c"]
|
||||||
expected_message = (
|
with self.assertRaisesWithLiteralMatch(
|
||||||
"The two structures don't have the same sequence length. Input "
|
ValueError,
|
||||||
"structure has length 2, while shallow structure has length 3.")
|
nest._INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
|
||||||
with self.assertRaisesRegexp(ValueError, expected_message):
|
shallow_size=len(inp_abc),
|
||||||
nest.assert_shallow_structure(inp_abc, inp_ab)
|
input_size=len(inp_ab))):
|
||||||
|
nest.assert_shallow_structure(shallow_tree=inp_abc, input_tree=inp_ab)
|
||||||
|
|
||||||
inp_ab1 = [(1, 1), (2, 2)]
|
inp_ab1 = [(1, 1), (2, 2)]
|
||||||
inp_ab2 = [[1, 1], [2, 2]]
|
inp_ab2 = [[1, 1], [2, 2]]
|
||||||
expected_message = (
|
with self.assertRaisesWithLiteralMatch(
|
||||||
"The two structures don't have the same sequence type. Input structure "
|
TypeError,
|
||||||
"has type <(type|class) 'tuple'>, while shallow structure has type "
|
nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
|
||||||
"<(type|class) 'list'>.")
|
shallow_type=type(inp_ab2[0]),
|
||||||
with self.assertRaisesRegexp(TypeError, expected_message):
|
input_type=type(inp_ab1[0]))):
|
||||||
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
||||||
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
|
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
|
||||||
|
|
||||||
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
|
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
|
||||||
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
|
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
|
||||||
expected_message = (
|
with self.assertRaisesWithLiteralMatch(
|
||||||
r"The two structures don't have the same keys. Input "
|
ValueError,
|
||||||
r"structure has keys \['c'\], while shallow structure has "
|
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])):
|
||||||
r"keys \['d'\].")
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, expected_message):
|
|
||||||
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
||||||
|
|
||||||
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
|
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
|
||||||
@ -719,7 +717,9 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
# Non-equal dicts.
|
# Non-equal dicts.
|
||||||
inp_val = dict(a=2, b=3)
|
inp_val = dict(a=2, b=3)
|
||||||
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
|
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
|
||||||
with self.assertRaisesRegexp(ValueError, "same keys"):
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError,
|
||||||
|
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
|
||||||
nest.map_structure_up_to(
|
nest.map_structure_up_to(
|
||||||
inp_val,
|
inp_val,
|
||||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
|
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
|
||||||
@ -736,7 +736,9 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
# Non-equal dict/mapping.
|
# Non-equal dict/mapping.
|
||||||
inp_val = dict(a=2, b=3)
|
inp_val = dict(a=2, b=3)
|
||||||
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
|
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
|
||||||
with self.assertRaisesRegexp(ValueError, "same keys"):
|
with self.assertRaisesWithLiteralMatch(
|
||||||
|
ValueError,
|
||||||
|
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
|
||||||
nest.map_structure_up_to(
|
nest.map_structure_up_to(
|
||||||
inp_val,
|
inp_val,
|
||||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
|
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
|
||||||
@ -849,12 +851,12 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
self.assertEqual(expected, result)
|
self.assertEqual(expected, result)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
("tuples", (1, 2), (3, 4, 5), ValueError),
|
("tuples", (1, 2, 3), (4, 5), ValueError),
|
||||||
("dicts", {"a": 1}, {"b": 2}, ValueError),
|
("dicts", {"a": 1}, {"b": 2}, ValueError),
|
||||||
("mixed", (1, 2), [3, 4], TypeError),
|
("mixed", (1, 2), [3, 4], TypeError),
|
||||||
("nested",
|
("nested",
|
||||||
{"a": [2, 3], "b": [1, 3]},
|
{"a": [2, 3, 4], "b": [1, 3]},
|
||||||
{"b": [5, 6, 7], "a": [8, 9]},
|
{"b": [5, 6], "a": [8, 9]},
|
||||||
ValueError
|
ValueError
|
||||||
))
|
))
|
||||||
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
|
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
|
||||||
@ -884,13 +886,14 @@ class NestTest(parameterized.TestCase, test.TestCase):
|
|||||||
self.assertEqual(expected, result)
|
self.assertEqual(expected, result)
|
||||||
|
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4, 5),
|
dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5),
|
||||||
error_type=ValueError),
|
error_type=ValueError),
|
||||||
dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2},
|
dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2},
|
||||||
error_type=ValueError),
|
error_type=ValueError),
|
||||||
dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], error_type=TypeError),
|
dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], error_type=TypeError),
|
||||||
dict(testcase_name="Nested",
|
dict(testcase_name="Nested",
|
||||||
s1={"a": [2, 3], "b": [1, 3]}, s2={"b": [5, 6, 7], "a": [8, 9]},
|
s1={"a": [2, 3, 4], "b": [1, 3]},
|
||||||
|
s2={"b": [5, 6], "a": [8, 9]},
|
||||||
error_type=ValueError)
|
error_type=ValueError)
|
||||||
])
|
])
|
||||||
def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):
|
def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):
|
||||||
|
Loading…
Reference in New Issue
Block a user