From e1836047555f19ea371a1e4b46101cf535c4a68b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Feb 2019 07:41:22 -0800 Subject: [PATCH] Add flatten_with_tuple_paths_up_to(), the 'flatten' version of map_structure_with_tuple_paths_up_to(). PiperOrigin-RevId: 232673116 --- tensorflow/contrib/framework/__init__.py | 1 + tensorflow/python/util/nest.py | 94 +++++++++ tensorflow/python/util/nest_test.py | 238 +++++++++++++++++++++++ 3 files changed, 333 insertions(+) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 3784631dcbf..fc2334d5d7f 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -139,6 +139,7 @@ _nest_allowed_symbols = [ 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', + 'flatten_with_tuple_paths_up_to', 'map_structure_up_to', 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index a43ec48589f..24356946317 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -761,6 +761,100 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True): return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree)) +def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True): + """Flattens `input_tree` up to `shallow_tree`. + + Any further depth in structure in `input_tree` is retained as elements in the + partially flattened output. + + Returns a list of (path, value) pairs, where value a leaf node in the + flattened tree, and path is the tuple path of that leaf in input_tree. + + If `shallow_tree` and `input_tree` are not sequences, this returns a + single-element list: `[((), input_tree)]`. + + Use Case: + + Sometimes we may wish to partially flatten a nested sequence, retaining some + of the nested structure. We achieve this by specifying a shallow structure, + `shallow_tree`, we wish to flatten up to. + + The input, `input_tree`, can be thought of as having the same structure as + `shallow_tree`, but with leaf nodes that are themselves tree structures. + + Examples: + + ```python + input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + shallow_tree = [[True, True], [False, True]] + + flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree, + input_tree) + flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree, + shallow_tree) + + # Output is: + # [((0, 0), [2, 2]), + # ((0, 1), [3, 3]), + # ((1, 0), [4, 9]), + # ((1, 1), [5, 5])] + # + # [((0, 0), True), + # ((0, 1), True), + # ((1, 0), False), + # ((1, 1), True)] + ``` + + ```python + input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] + shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] + + input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) + input_tree_flattened = flatten(input_tree) + + # Output is: + # [((0, 0), ('a', 1)), + # ((0, 1, 0), ('b', 2)), + # ((0, 1, 1, 0), ('c', 3)), + # ((0, 1, 1, 1), ('d', 4))] + # ['a', 1, 'b', 2, 'c', 3, 'd', 4] + ``` + + Non-Sequence Edge Cases: + + ```python + flatten_with_tuple_paths_up_to(0, 0) # Output: [(), 0] + + flatten_with_tuple_paths_up_to(0, [0, 1, 2]) # Output: [(), [0, 1, 2]] + + flatten_with_tuple_paths_up_to([0, 1, 2], 0) # Output: TypeError + + flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2]) + # Output: [((0,) 0), ((1,), 1), ((2,), 2)] + ``` + + Args: + shallow_tree: a possibly pruned structure of input_tree. + input_tree: an arbitrarily nested structure or a scalar object. + Note, numpy arrays are considered scalars. + check_types: bool. If True, check that each node in shallow_tree has the + same type as the corresponding node in input_tree. + + Returns: + A Python list, the partially flattened version of `input_tree` according to + the structure of `shallow_tree`. + + Raises: + TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + TypeError: If the sequence types of `shallow_tree` are different from + `input_tree`. + ValueError: If the sequence lengths of `shallow_tree` are different from + `input_tree`. + """ + assert_shallow_structure(shallow_tree, input_tree, check_types=check_types) + return list(_yield_flat_up_to(shallow_tree, input_tree)) + + def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): """Applies a function or op to a number of partially flattened inputs. diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 71034ffcb6b..ec559bd2abd 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -686,6 +686,244 @@ class NestTest(parameterized.TestCase, test.TestCase): flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) + def testFlattenWithTuplePathsUpTo(self): + def get_paths_and_values(shallow_tree, input_tree): + path_value_pairs = nest.flatten_with_tuple_paths_up_to(shallow_tree, + input_tree) + paths = [p for p, _ in path_value_pairs] + values = [v for _, v in path_value_pairs] + return paths, values + + # Shallow tree ends at scalar. + input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + shallow_tree = [[True, True], [False, True]] + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, + [(0, 0), (0, 1), (1, 0), (1, 1)]) + self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) + self.assertEqual(flattened_shallow_tree_paths, + [(0, 0), (0, 1), (1, 0), (1, 1)]) + self.assertEqual(flattened_shallow_tree, [True, True, False, True]) + + # Shallow tree ends at string. + input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] + shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] + (input_tree_flattened_as_shallow_tree_paths, + input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, + input_tree) + input_tree_flattened_paths = [p for p, _ in + nest.flatten_with_tuple_paths(input_tree)] + input_tree_flattened = nest.flatten(input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree_paths, + [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)]) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) + + self.assertEqual(input_tree_flattened_paths, + [(0, 0, 0), (0, 0, 1), + (0, 1, 0, 0), (0, 1, 0, 1), + (0, 1, 1, 0, 0), (0, 1, 1, 0, 1), + (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)]) + self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) + + # Make sure dicts are correctly flattened, yielding values, not keys. + input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} + shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} + (input_tree_flattened_as_shallow_tree_paths, + input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree_paths, + [("a",), ("b",), ("d", 0), ("d", 1)]) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [1, {"c": 2}, 3, (4, 5)]) + + # Namedtuples. + ab_tuple = collections.namedtuple("ab_tuple", "a, b") + input_tree = ab_tuple(a=[0, 1], b=2) + shallow_tree = ab_tuple(a=0, b=1) + (input_tree_flattened_as_shallow_tree_paths, + input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree_paths, + [("a",), ("b",)]) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [[0, 1], 2]) + + # Nested dicts, OrderedDicts and namedtuples. + input_tree = collections.OrderedDict( + [("a", ab_tuple(a=[0, {"b": 1}], b=2)), + ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) + shallow_tree = input_tree + (input_tree_flattened_as_shallow_tree_paths, + input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree_paths, + [("a", "a", 0), + ("a", "a", 1, "b"), + ("a", "b"), + ("c", "d"), + ("c", "e", "f")]) + self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) + shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) + (input_tree_flattened_as_shallow_tree_paths, + input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree_paths, + [("a",), + ("c", "d"), + ("c", "e")]) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [ab_tuple(a=[0, {"b": 1}], b=2), + 3, + collections.OrderedDict([("f", 4)])]) + shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) + (input_tree_flattened_as_shallow_tree_paths, + input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree_paths, + [("a",), ("c",)]) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [ab_tuple(a=[0, {"b": 1}], b=2), + {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) + + ## Shallow non-list edge-case. + # Using iterable elements. + input_tree = ["input_tree"] + shallow_tree = "shallow_tree" + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [()]) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree_paths, [()]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + input_tree = ["input_tree_0", "input_tree_1"] + shallow_tree = "shallow_tree" + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [()]) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree_paths, [()]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + # Test case where len(shallow_tree) < len(input_tree) + input_tree = {"a": "A", "b": "B", "c": "C"} + shallow_tree = {"a": 1, "c": 2} + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [("a",), ("c",)]) + self.assertEqual(flattened_input_tree, ["A", "C"]) + self.assertEqual(flattened_shallow_tree_paths, [("a",), ("c",)]) + self.assertEqual(flattened_shallow_tree, [1, 2]) + + # Using non-iterable elements. + input_tree = [0] + shallow_tree = 9 + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [()]) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree_paths, [()]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + input_tree = [0, 1] + shallow_tree = 9 + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [()]) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree_paths, [()]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + ## Both non-list edge-case. + # Using iterable elements. + input_tree = "input_tree" + shallow_tree = "shallow_tree" + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [()]) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree_paths, [()]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + # Using non-iterable elements. + input_tree = 0 + shallow_tree = 0 + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree_paths, [()]) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree_paths, [()]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + ## Input non-list edge-case. + # Using iterable elements. + input_tree = "input_tree" + shallow_tree = ["shallow_tree"] + with self.assertRaisesWithLiteralMatch( + TypeError, + nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree_paths, [(0,)]) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + input_tree = "input_tree" + shallow_tree = ["shallow_tree_9", "shallow_tree_8"] + with self.assertRaisesWithLiteralMatch( + TypeError, + nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)]) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + # Using non-iterable elements. + input_tree = 0 + shallow_tree = [9] + with self.assertRaisesWithLiteralMatch( + TypeError, + nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree_paths, [(0,)]) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + input_tree = 0 + shallow_tree = [9, 8] + with self.assertRaisesWithLiteralMatch( + TypeError, + nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): + (flattened_input_tree_paths, + flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) + (flattened_shallow_tree_paths, + flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)]) + self.assertEqual(flattened_shallow_tree, shallow_tree) + def testMapStructureUpTo(self): # Named tuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b")