Added map_structure_with_tuple_paths_up_to(), rewrote existing functions in tree.py in terms of it, as they are subsets of its functionality.

This cleans the behavior of traversing with shallow_trees, in that it now allows for any shallow_tree that's an upper subset of the input_tree. For example, before this change, the following was not allowed, on the basis that len(shallow_tree) < len(input_tree):

shallow_tree = {'a': 'A','c': 'C'}
input_tree = {'a': 'A', 'b': 'B', 'c': 'C'}
map_structure_up_to(shallow_tree, input_tree, func)
PiperOrigin-RevId: 230883405
This commit is contained in:
A. Unique TensorFlower 2019-01-25 04:43:48 -08:00 committed by TensorFlower Gardener
parent 53d3a75b7b
commit eda13c8416
3 changed files with 246 additions and 133 deletions

View File

@ -136,12 +136,15 @@ _nest_allowed_symbols = [
'pack_sequence_as', 'pack_sequence_as',
'map_structure', 'map_structure',
'map_structure_with_paths', 'map_structure_with_paths',
'map_structure_with_tuple_paths',
'assert_shallow_structure', 'assert_shallow_structure',
'flatten_up_to', 'flatten_up_to',
'map_structure_up_to', 'map_structure_up_to',
'map_structure_with_tuple_paths_up_to',
'get_traverse_shallow_structure', 'get_traverse_shallow_structure',
'yield_flat_paths', 'yield_flat_paths',
'flatten_with_joined_string_paths', 'flatten_with_joined_string_paths',
'flatten_with_tuple_paths',
] ]
remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols)

View File

@ -41,10 +41,38 @@ import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
def _get_attrs_values(obj): _SHALLOW_TREE_HAS_INVALID_KEYS = (
"""Returns the list of values from an attrs instance.""" "The shallow_tree's keys are not a subset of the input_tree's keys. The "
"shallow_tree has the following keys that are not in the input_tree: {}.")
_STRUCTURES_HAVE_MISMATCHING_TYPES = (
"The two structures don't have the same sequence type. Input structure has "
"type {shallow_type}, while shallow structure has type {input_type}.")
_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
"The input_tree has fewer elements than the input_tree. Input structure "
"has length {input_size}, while shallow structure has length "
"{shallow_size}.")
_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
"If shallow structure is a sequence, input must also be a sequence. "
"Input has type: {}.")
def _get_attrs_items(obj):
"""Returns a list of (name, value) pairs from an attrs instance.
The list will be sorted by name.
Args:
obj: an object.
Returns:
A list of (attr_name, attr_value) pairs, sorted by attr_name.
"""
attrs = getattr(obj.__class__, "__attrs_attrs__") attrs = getattr(obj.__class__, "__attrs_attrs__")
return [getattr(obj, a.name) for a in attrs] attr_names = sorted([a.name for a in attrs])
return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
def _sorted(dict_): def _sorted(dict_):
@ -106,24 +134,45 @@ def _sequence_like(instance, args):
def _yield_value(iterable): def _yield_value(iterable):
"""Yields the next value from the given iterable.""" for _, v in _yield_sorted_items(iterable):
if _is_mapping(iterable): yield v
def _yield_sorted_items(iterable):
"""Yield (key, value) pairs for `iterable` in a deterministic order.
For Sequences, the key will be an int, the array index of a value.
For Mappings, the key will be the dictionary key.
For objects (e.g. namedtuples), the key will be the attribute name.
In all cases, the keys will be iterated in sorted order.
Args:
iterable: an iterable.
Yields:
The iterable's (key, value) pairs, in order of sorted keys.
"""
if isinstance(iterable, _collections.Mapping):
# Iterate through dictionaries in a deterministic order by sorting the # Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict` # keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing # instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a # ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back). # corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable): for key in _sorted(iterable):
yield iterable[key] yield key, iterable[key]
elif _is_attrs(iterable): elif _is_attrs(iterable):
for value in _get_attrs_values(iterable): for item in _get_attrs_items(iterable):
yield value yield item
elif _is_namedtuple(iterable):
for field in iterable._fields:
yield field, getattr(iterable, field)
elif _is_composite_tensor(iterable): elif _is_composite_tensor(iterable):
for value in _yield_value(iterable._to_components()): # pylint: disable=protected-access for item in enumerate(iterable._to_components()): # pylint: disable=protected-access
yield value yield item
else: else:
for value in iterable: for item in enumerate(iterable):
yield value yield item
# See the swig file (util.i) for documentation. # See the swig file (util.i) for documentation.
@ -442,8 +491,15 @@ def map_structure_with_paths(func, *structure, **kwargs):
the type of sequence in any of their substructures. the type of sequence in any of their substructures.
ValueError: If no structures are provided. ValueError: If no structures are provided.
""" """
return _map_structure_with_tuple_or_string_paths( print("wheee I'm updated")
use_string_paths=True, func=func, structure=structure, kwargs=kwargs) def wrapper_func(tuple_path, *inputs, **kwargs):
string_path = "/".join(str(s) for s in tuple_path)
return func(string_path, *inputs, **kwargs)
return map_structure_with_tuple_paths_up_to(structure[0],
wrapper_func,
*structure,
**kwargs)
def map_structure_with_tuple_paths(func, *structure, **kwargs): def map_structure_with_tuple_paths(func, *structure, **kwargs):
@ -479,52 +535,43 @@ def map_structure_with_tuple_paths(func, *structure, **kwargs):
the type of sequence in any of their substructures. the type of sequence in any of their substructures.
ValueError: If no structures are provided. ValueError: If no structures are provided.
""" """
return _map_structure_with_tuple_or_string_paths( return map_structure_with_tuple_paths_up_to(structure[0],
use_string_paths=False, func=func, structure=structure, kwargs=kwargs) func,
*structure,
**kwargs)
def _map_structure_with_tuple_or_string_paths( def _yield_flat_up_to(shallow_tree, input_tree, path=()):
use_string_paths, func, structure, kwargs): """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
"""Implements `map_structure` with either tuple or string paths."""
if not callable(func): Args:
raise TypeError("func must be callable, got: %s" % func) shallow_tree: Nested structure. Traverse no further than its leaf nodes.
if not structure: input_tree: Nested structure. Return the paths and values from this tree.
raise ValueError("Must provide at least one structure") Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.
check_types = kwargs.pop("check_types", True) Yields:
for other in structure[1:]: Pairs of (path, value), where path the tuple path of a leaf node in
assert_same_structure(structure[0], other, check_types=check_types) shallow_tree, and value is the value of the corresponding node in
input_tree.
if use_string_paths: """
flatten_func = flatten_with_joined_string_paths if (isinstance(shallow_tree, _six.string_types) or
not any([isinstance(shallow_tree, _collections.Sequence),
isinstance(shallow_tree, _collections.Mapping),
_is_namedtuple(shallow_tree),
_is_attrs(shallow_tree)])):
yield (path, input_tree)
else: else:
flatten_func = flatten_with_tuple_paths input_tree = dict(_yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
# First set paths_and_values to: subpath = path + (shallow_key,)
# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] input_subtree = input_tree[shallow_key]
paths_and_values = [flatten_func(s) for s in structure] for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
input_subtree,
# Now zip(*paths_and_values) would be: path=subpath):
# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] yield (leaf_path, leaf_value)
# so grouped_by_path is set to:
# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]]
# Note that p1i, ... pmi must all be equal since the structures are the same.
grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)]
return pack_sequence_as(structure[0], [
func(paths[0], *values, **kwargs) for paths, values in grouped_by_path])
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
_yield_value(input_tree)):
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
yield input_leaf
else:
yield input_tree
def assert_shallow_structure(shallow_tree, input_tree, check_types=True): def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
@ -538,8 +585,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
The following code will raise an exception: The following code will raise an exception:
```python ```python
shallow_tree = ["a", "b"] shallow_tree = {"a": "A", "b": "B"}
input_tree = ["c", ["d", "e"], "f"] input_tree = {"a": 1, "c": 2}
assert_shallow_structure(shallow_tree, input_tree) assert_shallow_structure(shallow_tree, input_tree)
``` ```
@ -578,40 +625,34 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
input_is_namedtuple = _is_namedtuple(input_tree, False) input_is_namedtuple = _is_namedtuple(input_tree, False)
if shallow_is_namedtuple and input_is_namedtuple: if shallow_is_namedtuple and input_is_namedtuple:
if not _same_namedtuples(shallow_tree, input_tree): if not _same_namedtuples(shallow_tree, input_tree):
raise TypeError( raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
"The two namedtuples don't have the same sequence type. Input " input_type=type(input_tree),
"structure has type %s, while shallow structure has type %s." shallow_type=type(shallow_tree)))
% (type(input_tree), type(shallow_tree)))
elif not (isinstance(shallow_tree, _collections.Mapping) elif not (isinstance(shallow_tree, _collections.Mapping)
and isinstance(input_tree, _collections.Mapping)): and isinstance(input_tree, _collections.Mapping)):
raise TypeError( raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
"The two structures don't have the same sequence type. Input " input_type=type(input_tree),
"structure has type %s, while shallow structure has type %s." shallow_type=type(shallow_tree)))
% (type(input_tree), type(shallow_tree)))
if len(input_tree) != len(shallow_tree): if len(input_tree) < len(shallow_tree):
raise ValueError( raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
"The two structures don't have the same sequence length. Input " input_size=len(input_tree),
"structure has length %s, while shallow structure has length %s." shallow_size=len(shallow_tree)))
% (len(input_tree), len(shallow_tree)))
if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): if isinstance(shallow_tree, _collections.Mapping):
if set(input_tree) != set(shallow_tree): absent_keys = set(shallow_tree) - set(input_tree)
raise ValueError( if absent_keys:
"The two structures don't have the same keys. Input " raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
"structure has keys %s, while shallow structure has keys %s." % .format(sorted(absent_keys)))
(list(_six.iterkeys(input_tree)),
list(_six.iterkeys(shallow_tree))))
input_tree = list(sorted(_six.iteritems(input_tree))) for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
shallow_tree = list(sorted(_six.iteritems(shallow_tree))) _yield_value(input_tree)):
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
assert_shallow_structure(shallow_branch, input_branch, assert_shallow_structure(shallow_branch, input_branch,
check_types=check_types) check_types=check_types)
def flatten_up_to(shallow_tree, input_tree): def flatten_up_to(shallow_tree, input_tree, check_types=True):
"""Flattens `input_tree` up to `shallow_tree`. """Flattens `input_tree` up to `shallow_tree`.
Any further depth in structure in `input_tree` is retained as elements in the Any further depth in structure in `input_tree` is retained as elements in the
@ -668,6 +709,8 @@ def flatten_up_to(shallow_tree, input_tree):
shallow_tree: a possibly pruned structure of input_tree. shallow_tree: a possibly pruned structure of input_tree.
input_tree: an arbitrarily nested structure or a scalar object. input_tree: an arbitrarily nested structure or a scalar object.
Note, numpy arrays are considered scalars. Note, numpy arrays are considered scalars.
check_types: bool. If True, check that each node in shallow_tree has the
same type as the corresponding node in input_tree.
Returns: Returns:
A Python list, the partially flattened version of `input_tree` according to A Python list, the partially flattened version of `input_tree` according to
@ -680,11 +723,12 @@ def flatten_up_to(shallow_tree, input_tree):
ValueError: If the sequence lengths of `shallow_tree` are different from ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`. `input_tree`.
""" """
assert_shallow_structure(shallow_tree, input_tree) assert_shallow_structure(shallow_tree, input_tree, check_types)
return list(_yield_flat_up_to(shallow_tree, input_tree)) # Discard paths returned by _yield_flat_up_to.
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
def map_structure_up_to(shallow_tree, func, *inputs): def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
"""Applies a function or op to a number of partially flattened inputs. """Applies a function or op to a number of partially flattened inputs.
The `inputs` are flattened up to `shallow_tree` before being mapped. The `inputs` are flattened up to `shallow_tree` before being mapped.
@ -733,6 +777,11 @@ def map_structure_up_to(shallow_tree, func, *inputs):
shallow_tree. The function `func` is applied to corresponding shallow_tree. The function `func` is applied to corresponding
partially flattened elements of each input, so the function must support partially flattened elements of each input, so the function must support
arity of `len(inputs)`. arity of `len(inputs)`.
**kwargs: kwargs to feed to func(). Special kwarg
`check_types` is not passed to func, but instead determines whether the
types of iterables within the structures have to be same (e.g.
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
this set this argument to `False`.
Raises: Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not. TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
@ -745,16 +794,93 @@ def map_structure_up_to(shallow_tree, func, *inputs):
result of repeatedly applying `func`, with same structure as result of repeatedly applying `func`, with same structure as
`shallow_tree`. `shallow_tree`.
""" """
return map_structure_with_tuple_paths_up_to(
shallow_tree,
lambda _, *values: func(*values), # Discards the path arg.
*inputs,
**kwargs)
def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
"""Applies a function or op to a number of partially flattened inputs.
Like map_structure_up_to(), except that the 'func' argument takes a path
tuple as its first argument, followed by the corresponding values from
*inputs.
Example:
lowercase = {'a': 'a', 'b': ('b0', 'b1')}
uppercase = {'a': 'A', 'b': ('B0', 'B1')}
def print_path_and_values(path, *values):
print("path: {}, values: {}".format(path, values))
shallow_tree = {'a': None}
map_structure_with_tuple_paths_up_to(shallow_tree,
print_path_and_values,
lowercase,
uppercase)
>>> path: ('a',), values: ('a', 'A')
>>> path: ('b', 0), values: ('b0', 'B0')
>>> path: ('b', 1), values: ('b1', 'B1')
shallow_tree = {'b': None}
map_structure_with_tuple_paths_up_to(shallow_tree,
print_path_and_values,
lowercase,
uppercase,
check_types=False)
>>> path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1'))
shallow_tree = {'a': None, 'b': {1: None}}
map_structure_with_tuple_paths_up_to(shallow_tree,
print_path_and_values,
lowercase,
uppercase,
check_types=False)
>>> path: ('a',), values: ('a', 'A')
>>> path: ('b', 1), values: ('b1', B1')
Args:
shallow_tree: a shallow tree, common to all the inputs.
func: callable that takes args (path, inputs_0_value, ... , inputs_N_value),
where path is a tuple path to a leaf node in shallow_tree, and
inputs_i_value is the corresponding value from inputs[i].
*inputs: nested structures that are all structurally compatible with
shallow_tree.
**kwargs: kwargs to feed to func(). Special kwarg
`check_types` is not passed to func, but instead determines whether the
types of iterables within the structures have to be same (e.g.
`map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
this set this argument to `False`.
Raises:
TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
TypeError: If the sequence types of `shallow_tree` are different from
`input_tree`.
ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`.
Returns:
Result of repeatedly applying `func`. Has same structure as `shallow_tree`.
"""
if not inputs: if not inputs:
raise ValueError("Cannot map over no sequences") raise ValueError("Cannot map over no sequences")
check_types = kwargs.pop("check_types", True)
for input_tree in inputs: for input_tree in inputs:
assert_shallow_structure(shallow_tree, input_tree) assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
# Flatten each input separately, apply the function to corresponding elements, # Flatten each input separately, apply the function to corresponding elements,
# then repack based on the structure of the first input. # then repack based on the structure of the first input.
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types)
for input_tree in inputs] for input_tree in inputs]
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] flat_path_list = [path for path, _
in _yield_flat_up_to(shallow_tree, inputs[0])]
results = [func(*args, **kwargs) for args in zip(flat_path_list,
*flat_value_lists)]
return pack_sequence_as(structure=shallow_tree, flat_sequence=results) return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
@ -853,27 +979,8 @@ def yield_flat_paths(nest):
Tuples containing index or key values which form the path to a specific Tuples containing index or key values which form the path to a specific
leaf value in the nested structure. leaf value in the nested structure.
""" """
for k, _ in _yield_flat_up_to(nest, nest):
# The _maybe_add_final_path_element function is used below in order to avoid yield k
# adding trailing slashes when the sub-element recursed into is a leaf.
if isinstance(nest, (dict, _collections.Mapping)):
for key in _sorted(nest):
value = nest[key]
for sub_path in yield_flat_paths(value):
yield (key,) + sub_path
elif _is_namedtuple(nest):
for key in nest._fields:
value = getattr(nest, key)
for sub_path in yield_flat_paths(value):
yield (key,) + sub_path
elif isinstance(nest, _six.string_types):
yield ()
elif isinstance(nest, _collections.Sequence):
for idx, value in enumerate(nest):
for sub_path in yield_flat_paths(value):
yield (idx,) + sub_path
else:
yield ()
def flatten_with_joined_string_paths(structure, separator="/"): def flatten_with_joined_string_paths(structure, separator="/"):

View File

@ -510,30 +510,28 @@ class NestTest(parameterized.TestCase, test.TestCase):
def testAssertShallowStructure(self): def testAssertShallowStructure(self):
inp_ab = ["a", "b"] inp_ab = ["a", "b"]
inp_abc = ["a", "b", "c"] inp_abc = ["a", "b", "c"]
expected_message = ( with self.assertRaisesWithLiteralMatch(
"The two structures don't have the same sequence length. Input " ValueError,
"structure has length 2, while shallow structure has length 3.") nest._INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
with self.assertRaisesRegexp(ValueError, expected_message): shallow_size=len(inp_abc),
nest.assert_shallow_structure(inp_abc, inp_ab) input_size=len(inp_ab))):
nest.assert_shallow_structure(shallow_tree=inp_abc, input_tree=inp_ab)
inp_ab1 = [(1, 1), (2, 2)] inp_ab1 = [(1, 1), (2, 2)]
inp_ab2 = [[1, 1], [2, 2]] inp_ab2 = [[1, 1], [2, 2]]
expected_message = ( with self.assertRaisesWithLiteralMatch(
"The two structures don't have the same sequence type. Input structure " TypeError,
"has type <(type|class) 'tuple'>, while shallow structure has type " nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
"<(type|class) 'list'>.") shallow_type=type(inp_ab2[0]),
with self.assertRaisesRegexp(TypeError, expected_message): input_type=type(inp_ab1[0]))):
nest.assert_shallow_structure(inp_ab2, inp_ab1) nest.assert_shallow_structure(inp_ab2, inp_ab1)
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
expected_message = ( with self.assertRaisesWithLiteralMatch(
r"The two structures don't have the same keys. Input " ValueError,
r"structure has keys \['c'\], while shallow structure has " nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])):
r"keys \['d'\].")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1) nest.assert_shallow_structure(inp_ab2, inp_ab1)
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
@ -719,7 +717,9 @@ class NestTest(parameterized.TestCase, test.TestCase):
# Non-equal dicts. # Non-equal dicts.
inp_val = dict(a=2, b=3) inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesRegexp(ValueError, "same keys"): with self.assertRaisesWithLiteralMatch(
ValueError,
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
nest.map_structure_up_to( nest.map_structure_up_to(
inp_val, inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
@ -736,7 +736,9 @@ class NestTest(parameterized.TestCase, test.TestCase):
# Non-equal dict/mapping. # Non-equal dict/mapping.
inp_val = dict(a=2, b=3) inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesRegexp(ValueError, "same keys"): with self.assertRaisesWithLiteralMatch(
ValueError,
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
nest.map_structure_up_to( nest.map_structure_up_to(
inp_val, inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
@ -849,12 +851,12 @@ class NestTest(parameterized.TestCase, test.TestCase):
self.assertEqual(expected, result) self.assertEqual(expected, result)
@parameterized.named_parameters( @parameterized.named_parameters(
("tuples", (1, 2), (3, 4, 5), ValueError), ("tuples", (1, 2, 3), (4, 5), ValueError),
("dicts", {"a": 1}, {"b": 2}, ValueError), ("dicts", {"a": 1}, {"b": 2}, ValueError),
("mixed", (1, 2), [3, 4], TypeError), ("mixed", (1, 2), [3, 4], TypeError),
("nested", ("nested",
{"a": [2, 3], "b": [1, 3]}, {"a": [2, 3, 4], "b": [1, 3]},
{"b": [5, 6, 7], "a": [8, 9]}, {"b": [5, 6], "a": [8, 9]},
ValueError ValueError
)) ))
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
@ -884,13 +886,14 @@ class NestTest(parameterized.TestCase, test.TestCase):
self.assertEqual(expected, result) self.assertEqual(expected, result)
@parameterized.named_parameters([ @parameterized.named_parameters([
dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4, 5), dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5),
error_type=ValueError), error_type=ValueError),
dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2}, dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2},
error_type=ValueError), error_type=ValueError),
dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], error_type=TypeError), dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], error_type=TypeError),
dict(testcase_name="Nested", dict(testcase_name="Nested",
s1={"a": [2, 3], "b": [1, 3]}, s2={"b": [5, 6, 7], "a": [8, 9]}, s1={"a": [2, 3, 4], "b": [1, 3]},
s2={"b": [5, 6], "a": [8, 9]},
error_type=ValueError) error_type=ValueError)
]) ])
def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type): def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):