From eda13c84160a3280316c80ca6455a6edb45e3129 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 25 Jan 2019 04:43:48 -0800 Subject: [PATCH] Added map_structure_with_tuple_paths_up_to(), rewrote existing functions in tree.py in terms of it, as they are subsets of its functionality. This cleans the behavior of traversing with shallow_trees, in that it now allows for any shallow_tree that's an upper subset of the input_tree. For example, before this change, the following was not allowed, on the basis that len(shallow_tree) < len(input_tree): shallow_tree = {'a': 'A','c': 'C'} input_tree = {'a': 'A', 'b': 'B', 'c': 'C'} map_structure_up_to(shallow_tree, input_tree, func) PiperOrigin-RevId: 230883405 --- tensorflow/contrib/framework/__init__.py | 3 + tensorflow/python/util/nest.py | 327 +++++++++++++++-------- tensorflow/python/util/nest_test.py | 49 ++-- 3 files changed, 246 insertions(+), 133 deletions(-) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index ff65c1586cd..3784631dcbf 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -136,12 +136,15 @@ _nest_allowed_symbols = [ 'pack_sequence_as', 'map_structure', 'map_structure_with_paths', + 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', + 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', 'yield_flat_paths', 'flatten_with_joined_string_paths', + 'flatten_with_tuple_paths', ] remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 83fd8bdf601..4de358dec6f 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -41,10 +41,38 @@ import six as _six from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow -def _get_attrs_values(obj): - """Returns the list of values from an attrs instance.""" +_SHALLOW_TREE_HAS_INVALID_KEYS = ( + "The shallow_tree's keys are not a subset of the input_tree's keys. The " + "shallow_tree has the following keys that are not in the input_tree: {}.") + +_STRUCTURES_HAVE_MISMATCHING_TYPES = ( + "The two structures don't have the same sequence type. Input structure has " + "type {shallow_type}, while shallow structure has type {input_type}.") + +_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( + "The input_tree has fewer elements than the input_tree. Input structure " + "has length {input_size}, while shallow structure has length " + "{shallow_size}.") + +_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = ( + "If shallow structure is a sequence, input must also be a sequence. " + "Input has type: {}.") + + +def _get_attrs_items(obj): + """Returns a list of (name, value) pairs from an attrs instance. + + The list will be sorted by name. + + Args: + obj: an object. + + Returns: + A list of (attr_name, attr_value) pairs, sorted by attr_name. + """ attrs = getattr(obj.__class__, "__attrs_attrs__") - return [getattr(obj, a.name) for a in attrs] + attr_names = sorted([a.name for a in attrs]) + return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names] def _sorted(dict_): @@ -106,24 +134,45 @@ def _sequence_like(instance, args): def _yield_value(iterable): - """Yields the next value from the given iterable.""" - if _is_mapping(iterable): + for _, v in _yield_sorted_items(iterable): + yield v + + +def _yield_sorted_items(iterable): + """Yield (key, value) pairs for `iterable` in a deterministic order. + + For Sequences, the key will be an int, the array index of a value. + For Mappings, the key will be the dictionary key. + For objects (e.g. namedtuples), the key will be the attribute name. + + In all cases, the keys will be iterated in sorted order. + + Args: + iterable: an iterable. + + Yields: + The iterable's (key, value) pairs, in order of sorted keys. + """ + if isinstance(iterable, _collections.Mapping): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing # ordered and plain dicts (e.g., flattening a dict but using a # corresponding `OrderedDict` to pack it back). for key in _sorted(iterable): - yield iterable[key] + yield key, iterable[key] elif _is_attrs(iterable): - for value in _get_attrs_values(iterable): - yield value + for item in _get_attrs_items(iterable): + yield item + elif _is_namedtuple(iterable): + for field in iterable._fields: + yield field, getattr(iterable, field) elif _is_composite_tensor(iterable): - for value in _yield_value(iterable._to_components()): # pylint: disable=protected-access - yield value + for item in enumerate(iterable._to_components()): # pylint: disable=protected-access + yield item else: - for value in iterable: - yield value + for item in enumerate(iterable): + yield item # See the swig file (util.i) for documentation. @@ -442,8 +491,15 @@ def map_structure_with_paths(func, *structure, **kwargs): the type of sequence in any of their substructures. ValueError: If no structures are provided. """ - return _map_structure_with_tuple_or_string_paths( - use_string_paths=True, func=func, structure=structure, kwargs=kwargs) + print("wheee I'm updated") + def wrapper_func(tuple_path, *inputs, **kwargs): + string_path = "/".join(str(s) for s in tuple_path) + return func(string_path, *inputs, **kwargs) + + return map_structure_with_tuple_paths_up_to(structure[0], + wrapper_func, + *structure, + **kwargs) def map_structure_with_tuple_paths(func, *structure, **kwargs): @@ -479,52 +535,43 @@ def map_structure_with_tuple_paths(func, *structure, **kwargs): the type of sequence in any of their substructures. ValueError: If no structures are provided. """ - return _map_structure_with_tuple_or_string_paths( - use_string_paths=False, func=func, structure=structure, kwargs=kwargs) + return map_structure_with_tuple_paths_up_to(structure[0], + func, + *structure, + **kwargs) -def _map_structure_with_tuple_or_string_paths( - use_string_paths, func, structure, kwargs): - """Implements `map_structure` with either tuple or string paths.""" +def _yield_flat_up_to(shallow_tree, input_tree, path=()): + """Yields (path, value) pairs of input_tree flattened up to shallow_tree. - if not callable(func): - raise TypeError("func must be callable, got: %s" % func) - if not structure: - raise ValueError("Must provide at least one structure") + Args: + shallow_tree: Nested structure. Traverse no further than its leaf nodes. + input_tree: Nested structure. Return the paths and values from this tree. + Must have the same upper structure as shallow_tree. + path: Tuple. Optional argument, only used when recursing. The path from the + root of the original shallow_tree, down to the root of the shallow_tree + arg of this recursive call. - check_types = kwargs.pop("check_types", True) - for other in structure[1:]: - assert_same_structure(structure[0], other, check_types=check_types) - - if use_string_paths: - flatten_func = flatten_with_joined_string_paths + Yields: + Pairs of (path, value), where path the tuple path of a leaf node in + shallow_tree, and value is the value of the corresponding node in + input_tree. + """ + if (isinstance(shallow_tree, _six.string_types) or + not any([isinstance(shallow_tree, _collections.Sequence), + isinstance(shallow_tree, _collections.Mapping), + _is_namedtuple(shallow_tree), + _is_attrs(shallow_tree)])): + yield (path, input_tree) else: - flatten_func = flatten_with_tuple_paths - - # First set paths_and_values to: - # [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] - paths_and_values = [flatten_func(s) for s in structure] - - # Now zip(*paths_and_values) would be: - # [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] - # so grouped_by_path is set to: - # [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] - # Note that p1i, ... pmi must all be equal since the structures are the same. - grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] - - return pack_sequence_as(structure[0], [ - func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) - - -def _yield_flat_up_to(shallow_tree, input_tree): - """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" - if is_sequence(shallow_tree): - for shallow_branch, input_branch in zip(_yield_value(shallow_tree), - _yield_value(input_tree)): - for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): - yield input_leaf - else: - yield input_tree + input_tree = dict(_yield_sorted_items(input_tree)) + for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): + subpath = path + (shallow_key,) + input_subtree = input_tree[shallow_key] + for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, + input_subtree, + path=subpath): + yield (leaf_path, leaf_value) def assert_shallow_structure(shallow_tree, input_tree, check_types=True): @@ -538,8 +585,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): The following code will raise an exception: ```python - shallow_tree = ["a", "b"] - input_tree = ["c", ["d", "e"], "f"] + shallow_tree = {"a": "A", "b": "B"} + input_tree = {"a": 1, "c": 2} assert_shallow_structure(shallow_tree, input_tree) ``` @@ -578,40 +625,34 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): input_is_namedtuple = _is_namedtuple(input_tree, False) if shallow_is_namedtuple and input_is_namedtuple: if not _same_namedtuples(shallow_tree, input_tree): - raise TypeError( - "The two namedtuples don't have the same sequence type. Input " - "structure has type %s, while shallow structure has type %s." - % (type(input_tree), type(shallow_tree))) + raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( + input_type=type(input_tree), + shallow_type=type(shallow_tree))) + elif not (isinstance(shallow_tree, _collections.Mapping) and isinstance(input_tree, _collections.Mapping)): - raise TypeError( - "The two structures don't have the same sequence type. Input " - "structure has type %s, while shallow structure has type %s." - % (type(input_tree), type(shallow_tree))) + raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( + input_type=type(input_tree), + shallow_type=type(shallow_tree))) - if len(input_tree) != len(shallow_tree): - raise ValueError( - "The two structures don't have the same sequence length. Input " - "structure has length %s, while shallow structure has length %s." - % (len(input_tree), len(shallow_tree))) + if len(input_tree) < len(shallow_tree): + raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( + input_size=len(input_tree), + shallow_size=len(shallow_tree))) - if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): - if set(input_tree) != set(shallow_tree): - raise ValueError( - "The two structures don't have the same keys. Input " - "structure has keys %s, while shallow structure has keys %s." % - (list(_six.iterkeys(input_tree)), - list(_six.iterkeys(shallow_tree)))) + if isinstance(shallow_tree, _collections.Mapping): + absent_keys = set(shallow_tree) - set(input_tree) + if absent_keys: + raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS + .format(sorted(absent_keys))) - input_tree = list(sorted(_six.iteritems(input_tree))) - shallow_tree = list(sorted(_six.iteritems(shallow_tree))) - - for shallow_branch, input_branch in zip(shallow_tree, input_tree): + for shallow_branch, input_branch in zip(_yield_value(shallow_tree), + _yield_value(input_tree)): assert_shallow_structure(shallow_branch, input_branch, check_types=check_types) -def flatten_up_to(shallow_tree, input_tree): +def flatten_up_to(shallow_tree, input_tree, check_types=True): """Flattens `input_tree` up to `shallow_tree`. Any further depth in structure in `input_tree` is retained as elements in the @@ -668,6 +709,8 @@ def flatten_up_to(shallow_tree, input_tree): shallow_tree: a possibly pruned structure of input_tree. input_tree: an arbitrarily nested structure or a scalar object. Note, numpy arrays are considered scalars. + check_types: bool. If True, check that each node in shallow_tree has the + same type as the corresponding node in input_tree. Returns: A Python list, the partially flattened version of `input_tree` according to @@ -680,11 +723,12 @@ def flatten_up_to(shallow_tree, input_tree): ValueError: If the sequence lengths of `shallow_tree` are different from `input_tree`. """ - assert_shallow_structure(shallow_tree, input_tree) - return list(_yield_flat_up_to(shallow_tree, input_tree)) + assert_shallow_structure(shallow_tree, input_tree, check_types) + # Discard paths returned by _yield_flat_up_to. + return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree)) -def map_structure_up_to(shallow_tree, func, *inputs): +def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): """Applies a function or op to a number of partially flattened inputs. The `inputs` are flattened up to `shallow_tree` before being mapped. @@ -733,6 +777,11 @@ def map_structure_up_to(shallow_tree, func, *inputs): shallow_tree. The function `func` is applied to corresponding partially flattened elements of each input, so the function must support arity of `len(inputs)`. + **kwargs: kwargs to feed to func(). Special kwarg + `check_types` is not passed to func, but instead determines whether the + types of iterables within the structures have to be same (e.g. + `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow + this set this argument to `False`. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. @@ -745,16 +794,93 @@ def map_structure_up_to(shallow_tree, func, *inputs): result of repeatedly applying `func`, with same structure as `shallow_tree`. """ + return map_structure_with_tuple_paths_up_to( + shallow_tree, + lambda _, *values: func(*values), # Discards the path arg. + *inputs, + **kwargs) + + +def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs): + """Applies a function or op to a number of partially flattened inputs. + + Like map_structure_up_to(), except that the 'func' argument takes a path + tuple as its first argument, followed by the corresponding values from + *inputs. + + Example: + + lowercase = {'a': 'a', 'b': ('b0', 'b1')} + uppercase = {'a': 'A', 'b': ('B0', 'B1')} + + def print_path_and_values(path, *values): + print("path: {}, values: {}".format(path, values)) + + shallow_tree = {'a': None} + map_structure_with_tuple_paths_up_to(shallow_tree, + print_path_and_values, + lowercase, + uppercase) + >>> path: ('a',), values: ('a', 'A') + >>> path: ('b', 0), values: ('b0', 'B0') + >>> path: ('b', 1), values: ('b1', 'B1') + + shallow_tree = {'b': None} + map_structure_with_tuple_paths_up_to(shallow_tree, + print_path_and_values, + lowercase, + uppercase, + check_types=False) + >>> path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1')) + + shallow_tree = {'a': None, 'b': {1: None}} + map_structure_with_tuple_paths_up_to(shallow_tree, + print_path_and_values, + lowercase, + uppercase, + check_types=False) + >>> path: ('a',), values: ('a', 'A') + >>> path: ('b', 1), values: ('b1', B1') + + Args: + shallow_tree: a shallow tree, common to all the inputs. + func: callable that takes args (path, inputs_0_value, ... , inputs_N_value), + where path is a tuple path to a leaf node in shallow_tree, and + inputs_i_value is the corresponding value from inputs[i]. + *inputs: nested structures that are all structurally compatible with + shallow_tree. + **kwargs: kwargs to feed to func(). Special kwarg + `check_types` is not passed to func, but instead determines whether the + types of iterables within the structures have to be same (e.g. + `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow + this set this argument to `False`. + + Raises: + TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not. + TypeError: If the sequence types of `shallow_tree` are different from + `input_tree`. + ValueError: If the sequence lengths of `shallow_tree` are different from + `input_tree`. + + Returns: + Result of repeatedly applying `func`. Has same structure as `shallow_tree`. + """ if not inputs: raise ValueError("Cannot map over no sequences") + + check_types = kwargs.pop("check_types", True) + for input_tree in inputs: - assert_shallow_structure(shallow_tree, input_tree) + assert_shallow_structure(shallow_tree, input_tree, check_types=check_types) # Flatten each input separately, apply the function to corresponding elements, # then repack based on the structure of the first input. - all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) - for input_tree in inputs] - results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] + flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types) + for input_tree in inputs] + flat_path_list = [path for path, _ + in _yield_flat_up_to(shallow_tree, inputs[0])] + results = [func(*args, **kwargs) for args in zip(flat_path_list, + *flat_value_lists)] return pack_sequence_as(structure=shallow_tree, flat_sequence=results) @@ -853,27 +979,8 @@ def yield_flat_paths(nest): Tuples containing index or key values which form the path to a specific leaf value in the nested structure. """ - - # The _maybe_add_final_path_element function is used below in order to avoid - # adding trailing slashes when the sub-element recursed into is a leaf. - if isinstance(nest, (dict, _collections.Mapping)): - for key in _sorted(nest): - value = nest[key] - for sub_path in yield_flat_paths(value): - yield (key,) + sub_path - elif _is_namedtuple(nest): - for key in nest._fields: - value = getattr(nest, key) - for sub_path in yield_flat_paths(value): - yield (key,) + sub_path - elif isinstance(nest, _six.string_types): - yield () - elif isinstance(nest, _collections.Sequence): - for idx, value in enumerate(nest): - for sub_path in yield_flat_paths(value): - yield (idx,) + sub_path - else: - yield () + for k, _ in _yield_flat_up_to(nest, nest): + yield k def flatten_with_joined_string_paths(structure, separator="/"): diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 83fa5dd6608..71034ffcb6b 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -510,30 +510,28 @@ class NestTest(parameterized.TestCase, test.TestCase): def testAssertShallowStructure(self): inp_ab = ["a", "b"] inp_abc = ["a", "b", "c"] - expected_message = ( - "The two structures don't have the same sequence length. Input " - "structure has length 2, while shallow structure has length 3.") - with self.assertRaisesRegexp(ValueError, expected_message): - nest.assert_shallow_structure(inp_abc, inp_ab) + with self.assertRaisesWithLiteralMatch( + ValueError, + nest._INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( + shallow_size=len(inp_abc), + input_size=len(inp_ab))): + nest.assert_shallow_structure(shallow_tree=inp_abc, input_tree=inp_ab) inp_ab1 = [(1, 1), (2, 2)] inp_ab2 = [[1, 1], [2, 2]] - expected_message = ( - "The two structures don't have the same sequence type. Input structure " - "has type <(type|class) 'tuple'>, while shallow structure has type " - "<(type|class) 'list'>.") - with self.assertRaisesRegexp(TypeError, expected_message): + with self.assertRaisesWithLiteralMatch( + TypeError, + nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format( + shallow_type=type(inp_ab2[0]), + input_type=type(inp_ab1[0]))): nest.assert_shallow_structure(inp_ab2, inp_ab1) nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} - expected_message = ( - r"The two structures don't have the same keys. Input " - r"structure has keys \['c'\], while shallow structure has " - r"keys \['d'\].") - - with self.assertRaisesRegexp(ValueError, expected_message): + with self.assertRaisesWithLiteralMatch( + ValueError, + nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])): nest.assert_shallow_structure(inp_ab2, inp_ab1) inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) @@ -719,7 +717,9 @@ class NestTest(parameterized.TestCase, test.TestCase): # Non-equal dicts. inp_val = dict(a=2, b=3) inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) - with self.assertRaisesRegexp(ValueError, "same keys"): + with self.assertRaisesWithLiteralMatch( + ValueError, + nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])): nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) @@ -736,7 +736,9 @@ class NestTest(parameterized.TestCase, test.TestCase): # Non-equal dict/mapping. inp_val = dict(a=2, b=3) inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) - with self.assertRaisesRegexp(ValueError, "same keys"): + with self.assertRaisesWithLiteralMatch( + ValueError, + nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])): nest.map_structure_up_to( inp_val, lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) @@ -849,12 +851,12 @@ class NestTest(parameterized.TestCase, test.TestCase): self.assertEqual(expected, result) @parameterized.named_parameters( - ("tuples", (1, 2), (3, 4, 5), ValueError), + ("tuples", (1, 2, 3), (4, 5), ValueError), ("dicts", {"a": 1}, {"b": 2}, ValueError), ("mixed", (1, 2), [3, 4], TypeError), ("nested", - {"a": [2, 3], "b": [1, 3]}, - {"b": [5, 6, 7], "a": [8, 9]}, + {"a": [2, 3, 4], "b": [1, 3]}, + {"b": [5, 6], "a": [8, 9]}, ValueError )) def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): @@ -884,13 +886,14 @@ class NestTest(parameterized.TestCase, test.TestCase): self.assertEqual(expected, result) @parameterized.named_parameters([ - dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4, 5), + dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5), error_type=ValueError), dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2}, error_type=ValueError), dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], error_type=TypeError), dict(testcase_name="Nested", - s1={"a": [2, 3], "b": [1, 3]}, s2={"b": [5, 6, 7], "a": [8, 9]}, + s1={"a": [2, 3, 4], "b": [1, 3]}, + s2={"b": [5, 6], "a": [8, 9]}, error_type=ValueError) ]) def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):