Add expand_composites argument to all nest.* methods.

PiperOrigin-RevId: 242187912
This commit is contained in:
Edward Loper 2019-04-05 13:54:25 -07:00 committed by TensorFlower Gardener
parent f8dfa51551
commit dc20e15f62
2 changed files with 275 additions and 50 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.util import nest from tensorflow.python.util import nest
@test_util.run_all_in_graph_and_eager_modes
class TestCompositeTensor(composite_tensor.CompositeTensor): class TestCompositeTensor(composite_tensor.CompositeTensor):
def __init__(self, *components): def __init__(self, *components):
@ -41,12 +42,19 @@ class TestCompositeTensor(composite_tensor.CompositeTensor):
raise NotImplementedError('CompositeTensor._shape_invariant_to_components') raise NotImplementedError('CompositeTensor._shape_invariant_to_components')
def _is_graph_tensor(self): def _is_graph_tensor(self):
return True return False
def __repr__(self):
return 'TestCompositeTensor%r' % (self._components,)
def __eq__(self, other):
return (isinstance(other, TestCompositeTensor) and
self._components == other._components)
class CompositeTensorTest(test_util.TensorFlowTestCase): class CompositeTensorTest(test_util.TensorFlowTestCase):
def assertNestEqual(self, a, b, expand_composites=False): def assertNestEqual(self, a, b):
if isinstance(a, dict): if isinstance(a, dict):
self.assertIsInstance(b, dict) self.assertIsInstance(b, dict)
self.assertEqual(set(a), set(b)) self.assertEqual(set(a), set(b))
@ -57,36 +65,35 @@ class CompositeTensorTest(test_util.TensorFlowTestCase):
self.assertEqual(len(a), len(b)) self.assertEqual(len(a), len(b))
for a_val, b_val in zip(a, b): for a_val, b_val in zip(a, b):
self.assertNestEqual(a_val, b_val) self.assertNestEqual(a_val, b_val)
elif expand_composites and isinstance(a, composite_tensor.CompositeTensor): elif isinstance(a, composite_tensor.CompositeTensor):
self.assertIsInstance(b, composite_tensor.CompositeTensor) self.assertIsInstance(b, composite_tensor.CompositeTensor)
self.assertNestEqual(a._to_components(), self.assertNestEqual(a._to_components(), b._to_components())
b._to_components()) else:
self.assertAllEqual(a, b)
def testNestFlatten(self): def testNestFlatten(self):
st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10]) st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10])
st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10]) st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10])
structure = [[st1], 'foo', {'y': [st2]}] structure = [[st1], 'foo', {'y': [st2]}]
x = nest.flatten(structure, expand_composites=True) x = nest.flatten(structure, expand_composites=True)
self.assertEqual(len(x), 7) self.assertNestEqual(x, [
self.assertIs(x[0], st1.indices) st1.indices, st1.values, st1.dense_shape, 'foo', st2.indices,
self.assertIs(x[1], st1.values) st2.values, st2.dense_shape
self.assertIs(x[2], st1.dense_shape) ])
self.assertEqual(x[3], 'foo')
self.assertIs(x[4], st2.indices)
self.assertIs(x[5], st2.values)
self.assertIs(x[6], st2.dense_shape)
def testNestPackSequenceAs(self): def testNestPackSequenceAs(self):
st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10]) st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10])
st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10]) st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10])
structure1 = [[st1], 'foo', {'y': [st2]}] structure1 = [[st1], 'foo', {'y': [st2]}]
flat = [st2.indices, st2.values, st2.dense_shape, 'bar', flat = [
st1.indices, st1.values, st1.dense_shape] st2.indices, st2.values, st2.dense_shape, 'bar', st1.indices,
st1.values, st1.dense_shape
]
result = nest.pack_sequence_as(structure1, flat, expand_composites=True) result = nest.pack_sequence_as(structure1, flat, expand_composites=True)
expected = [[st2], 'bar', {'y': [st1]}] expected = [[st2], 'bar', {'y': [st1]}]
self.assertNestEqual(expected, result) self.assertNestEqual(expected, result)
def testAssertSameStructure(self): def testNestAssertSameStructure(self):
st1 = sparse_tensor.SparseTensor([[0]], [0], [100]) st1 = sparse_tensor.SparseTensor([[0]], [0], [100])
st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100]) st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100])
test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape) test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape)
@ -96,6 +103,184 @@ class CompositeTensorTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
nest.assert_same_structure(st1, test, expand_composites=True) nest.assert_same_structure(st1, test, expand_composites=True)
def testNestMapStructure(self):
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
def func(x):
return x + 10
result = nest.map_structure(func, structure, expand_composites=True)
expected = [[TestCompositeTensor(11, 12, 13)], 110, {
'y': TestCompositeTensor(TestCompositeTensor(14, 15), 16)
}]
self.assertEqual(result, expected)
def testNestMapStructureWithPaths(self):
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
def func(path, x):
return '%s:%s' % (path, x)
result = nest.map_structure_with_paths(
func, structure, expand_composites=True)
expected = [[TestCompositeTensor('0/0/0:1', '0/0/1:2', '0/0/2:3')], '1:100',
{
'y':
TestCompositeTensor(
TestCompositeTensor('2/y/0/0:4', '2/y/0/1:5'),
'2/y/1:6')
}]
self.assertEqual(result, expected)
def testNestMapStructureWithTuplePaths(self):
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
def func(path, x):
return (path, x)
result = nest.map_structure_with_tuple_paths(
func, structure, expand_composites=True)
expected = [[
TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
], ((1,), 100), {
'y':
TestCompositeTensor(
TestCompositeTensor(((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5)),
((2, 'y', 1), 6))
}]
self.assertEqual(result, expected)
def testNestAssertShallowStructure(self):
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
nest.assert_shallow_structure(s1, s2, expand_composites=False)
nest.assert_shallow_structure(s1, s2, expand_composites=True)
nest.assert_shallow_structure(s2, s1, expand_composites=False)
with self.assertRaises(TypeError):
nest.assert_shallow_structure(s2, s1, expand_composites=True)
def testNestFlattenUpTo(self):
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
result1 = nest.flatten_up_to(s1, s2, expand_composites=True)
expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6]
self.assertEqual(result1, expected1)
result2 = nest.flatten_up_to(s1, s2, expand_composites=False)
expected2 = [
TestCompositeTensor(1, 2, 3), 100,
TestCompositeTensor(TestCompositeTensor(4, 5), 6)
]
self.assertEqual(result2, expected2)
def testNestFlattenWithTuplePathsUpTo(self):
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
result1 = nest.flatten_with_tuple_paths_up_to(
s1, s2, expand_composites=True)
expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100),
((2, 'y', 0), TestCompositeTensor(4, 5)), ((2, 'y', 1), 6)]
self.assertEqual(result1, expected1)
result2 = nest.flatten_with_tuple_paths_up_to(
s1, s2, expand_composites=False)
expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100),
((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
self.assertEqual(result2, expected2)
def testNestMapStructureUpTo(self):
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
def func(x):
return x + 10 if isinstance(x, int) else x
result = nest.map_structure_up_to(s1, func, s2, expand_composites=True)
expected = [[TestCompositeTensor(11, 12, 13)], 110, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16)
}]
self.assertEqual(result, expected)
def testNestMapStructureWithTuplePathsUpTo(self):
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
def func(path, x):
return (path, x)
result = nest.map_structure_with_tuple_paths_up_to(
s1, func, s2, expand_composites=True)
expected = [[
TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
], ((1,), 100), {
'y':
TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)),
((2, 'y', 1), 6))
}]
self.assertEqual(result, expected)
def testNestGetTraverseShallowStructure(self):
pass
def testNestYieldFlatPaths(self):
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
result1 = list(nest.yield_flat_paths(structure, expand_composites=True))
expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1,), (2, 'y', 0, 0),
(2, 'y', 0, 1), (2, 'y', 1)]
self.assertEqual(result1, expected1)
result2 = list(nest.yield_flat_paths(structure, expand_composites=False))
expected2 = [(0, 0), (1,), (2, 'y')]
self.assertEqual(result2, expected2)
def testNestFlattenWithJoinedStringPaths(self):
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
result1 = nest.flatten_with_joined_string_paths(
structure, expand_composites=True)
expected1 = [('0/0/0', 1), ('0/0/1', 2), ('0/0/2', 3), ('1', 100),
('2/y/0/0', 4), ('2/y/0/1', 5), ('2/y/1', 6)]
self.assertEqual(result1, expected1)
result2 = nest.flatten_with_joined_string_paths(
structure, expand_composites=False)
expected2 = [('0/0', TestCompositeTensor(1, 2, 3)), ('1', 100),
('2/y', TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
self.assertEqual(result2, expected2)
def testNestFlattenWithTuplePaths(self):
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
}]
result1 = nest.flatten_with_tuple_paths(structure, expand_composites=True)
expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100),
((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5), ((2, 'y', 1), 6)]
self.assertEqual(result1, expected1)
result2 = nest.flatten_with_tuple_paths(structure, expand_composites=False)
expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100),
((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
self.assertEqual(result2, expected2)
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()

View File

@ -588,13 +588,14 @@ def map_structure_with_tuple_paths(func, *structure, **kwargs):
**kwargs) **kwargs)
def _yield_flat_up_to(shallow_tree, input_tree, path=()): def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree. """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Args: Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes. shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree. input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree. Must have the same upper structure as shallow_tree.
is_seq: Function used to test if a value should be treated as a sequence.
path: Tuple. Optional argument, only used when recursing. The path from the path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call. arg of this recursive call.
@ -604,11 +605,7 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
shallow_tree, and value is the value of the corresponding node in shallow_tree, and value is the value of the corresponding node in
input_tree. input_tree.
""" """
if (isinstance(shallow_tree, _six.string_types) or if not is_seq(shallow_tree):
not any([isinstance(shallow_tree, _collections.Sequence),
isinstance(shallow_tree, _collections.Mapping),
_is_namedtuple(shallow_tree),
_is_attrs(shallow_tree)])):
yield (path, input_tree) yield (path, input_tree)
else: else:
input_tree = dict(_yield_sorted_items(input_tree)) input_tree = dict(_yield_sorted_items(input_tree))
@ -616,12 +613,13 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
subpath = path + (shallow_key,) subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key] input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
input_subtree, input_subtree, is_seq,
path=subpath): path=subpath):
yield (leaf_path, leaf_value) yield (leaf_path, leaf_value)
def assert_shallow_structure(shallow_tree, input_tree, check_types=True): def assert_shallow_structure(shallow_tree, input_tree, check_types=True,
expand_composites=False):
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`. """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
That is, this function tests if the `input_tree` structure can be created from That is, this function tests if the `input_tree` structure can be created from
@ -651,6 +649,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
`input_tree` have to be the same. Note that even with check_types==True, `input_tree` have to be the same. Note that even with check_types==True,
this function will consider two different namedtuple classes with the same this function will consider two different namedtuple classes with the same
name and _fields attribute to be the same class. name and _fields attribute to be the same class.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Raises: Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not. TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
@ -659,8 +659,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
ValueError: If the sequence lengths of `shallow_tree` are different from ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`. `input_tree`.
""" """
if is_sequence(shallow_tree): is_seq = is_sequence_or_composite if expand_composites else is_sequence
if not is_sequence(input_tree): if is_seq(shallow_tree):
if not is_seq(input_tree):
raise TypeError( raise TypeError(
"If shallow structure is a sequence, input must also be a sequence. " "If shallow structure is a sequence, input must also be a sequence. "
"Input has type: %s." % type(input_tree)) "Input has type: %s." % type(input_tree))
@ -682,6 +683,11 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
input_type=type(input_tree), input_type=type(input_tree),
shallow_type=type(shallow_tree))) shallow_type=type(shallow_tree)))
while _is_composite_tensor(shallow_tree):
shallow_tree = shallow_tree._to_components() # pylint: disable=protected-access
while _is_composite_tensor(input_tree):
input_tree = input_tree._to_components() # pylint: disable=protected-access
if len(input_tree) < len(shallow_tree): if len(input_tree) < len(shallow_tree):
raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
input_size=len(input_tree), input_size=len(input_tree),
@ -696,10 +702,12 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
for shallow_branch, input_branch in zip(_yield_value(shallow_tree), for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
_yield_value(input_tree)): _yield_value(input_tree)):
assert_shallow_structure(shallow_branch, input_branch, assert_shallow_structure(shallow_branch, input_branch,
check_types=check_types) check_types=check_types,
expand_composites=expand_composites)
def flatten_up_to(shallow_tree, input_tree, check_types=True): def flatten_up_to(shallow_tree, input_tree, check_types=True,
expand_composites=False):
"""Flattens `input_tree` up to `shallow_tree`. """Flattens `input_tree` up to `shallow_tree`.
Any further depth in structure in `input_tree` is retained as elements in the Any further depth in structure in `input_tree` is retained as elements in the
@ -758,6 +766,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True):
Note, numpy arrays are considered scalars. Note, numpy arrays are considered scalars.
check_types: bool. If True, check that each node in shallow_tree has the check_types: bool. If True, check that each node in shallow_tree has the
same type as the corresponding node in input_tree. same type as the corresponding node in input_tree.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Returns: Returns:
A Python list, the partially flattened version of `input_tree` according to A Python list, the partially flattened version of `input_tree` according to
@ -770,12 +780,15 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True):
ValueError: If the sequence lengths of `shallow_tree` are different from ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`. `input_tree`.
""" """
assert_shallow_structure(shallow_tree, input_tree, check_types) is_seq = is_sequence_or_composite if expand_composites else is_sequence
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types,
expand_composites=expand_composites)
# Discard paths returned by _yield_flat_up_to. # Discard paths returned by _yield_flat_up_to.
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree)) return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq))
def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True): def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True,
expand_composites=False):
"""Flattens `input_tree` up to `shallow_tree`. """Flattens `input_tree` up to `shallow_tree`.
Any further depth in structure in `input_tree` is retained as elements in the Any further depth in structure in `input_tree` is retained as elements in the
@ -853,6 +866,8 @@ def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
Note, numpy arrays are considered scalars. Note, numpy arrays are considered scalars.
check_types: bool. If True, check that each node in shallow_tree has the check_types: bool. If True, check that each node in shallow_tree has the
same type as the corresponding node in input_tree. same type as the corresponding node in input_tree.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Returns: Returns:
A Python list, the partially flattened version of `input_tree` according to A Python list, the partially flattened version of `input_tree` according to
@ -865,8 +880,10 @@ def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
ValueError: If the sequence lengths of `shallow_tree` are different from ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`. `input_tree`.
""" """
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types) is_seq = is_sequence_or_composite if expand_composites else is_sequence
return list(_yield_flat_up_to(shallow_tree, input_tree)) assert_shallow_structure(shallow_tree, input_tree, check_types=check_types,
expand_composites=expand_composites)
return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
@ -1019,22 +1036,28 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
raise ValueError("Cannot map over no sequences") raise ValueError("Cannot map over no sequences")
check_types = kwargs.pop("check_types", True) check_types = kwargs.pop("check_types", True)
expand_composites = kwargs.pop("expand_composites", False)
is_seq = is_sequence_or_composite if expand_composites else is_sequence
for input_tree in inputs: for input_tree in inputs:
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types) assert_shallow_structure(shallow_tree, input_tree, check_types=check_types,
expand_composites=expand_composites)
# Flatten each input separately, apply the function to corresponding elements, # Flatten each input separately, apply the function to corresponding elements,
# then repack based on the structure of the first input. # then repack based on the structure of the first input.
flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types) flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types,
expand_composites=expand_composites)
for input_tree in inputs] for input_tree in inputs]
flat_path_list = [path for path, _ flat_path_list = [path for path, _
in _yield_flat_up_to(shallow_tree, inputs[0])] in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)]
results = [func(*args, **kwargs) for args in zip(flat_path_list, results = [func(*args, **kwargs) for args in zip(flat_path_list,
*flat_value_lists)] *flat_value_lists)]
return pack_sequence_as(structure=shallow_tree, flat_sequence=results) return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
expand_composites=expand_composites)
def get_traverse_shallow_structure(traverse_fn, structure): def get_traverse_shallow_structure(traverse_fn, structure,
expand_composites=False):
"""Generates a shallow structure from a `traverse_fn` and `structure`. """Generates a shallow structure from a `traverse_fn` and `structure`.
`traverse_fn` must accept any possible subtree of `structure` and return `traverse_fn` must accept any possible subtree of `structure` and return
@ -1050,6 +1073,8 @@ def get_traverse_shallow_structure(traverse_fn, structure):
shallow structure of the same type, describing which parts of the shallow structure of the same type, describing which parts of the
substructure to traverse. substructure to traverse.
structure: The structure to traverse. structure: The structure to traverse.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Returns: Returns:
A shallow structure containing python bools, which can be passed to A shallow structure containing python bools, which can be passed to
@ -1061,8 +1086,9 @@ def get_traverse_shallow_structure(traverse_fn, structure):
or if any leaf values in the returned structure or scalar are not type or if any leaf values in the returned structure or scalar are not type
`bool`. `bool`.
""" """
is_seq = is_sequence_or_composite if expand_composites else is_sequence
to_traverse = traverse_fn(structure) to_traverse = traverse_fn(structure)
if not is_sequence(structure): if not is_seq(structure):
if not isinstance(to_traverse, bool): if not isinstance(to_traverse, bool):
raise TypeError("traverse_fn returned structure: %s for non-structure: %s" raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
% (to_traverse, structure)) % (to_traverse, structure))
@ -1076,14 +1102,17 @@ def get_traverse_shallow_structure(traverse_fn, structure):
# Traverse the entire substructure. # Traverse the entire substructure.
for branch in _yield_value(structure): for branch in _yield_value(structure):
level_traverse.append( level_traverse.append(
get_traverse_shallow_structure(traverse_fn, branch)) get_traverse_shallow_structure(traverse_fn, branch,
elif not is_sequence(to_traverse): expand_composites=expand_composites))
elif not is_seq(to_traverse):
raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
% (to_traverse, structure)) % (to_traverse, structure))
else: else:
# Traverse some subset of this substructure. # Traverse some subset of this substructure.
assert_shallow_structure(to_traverse, structure) assert_shallow_structure(to_traverse, structure,
for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): expand_composites=expand_composites)
for t, branch in zip(_yield_value(to_traverse),
_yield_value(structure)):
if not isinstance(t, bool): if not isinstance(t, bool):
raise TypeError( raise TypeError(
"traverse_fn didn't return a depth=1 structure of bools. saw: %s " "traverse_fn didn't return a depth=1 structure of bools. saw: %s "
@ -1096,7 +1125,7 @@ def get_traverse_shallow_structure(traverse_fn, structure):
return _sequence_like(structure, level_traverse) return _sequence_like(structure, level_traverse)
def yield_flat_paths(nest): def yield_flat_paths(nest, expand_composites=False):
"""Yields paths for some nested structure. """Yields paths for some nested structure.
Paths are lists of objects which can be str-converted, which may include Paths are lists of objects which can be str-converted, which may include
@ -1124,16 +1153,20 @@ def yield_flat_paths(nest):
Args: Args:
nest: the value to produce a flattened paths list for. nest: the value to produce a flattened paths list for.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Yields: Yields:
Tuples containing index or key values which form the path to a specific Tuples containing index or key values which form the path to a specific
leaf value in the nested structure. leaf value in the nested structure.
""" """
for k, _ in _yield_flat_up_to(nest, nest): is_seq = is_sequence_or_composite if expand_composites else is_sequence
for k, _ in _yield_flat_up_to(nest, nest, is_seq):
yield k yield k
def flatten_with_joined_string_paths(structure, separator="/"): def flatten_with_joined_string_paths(structure, separator="/",
expand_composites=False):
"""Returns a list of (string path, data element) tuples. """Returns a list of (string path, data element) tuples.
The order of tuples produced matches that of `nest.flatten`. This allows you The order of tuples produced matches that of `nest.flatten`. This allows you
@ -1145,18 +1178,21 @@ def flatten_with_joined_string_paths(structure, separator="/"):
structure: the nested structure to flatten. structure: the nested structure to flatten.
separator: string to separate levels of hierarchy in the results, defaults separator: string to separate levels of hierarchy in the results, defaults
to '/'. to '/'.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Returns: Returns:
A list of (string, data element) tuples. A list of (string, data element) tuples.
""" """
flat_paths = yield_flat_paths(structure) flat_paths = yield_flat_paths(structure, expand_composites=expand_composites)
def stringify_and_join(path_elements): def stringify_and_join(path_elements):
return separator.join(str(path_element) for path_element in path_elements) return separator.join(str(path_element) for path_element in path_elements)
flat_string_paths = [stringify_and_join(path) for path in flat_paths] flat_string_paths = [stringify_and_join(path) for path in flat_paths]
return list(zip(flat_string_paths, flatten(structure))) return list(zip(flat_string_paths,
flatten(structure, expand_composites=expand_composites)))
def flatten_with_tuple_paths(structure): def flatten_with_tuple_paths(structure, expand_composites=False):
"""Returns a list of `(tuple_path, leaf_element)` tuples. """Returns a list of `(tuple_path, leaf_element)` tuples.
The order of pairs produced matches that of `nest.flatten`. This allows you The order of pairs produced matches that of `nest.flatten`. This allows you
@ -1166,13 +1202,17 @@ def flatten_with_tuple_paths(structure):
Args: Args:
structure: the nested structure to flatten. structure: the nested structure to flatten.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Returns: Returns:
A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple
of indices and/or dictionary keys that uniquely specify the path to of indices and/or dictionary keys that uniquely specify the path to
`leaf_element` within `structure`. `leaf_element` within `structure`.
""" """
return list(zip(yield_flat_paths(structure), flatten(structure))) return list(zip(yield_flat_paths(structure,
expand_composites=expand_composites),
flatten(structure, expand_composites=expand_composites)))
_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping) _pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)