From dc20e15f62d6d9ac0819410a41d9db3402505a3b Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 5 Apr 2019 13:54:25 -0700 Subject: [PATCH] Add expand_composites argument to all nest.* methods. PiperOrigin-RevId: 242187912 --- .../python/framework/composite_tensor_test.py | 217 ++++++++++++++++-- tensorflow/python/util/nest.py | 108 ++++++--- 2 files changed, 275 insertions(+), 50 deletions(-) diff --git a/tensorflow/python/framework/composite_tensor_test.py b/tensorflow/python/framework/composite_tensor_test.py index f249faa5d68..65518bf650f 100644 --- a/tensorflow/python/framework/composite_tensor_test.py +++ b/tensorflow/python/framework/composite_tensor_test.py @@ -25,6 +25,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.util import nest +@test_util.run_all_in_graph_and_eager_modes class TestCompositeTensor(composite_tensor.CompositeTensor): def __init__(self, *components): @@ -41,12 +42,19 @@ class TestCompositeTensor(composite_tensor.CompositeTensor): raise NotImplementedError('CompositeTensor._shape_invariant_to_components') def _is_graph_tensor(self): - return True + return False + + def __repr__(self): + return 'TestCompositeTensor%r' % (self._components,) + + def __eq__(self, other): + return (isinstance(other, TestCompositeTensor) and + self._components == other._components) class CompositeTensorTest(test_util.TensorFlowTestCase): - def assertNestEqual(self, a, b, expand_composites=False): + def assertNestEqual(self, a, b): if isinstance(a, dict): self.assertIsInstance(b, dict) self.assertEqual(set(a), set(b)) @@ -57,36 +65,35 @@ class CompositeTensorTest(test_util.TensorFlowTestCase): self.assertEqual(len(a), len(b)) for a_val, b_val in zip(a, b): self.assertNestEqual(a_val, b_val) - elif expand_composites and isinstance(a, composite_tensor.CompositeTensor): + elif isinstance(a, composite_tensor.CompositeTensor): self.assertIsInstance(b, composite_tensor.CompositeTensor) - self.assertNestEqual(a._to_components(), - b._to_components()) + self.assertNestEqual(a._to_components(), b._to_components()) + else: + self.assertAllEqual(a, b) def testNestFlatten(self): st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10]) st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10]) structure = [[st1], 'foo', {'y': [st2]}] x = nest.flatten(structure, expand_composites=True) - self.assertEqual(len(x), 7) - self.assertIs(x[0], st1.indices) - self.assertIs(x[1], st1.values) - self.assertIs(x[2], st1.dense_shape) - self.assertEqual(x[3], 'foo') - self.assertIs(x[4], st2.indices) - self.assertIs(x[5], st2.values) - self.assertIs(x[6], st2.dense_shape) + self.assertNestEqual(x, [ + st1.indices, st1.values, st1.dense_shape, 'foo', st2.indices, + st2.values, st2.dense_shape + ]) def testNestPackSequenceAs(self): st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10]) st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10]) structure1 = [[st1], 'foo', {'y': [st2]}] - flat = [st2.indices, st2.values, st2.dense_shape, 'bar', - st1.indices, st1.values, st1.dense_shape] + flat = [ + st2.indices, st2.values, st2.dense_shape, 'bar', st1.indices, + st1.values, st1.dense_shape + ] result = nest.pack_sequence_as(structure1, flat, expand_composites=True) expected = [[st2], 'bar', {'y': [st1]}] self.assertNestEqual(expected, result) - def testAssertSameStructure(self): + def testNestAssertSameStructure(self): st1 = sparse_tensor.SparseTensor([[0]], [0], [100]) st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100]) test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape) @@ -96,6 +103,184 @@ class CompositeTensorTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): nest.assert_same_structure(st1, test, expand_composites=True) + def testNestMapStructure(self): + structure = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + + def func(x): + return x + 10 + + result = nest.map_structure(func, structure, expand_composites=True) + expected = [[TestCompositeTensor(11, 12, 13)], 110, { + 'y': TestCompositeTensor(TestCompositeTensor(14, 15), 16) + }] + self.assertEqual(result, expected) + + def testNestMapStructureWithPaths(self): + structure = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + + def func(path, x): + return '%s:%s' % (path, x) + + result = nest.map_structure_with_paths( + func, structure, expand_composites=True) + expected = [[TestCompositeTensor('0/0/0:1', '0/0/1:2', '0/0/2:3')], '1:100', + { + 'y': + TestCompositeTensor( + TestCompositeTensor('2/y/0/0:4', '2/y/0/1:5'), + '2/y/1:6') + }] + self.assertEqual(result, expected) + + def testNestMapStructureWithTuplePaths(self): + structure = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + + def func(path, x): + return (path, x) + + result = nest.map_structure_with_tuple_paths( + func, structure, expand_composites=True) + expected = [[ + TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3)) + ], ((1,), 100), { + 'y': + TestCompositeTensor( + TestCompositeTensor(((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5)), + ((2, 'y', 1), 6)) + }] + self.assertEqual(result, expected) + + def testNestAssertShallowStructure(self): + s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] + s2 = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + nest.assert_shallow_structure(s1, s2, expand_composites=False) + nest.assert_shallow_structure(s1, s2, expand_composites=True) + nest.assert_shallow_structure(s2, s1, expand_composites=False) + with self.assertRaises(TypeError): + nest.assert_shallow_structure(s2, s1, expand_composites=True) + + def testNestFlattenUpTo(self): + s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] + s2 = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + result1 = nest.flatten_up_to(s1, s2, expand_composites=True) + expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6] + self.assertEqual(result1, expected1) + + result2 = nest.flatten_up_to(s1, s2, expand_composites=False) + expected2 = [ + TestCompositeTensor(1, 2, 3), 100, + TestCompositeTensor(TestCompositeTensor(4, 5), 6) + ] + self.assertEqual(result2, expected2) + + def testNestFlattenWithTuplePathsUpTo(self): + s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] + s2 = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + result1 = nest.flatten_with_tuple_paths_up_to( + s1, s2, expand_composites=True) + expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100), + ((2, 'y', 0), TestCompositeTensor(4, 5)), ((2, 'y', 1), 6)] + self.assertEqual(result1, expected1) + + result2 = nest.flatten_with_tuple_paths_up_to( + s1, s2, expand_composites=False) + expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100), + ((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))] + self.assertEqual(result2, expected2) + + def testNestMapStructureUpTo(self): + s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] + s2 = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + + def func(x): + return x + 10 if isinstance(x, int) else x + + result = nest.map_structure_up_to(s1, func, s2, expand_composites=True) + expected = [[TestCompositeTensor(11, 12, 13)], 110, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16) + }] + self.assertEqual(result, expected) + + def testNestMapStructureWithTuplePathsUpTo(self): + s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}] + s2 = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + + def func(path, x): + return (path, x) + + result = nest.map_structure_with_tuple_paths_up_to( + s1, func, s2, expand_composites=True) + expected = [[ + TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3)) + ], ((1,), 100), { + 'y': + TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)), + ((2, 'y', 1), 6)) + }] + self.assertEqual(result, expected) + + def testNestGetTraverseShallowStructure(self): + pass + + def testNestYieldFlatPaths(self): + structure = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + result1 = list(nest.yield_flat_paths(structure, expand_composites=True)) + expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1,), (2, 'y', 0, 0), + (2, 'y', 0, 1), (2, 'y', 1)] + self.assertEqual(result1, expected1) + + result2 = list(nest.yield_flat_paths(structure, expand_composites=False)) + expected2 = [(0, 0), (1,), (2, 'y')] + self.assertEqual(result2, expected2) + + def testNestFlattenWithJoinedStringPaths(self): + structure = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + result1 = nest.flatten_with_joined_string_paths( + structure, expand_composites=True) + expected1 = [('0/0/0', 1), ('0/0/1', 2), ('0/0/2', 3), ('1', 100), + ('2/y/0/0', 4), ('2/y/0/1', 5), ('2/y/1', 6)] + self.assertEqual(result1, expected1) + + result2 = nest.flatten_with_joined_string_paths( + structure, expand_composites=False) + expected2 = [('0/0', TestCompositeTensor(1, 2, 3)), ('1', 100), + ('2/y', TestCompositeTensor(TestCompositeTensor(4, 5), 6))] + self.assertEqual(result2, expected2) + + def testNestFlattenWithTuplePaths(self): + structure = [[TestCompositeTensor(1, 2, 3)], 100, { + 'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6) + }] + result1 = nest.flatten_with_tuple_paths(structure, expand_composites=True) + expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100), + ((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5), ((2, 'y', 1), 6)] + self.assertEqual(result1, expected1) + + result2 = nest.flatten_with_tuple_paths(structure, expand_composites=False) + expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100), + ((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))] + self.assertEqual(result2, expected2) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index d22f02f1a2c..628e2d6bb4f 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -588,13 +588,14 @@ def map_structure_with_tuple_paths(func, *structure, **kwargs): **kwargs) -def _yield_flat_up_to(shallow_tree, input_tree, path=()): +def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()): """Yields (path, value) pairs of input_tree flattened up to shallow_tree. Args: shallow_tree: Nested structure. Traverse no further than its leaf nodes. input_tree: Nested structure. Return the paths and values from this tree. Must have the same upper structure as shallow_tree. + is_seq: Function used to test if a value should be treated as a sequence. path: Tuple. Optional argument, only used when recursing. The path from the root of the original shallow_tree, down to the root of the shallow_tree arg of this recursive call. @@ -604,11 +605,7 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()): shallow_tree, and value is the value of the corresponding node in input_tree. """ - if (isinstance(shallow_tree, _six.string_types) or - not any([isinstance(shallow_tree, _collections.Sequence), - isinstance(shallow_tree, _collections.Mapping), - _is_namedtuple(shallow_tree), - _is_attrs(shallow_tree)])): + if not is_seq(shallow_tree): yield (path, input_tree) else: input_tree = dict(_yield_sorted_items(input_tree)) @@ -616,12 +613,13 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()): subpath = path + (shallow_key,) input_subtree = input_tree[shallow_key] for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, - input_subtree, + input_subtree, is_seq, path=subpath): yield (leaf_path, leaf_value) -def assert_shallow_structure(shallow_tree, input_tree, check_types=True): +def assert_shallow_structure(shallow_tree, input_tree, check_types=True, + expand_composites=False): """Asserts that `shallow_tree` is a shallow structure of `input_tree`. That is, this function tests if the `input_tree` structure can be created from @@ -651,6 +649,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): `input_tree` have to be the same. Note that even with check_types==True, this function will consider two different namedtuple classes with the same name and _fields attribute to be the same class. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. @@ -659,8 +659,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): ValueError: If the sequence lengths of `shallow_tree` are different from `input_tree`. """ - if is_sequence(shallow_tree): - if not is_sequence(input_tree): + is_seq = is_sequence_or_composite if expand_composites else is_sequence + if is_seq(shallow_tree): + if not is_seq(input_tree): raise TypeError( "If shallow structure is a sequence, input must also be a sequence. " "Input has type: %s." % type(input_tree)) @@ -682,6 +683,11 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): input_type=type(input_tree), shallow_type=type(shallow_tree))) + while _is_composite_tensor(shallow_tree): + shallow_tree = shallow_tree._to_components() # pylint: disable=protected-access + while _is_composite_tensor(input_tree): + input_tree = input_tree._to_components() # pylint: disable=protected-access + if len(input_tree) < len(shallow_tree): raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( input_size=len(input_tree), @@ -696,10 +702,12 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): for shallow_branch, input_branch in zip(_yield_value(shallow_tree), _yield_value(input_tree)): assert_shallow_structure(shallow_branch, input_branch, - check_types=check_types) + check_types=check_types, + expand_composites=expand_composites) -def flatten_up_to(shallow_tree, input_tree, check_types=True): +def flatten_up_to(shallow_tree, input_tree, check_types=True, + expand_composites=False): """Flattens `input_tree` up to `shallow_tree`. Any further depth in structure in `input_tree` is retained as elements in the @@ -758,6 +766,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True): Note, numpy arrays are considered scalars. check_types: bool. If True, check that each node in shallow_tree has the same type as the corresponding node in input_tree. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Returns: A Python list, the partially flattened version of `input_tree` according to @@ -770,12 +780,15 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True): ValueError: If the sequence lengths of `shallow_tree` are different from `input_tree`. """ - assert_shallow_structure(shallow_tree, input_tree, check_types) + is_seq = is_sequence_or_composite if expand_composites else is_sequence + assert_shallow_structure(shallow_tree, input_tree, check_types=check_types, + expand_composites=expand_composites) # Discard paths returned by _yield_flat_up_to. - return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree)) + return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq)) -def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True): +def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True, + expand_composites=False): """Flattens `input_tree` up to `shallow_tree`. Any further depth in structure in `input_tree` is retained as elements in the @@ -853,6 +866,8 @@ def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True): Note, numpy arrays are considered scalars. check_types: bool. If True, check that each node in shallow_tree has the same type as the corresponding node in input_tree. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Returns: A Python list, the partially flattened version of `input_tree` according to @@ -865,8 +880,10 @@ def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True): ValueError: If the sequence lengths of `shallow_tree` are different from `input_tree`. """ - assert_shallow_structure(shallow_tree, input_tree, check_types=check_types) - return list(_yield_flat_up_to(shallow_tree, input_tree)) + is_seq = is_sequence_or_composite if expand_composites else is_sequence + assert_shallow_structure(shallow_tree, input_tree, check_types=check_types, + expand_composites=expand_composites) + return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq)) def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): @@ -1019,22 +1036,28 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs): raise ValueError("Cannot map over no sequences") check_types = kwargs.pop("check_types", True) + expand_composites = kwargs.pop("expand_composites", False) + is_seq = is_sequence_or_composite if expand_composites else is_sequence for input_tree in inputs: - assert_shallow_structure(shallow_tree, input_tree, check_types=check_types) + assert_shallow_structure(shallow_tree, input_tree, check_types=check_types, + expand_composites=expand_composites) # Flatten each input separately, apply the function to corresponding elements, # then repack based on the structure of the first input. - flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types) + flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types, + expand_composites=expand_composites) for input_tree in inputs] flat_path_list = [path for path, _ - in _yield_flat_up_to(shallow_tree, inputs[0])] + in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)] results = [func(*args, **kwargs) for args in zip(flat_path_list, *flat_value_lists)] - return pack_sequence_as(structure=shallow_tree, flat_sequence=results) + return pack_sequence_as(structure=shallow_tree, flat_sequence=results, + expand_composites=expand_composites) -def get_traverse_shallow_structure(traverse_fn, structure): +def get_traverse_shallow_structure(traverse_fn, structure, + expand_composites=False): """Generates a shallow structure from a `traverse_fn` and `structure`. `traverse_fn` must accept any possible subtree of `structure` and return @@ -1050,6 +1073,8 @@ def get_traverse_shallow_structure(traverse_fn, structure): shallow structure of the same type, describing which parts of the substructure to traverse. structure: The structure to traverse. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Returns: A shallow structure containing python bools, which can be passed to @@ -1061,8 +1086,9 @@ def get_traverse_shallow_structure(traverse_fn, structure): or if any leaf values in the returned structure or scalar are not type `bool`. """ + is_seq = is_sequence_or_composite if expand_composites else is_sequence to_traverse = traverse_fn(structure) - if not is_sequence(structure): + if not is_seq(structure): if not isinstance(to_traverse, bool): raise TypeError("traverse_fn returned structure: %s for non-structure: %s" % (to_traverse, structure)) @@ -1076,14 +1102,17 @@ def get_traverse_shallow_structure(traverse_fn, structure): # Traverse the entire substructure. for branch in _yield_value(structure): level_traverse.append( - get_traverse_shallow_structure(traverse_fn, branch)) - elif not is_sequence(to_traverse): + get_traverse_shallow_structure(traverse_fn, branch, + expand_composites=expand_composites)) + elif not is_seq(to_traverse): raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" % (to_traverse, structure)) else: # Traverse some subset of this substructure. - assert_shallow_structure(to_traverse, structure) - for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): + assert_shallow_structure(to_traverse, structure, + expand_composites=expand_composites) + for t, branch in zip(_yield_value(to_traverse), + _yield_value(structure)): if not isinstance(t, bool): raise TypeError( "traverse_fn didn't return a depth=1 structure of bools. saw: %s " @@ -1096,7 +1125,7 @@ def get_traverse_shallow_structure(traverse_fn, structure): return _sequence_like(structure, level_traverse) -def yield_flat_paths(nest): +def yield_flat_paths(nest, expand_composites=False): """Yields paths for some nested structure. Paths are lists of objects which can be str-converted, which may include @@ -1124,16 +1153,20 @@ def yield_flat_paths(nest): Args: nest: the value to produce a flattened paths list for. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Yields: Tuples containing index or key values which form the path to a specific leaf value in the nested structure. """ - for k, _ in _yield_flat_up_to(nest, nest): + is_seq = is_sequence_or_composite if expand_composites else is_sequence + for k, _ in _yield_flat_up_to(nest, nest, is_seq): yield k -def flatten_with_joined_string_paths(structure, separator="/"): +def flatten_with_joined_string_paths(structure, separator="/", + expand_composites=False): """Returns a list of (string path, data element) tuples. The order of tuples produced matches that of `nest.flatten`. This allows you @@ -1145,18 +1178,21 @@ def flatten_with_joined_string_paths(structure, separator="/"): structure: the nested structure to flatten. separator: string to separate levels of hierarchy in the results, defaults to '/'. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Returns: A list of (string, data element) tuples. """ - flat_paths = yield_flat_paths(structure) + flat_paths = yield_flat_paths(structure, expand_composites=expand_composites) def stringify_and_join(path_elements): return separator.join(str(path_element) for path_element in path_elements) flat_string_paths = [stringify_and_join(path) for path in flat_paths] - return list(zip(flat_string_paths, flatten(structure))) + return list(zip(flat_string_paths, + flatten(structure, expand_composites=expand_composites))) -def flatten_with_tuple_paths(structure): +def flatten_with_tuple_paths(structure, expand_composites=False): """Returns a list of `(tuple_path, leaf_element)` tuples. The order of pairs produced matches that of `nest.flatten`. This allows you @@ -1166,13 +1202,17 @@ def flatten_with_tuple_paths(structure): Args: structure: the nested structure to flatten. + expand_composites: If true, then composite tensors such as tf.SparseTensor + and tf.RaggedTensor are expanded into their component tensors. Returns: A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple of indices and/or dictionary keys that uniquely specify the path to `leaf_element` within `structure`. """ - return list(zip(yield_flat_paths(structure), flatten(structure))) + return list(zip(yield_flat_paths(structure, + expand_composites=expand_composites), + flatten(structure, expand_composites=expand_composites))) _pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)