Add expand_composites argument to all nest.* methods.
PiperOrigin-RevId: 242187912
This commit is contained in:
parent
f8dfa51551
commit
dc20e15f62
@ -25,6 +25,7 @@ from tensorflow.python.platform import googletest
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class TestCompositeTensor(composite_tensor.CompositeTensor):
|
class TestCompositeTensor(composite_tensor.CompositeTensor):
|
||||||
|
|
||||||
def __init__(self, *components):
|
def __init__(self, *components):
|
||||||
@ -41,12 +42,19 @@ class TestCompositeTensor(composite_tensor.CompositeTensor):
|
|||||||
raise NotImplementedError('CompositeTensor._shape_invariant_to_components')
|
raise NotImplementedError('CompositeTensor._shape_invariant_to_components')
|
||||||
|
|
||||||
def _is_graph_tensor(self):
|
def _is_graph_tensor(self):
|
||||||
return True
|
return False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'TestCompositeTensor%r' % (self._components,)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (isinstance(other, TestCompositeTensor) and
|
||||||
|
self._components == other._components)
|
||||||
|
|
||||||
|
|
||||||
class CompositeTensorTest(test_util.TensorFlowTestCase):
|
class CompositeTensorTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def assertNestEqual(self, a, b, expand_composites=False):
|
def assertNestEqual(self, a, b):
|
||||||
if isinstance(a, dict):
|
if isinstance(a, dict):
|
||||||
self.assertIsInstance(b, dict)
|
self.assertIsInstance(b, dict)
|
||||||
self.assertEqual(set(a), set(b))
|
self.assertEqual(set(a), set(b))
|
||||||
@ -57,36 +65,35 @@ class CompositeTensorTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(len(a), len(b))
|
self.assertEqual(len(a), len(b))
|
||||||
for a_val, b_val in zip(a, b):
|
for a_val, b_val in zip(a, b):
|
||||||
self.assertNestEqual(a_val, b_val)
|
self.assertNestEqual(a_val, b_val)
|
||||||
elif expand_composites and isinstance(a, composite_tensor.CompositeTensor):
|
elif isinstance(a, composite_tensor.CompositeTensor):
|
||||||
self.assertIsInstance(b, composite_tensor.CompositeTensor)
|
self.assertIsInstance(b, composite_tensor.CompositeTensor)
|
||||||
self.assertNestEqual(a._to_components(),
|
self.assertNestEqual(a._to_components(), b._to_components())
|
||||||
b._to_components())
|
else:
|
||||||
|
self.assertAllEqual(a, b)
|
||||||
|
|
||||||
def testNestFlatten(self):
|
def testNestFlatten(self):
|
||||||
st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10])
|
st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10])
|
||||||
st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10])
|
st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10])
|
||||||
structure = [[st1], 'foo', {'y': [st2]}]
|
structure = [[st1], 'foo', {'y': [st2]}]
|
||||||
x = nest.flatten(structure, expand_composites=True)
|
x = nest.flatten(structure, expand_composites=True)
|
||||||
self.assertEqual(len(x), 7)
|
self.assertNestEqual(x, [
|
||||||
self.assertIs(x[0], st1.indices)
|
st1.indices, st1.values, st1.dense_shape, 'foo', st2.indices,
|
||||||
self.assertIs(x[1], st1.values)
|
st2.values, st2.dense_shape
|
||||||
self.assertIs(x[2], st1.dense_shape)
|
])
|
||||||
self.assertEqual(x[3], 'foo')
|
|
||||||
self.assertIs(x[4], st2.indices)
|
|
||||||
self.assertIs(x[5], st2.values)
|
|
||||||
self.assertIs(x[6], st2.dense_shape)
|
|
||||||
|
|
||||||
def testNestPackSequenceAs(self):
|
def testNestPackSequenceAs(self):
|
||||||
st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10])
|
st1 = sparse_tensor.SparseTensor([[0, 3], [7, 2]], [1, 2], [10, 10])
|
||||||
st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10])
|
st2 = sparse_tensor.SparseTensor([[1, 2, 3]], ['a'], [10, 10, 10])
|
||||||
structure1 = [[st1], 'foo', {'y': [st2]}]
|
structure1 = [[st1], 'foo', {'y': [st2]}]
|
||||||
flat = [st2.indices, st2.values, st2.dense_shape, 'bar',
|
flat = [
|
||||||
st1.indices, st1.values, st1.dense_shape]
|
st2.indices, st2.values, st2.dense_shape, 'bar', st1.indices,
|
||||||
|
st1.values, st1.dense_shape
|
||||||
|
]
|
||||||
result = nest.pack_sequence_as(structure1, flat, expand_composites=True)
|
result = nest.pack_sequence_as(structure1, flat, expand_composites=True)
|
||||||
expected = [[st2], 'bar', {'y': [st1]}]
|
expected = [[st2], 'bar', {'y': [st1]}]
|
||||||
self.assertNestEqual(expected, result)
|
self.assertNestEqual(expected, result)
|
||||||
|
|
||||||
def testAssertSameStructure(self):
|
def testNestAssertSameStructure(self):
|
||||||
st1 = sparse_tensor.SparseTensor([[0]], [0], [100])
|
st1 = sparse_tensor.SparseTensor([[0]], [0], [100])
|
||||||
st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100])
|
st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100])
|
||||||
test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape)
|
test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape)
|
||||||
@ -96,6 +103,184 @@ class CompositeTensorTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
nest.assert_same_structure(st1, test, expand_composites=True)
|
nest.assert_same_structure(st1, test, expand_composites=True)
|
||||||
|
|
||||||
|
def testNestMapStructure(self):
|
||||||
|
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
return x + 10
|
||||||
|
|
||||||
|
result = nest.map_structure(func, structure, expand_composites=True)
|
||||||
|
expected = [[TestCompositeTensor(11, 12, 13)], 110, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(14, 15), 16)
|
||||||
|
}]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def testNestMapStructureWithPaths(self):
|
||||||
|
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
|
||||||
|
def func(path, x):
|
||||||
|
return '%s:%s' % (path, x)
|
||||||
|
|
||||||
|
result = nest.map_structure_with_paths(
|
||||||
|
func, structure, expand_composites=True)
|
||||||
|
expected = [[TestCompositeTensor('0/0/0:1', '0/0/1:2', '0/0/2:3')], '1:100',
|
||||||
|
{
|
||||||
|
'y':
|
||||||
|
TestCompositeTensor(
|
||||||
|
TestCompositeTensor('2/y/0/0:4', '2/y/0/1:5'),
|
||||||
|
'2/y/1:6')
|
||||||
|
}]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def testNestMapStructureWithTuplePaths(self):
|
||||||
|
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
|
||||||
|
def func(path, x):
|
||||||
|
return (path, x)
|
||||||
|
|
||||||
|
result = nest.map_structure_with_tuple_paths(
|
||||||
|
func, structure, expand_composites=True)
|
||||||
|
expected = [[
|
||||||
|
TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
|
||||||
|
], ((1,), 100), {
|
||||||
|
'y':
|
||||||
|
TestCompositeTensor(
|
||||||
|
TestCompositeTensor(((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5)),
|
||||||
|
((2, 'y', 1), 6))
|
||||||
|
}]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def testNestAssertShallowStructure(self):
|
||||||
|
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
|
||||||
|
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
nest.assert_shallow_structure(s1, s2, expand_composites=False)
|
||||||
|
nest.assert_shallow_structure(s1, s2, expand_composites=True)
|
||||||
|
nest.assert_shallow_structure(s2, s1, expand_composites=False)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
nest.assert_shallow_structure(s2, s1, expand_composites=True)
|
||||||
|
|
||||||
|
def testNestFlattenUpTo(self):
|
||||||
|
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
|
||||||
|
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
result1 = nest.flatten_up_to(s1, s2, expand_composites=True)
|
||||||
|
expected1 = [1, 2, 3, 100, TestCompositeTensor(4, 5), 6]
|
||||||
|
self.assertEqual(result1, expected1)
|
||||||
|
|
||||||
|
result2 = nest.flatten_up_to(s1, s2, expand_composites=False)
|
||||||
|
expected2 = [
|
||||||
|
TestCompositeTensor(1, 2, 3), 100,
|
||||||
|
TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
]
|
||||||
|
self.assertEqual(result2, expected2)
|
||||||
|
|
||||||
|
def testNestFlattenWithTuplePathsUpTo(self):
|
||||||
|
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
|
||||||
|
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
result1 = nest.flatten_with_tuple_paths_up_to(
|
||||||
|
s1, s2, expand_composites=True)
|
||||||
|
expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100),
|
||||||
|
((2, 'y', 0), TestCompositeTensor(4, 5)), ((2, 'y', 1), 6)]
|
||||||
|
self.assertEqual(result1, expected1)
|
||||||
|
|
||||||
|
result2 = nest.flatten_with_tuple_paths_up_to(
|
||||||
|
s1, s2, expand_composites=False)
|
||||||
|
expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100),
|
||||||
|
((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
|
||||||
|
self.assertEqual(result2, expected2)
|
||||||
|
|
||||||
|
def testNestMapStructureUpTo(self):
|
||||||
|
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
|
||||||
|
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
return x + 10 if isinstance(x, int) else x
|
||||||
|
|
||||||
|
result = nest.map_structure_up_to(s1, func, s2, expand_composites=True)
|
||||||
|
expected = [[TestCompositeTensor(11, 12, 13)], 110, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 16)
|
||||||
|
}]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def testNestMapStructureWithTuplePathsUpTo(self):
|
||||||
|
s1 = [[TestCompositeTensor(1, 2, 3)], 100, {'y': TestCompositeTensor(5, 6)}]
|
||||||
|
s2 = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
|
||||||
|
def func(path, x):
|
||||||
|
return (path, x)
|
||||||
|
|
||||||
|
result = nest.map_structure_with_tuple_paths_up_to(
|
||||||
|
s1, func, s2, expand_composites=True)
|
||||||
|
expected = [[
|
||||||
|
TestCompositeTensor(((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3))
|
||||||
|
], ((1,), 100), {
|
||||||
|
'y':
|
||||||
|
TestCompositeTensor(((2, 'y', 0), TestCompositeTensor(4, 5)),
|
||||||
|
((2, 'y', 1), 6))
|
||||||
|
}]
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def testNestGetTraverseShallowStructure(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def testNestYieldFlatPaths(self):
|
||||||
|
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
result1 = list(nest.yield_flat_paths(structure, expand_composites=True))
|
||||||
|
expected1 = [(0, 0, 0), (0, 0, 1), (0, 0, 2), (1,), (2, 'y', 0, 0),
|
||||||
|
(2, 'y', 0, 1), (2, 'y', 1)]
|
||||||
|
self.assertEqual(result1, expected1)
|
||||||
|
|
||||||
|
result2 = list(nest.yield_flat_paths(structure, expand_composites=False))
|
||||||
|
expected2 = [(0, 0), (1,), (2, 'y')]
|
||||||
|
self.assertEqual(result2, expected2)
|
||||||
|
|
||||||
|
def testNestFlattenWithJoinedStringPaths(self):
|
||||||
|
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
result1 = nest.flatten_with_joined_string_paths(
|
||||||
|
structure, expand_composites=True)
|
||||||
|
expected1 = [('0/0/0', 1), ('0/0/1', 2), ('0/0/2', 3), ('1', 100),
|
||||||
|
('2/y/0/0', 4), ('2/y/0/1', 5), ('2/y/1', 6)]
|
||||||
|
self.assertEqual(result1, expected1)
|
||||||
|
|
||||||
|
result2 = nest.flatten_with_joined_string_paths(
|
||||||
|
structure, expand_composites=False)
|
||||||
|
expected2 = [('0/0', TestCompositeTensor(1, 2, 3)), ('1', 100),
|
||||||
|
('2/y', TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
|
||||||
|
self.assertEqual(result2, expected2)
|
||||||
|
|
||||||
|
def testNestFlattenWithTuplePaths(self):
|
||||||
|
structure = [[TestCompositeTensor(1, 2, 3)], 100, {
|
||||||
|
'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
|
||||||
|
}]
|
||||||
|
result1 = nest.flatten_with_tuple_paths(structure, expand_composites=True)
|
||||||
|
expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100),
|
||||||
|
((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5), ((2, 'y', 1), 6)]
|
||||||
|
self.assertEqual(result1, expected1)
|
||||||
|
|
||||||
|
result2 = nest.flatten_with_tuple_paths(structure, expand_composites=False)
|
||||||
|
expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100),
|
||||||
|
((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
|
||||||
|
self.assertEqual(result2, expected2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -588,13 +588,14 @@ def map_structure_with_tuple_paths(func, *structure, **kwargs):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
|
||||||
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
|
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
|
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
|
||||||
input_tree: Nested structure. Return the paths and values from this tree.
|
input_tree: Nested structure. Return the paths and values from this tree.
|
||||||
Must have the same upper structure as shallow_tree.
|
Must have the same upper structure as shallow_tree.
|
||||||
|
is_seq: Function used to test if a value should be treated as a sequence.
|
||||||
path: Tuple. Optional argument, only used when recursing. The path from the
|
path: Tuple. Optional argument, only used when recursing. The path from the
|
||||||
root of the original shallow_tree, down to the root of the shallow_tree
|
root of the original shallow_tree, down to the root of the shallow_tree
|
||||||
arg of this recursive call.
|
arg of this recursive call.
|
||||||
@ -604,11 +605,7 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
|||||||
shallow_tree, and value is the value of the corresponding node in
|
shallow_tree, and value is the value of the corresponding node in
|
||||||
input_tree.
|
input_tree.
|
||||||
"""
|
"""
|
||||||
if (isinstance(shallow_tree, _six.string_types) or
|
if not is_seq(shallow_tree):
|
||||||
not any([isinstance(shallow_tree, _collections.Sequence),
|
|
||||||
isinstance(shallow_tree, _collections.Mapping),
|
|
||||||
_is_namedtuple(shallow_tree),
|
|
||||||
_is_attrs(shallow_tree)])):
|
|
||||||
yield (path, input_tree)
|
yield (path, input_tree)
|
||||||
else:
|
else:
|
||||||
input_tree = dict(_yield_sorted_items(input_tree))
|
input_tree = dict(_yield_sorted_items(input_tree))
|
||||||
@ -616,12 +613,13 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
|||||||
subpath = path + (shallow_key,)
|
subpath = path + (shallow_key,)
|
||||||
input_subtree = input_tree[shallow_key]
|
input_subtree = input_tree[shallow_key]
|
||||||
for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
|
for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
|
||||||
input_subtree,
|
input_subtree, is_seq,
|
||||||
path=subpath):
|
path=subpath):
|
||||||
yield (leaf_path, leaf_value)
|
yield (leaf_path, leaf_value)
|
||||||
|
|
||||||
|
|
||||||
def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
def assert_shallow_structure(shallow_tree, input_tree, check_types=True,
|
||||||
|
expand_composites=False):
|
||||||
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
|
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
|
||||||
|
|
||||||
That is, this function tests if the `input_tree` structure can be created from
|
That is, this function tests if the `input_tree` structure can be created from
|
||||||
@ -651,6 +649,8 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
|||||||
`input_tree` have to be the same. Note that even with check_types==True,
|
`input_tree` have to be the same. Note that even with check_types==True,
|
||||||
this function will consider two different namedtuple classes with the same
|
this function will consider two different namedtuple classes with the same
|
||||||
name and _fields attribute to be the same class.
|
name and _fields attribute to be the same class.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||||
@ -659,8 +659,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
|||||||
ValueError: If the sequence lengths of `shallow_tree` are different from
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
`input_tree`.
|
`input_tree`.
|
||||||
"""
|
"""
|
||||||
if is_sequence(shallow_tree):
|
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||||
if not is_sequence(input_tree):
|
if is_seq(shallow_tree):
|
||||||
|
if not is_seq(input_tree):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"If shallow structure is a sequence, input must also be a sequence. "
|
"If shallow structure is a sequence, input must also be a sequence. "
|
||||||
"Input has type: %s." % type(input_tree))
|
"Input has type: %s." % type(input_tree))
|
||||||
@ -682,6 +683,11 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
|||||||
input_type=type(input_tree),
|
input_type=type(input_tree),
|
||||||
shallow_type=type(shallow_tree)))
|
shallow_type=type(shallow_tree)))
|
||||||
|
|
||||||
|
while _is_composite_tensor(shallow_tree):
|
||||||
|
shallow_tree = shallow_tree._to_components() # pylint: disable=protected-access
|
||||||
|
while _is_composite_tensor(input_tree):
|
||||||
|
input_tree = input_tree._to_components() # pylint: disable=protected-access
|
||||||
|
|
||||||
if len(input_tree) < len(shallow_tree):
|
if len(input_tree) < len(shallow_tree):
|
||||||
raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
|
raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
|
||||||
input_size=len(input_tree),
|
input_size=len(input_tree),
|
||||||
@ -696,10 +702,12 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
|||||||
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
|
||||||
_yield_value(input_tree)):
|
_yield_value(input_tree)):
|
||||||
assert_shallow_structure(shallow_branch, input_branch,
|
assert_shallow_structure(shallow_branch, input_branch,
|
||||||
check_types=check_types)
|
check_types=check_types,
|
||||||
|
expand_composites=expand_composites)
|
||||||
|
|
||||||
|
|
||||||
def flatten_up_to(shallow_tree, input_tree, check_types=True):
|
def flatten_up_to(shallow_tree, input_tree, check_types=True,
|
||||||
|
expand_composites=False):
|
||||||
"""Flattens `input_tree` up to `shallow_tree`.
|
"""Flattens `input_tree` up to `shallow_tree`.
|
||||||
|
|
||||||
Any further depth in structure in `input_tree` is retained as elements in the
|
Any further depth in structure in `input_tree` is retained as elements in the
|
||||||
@ -758,6 +766,8 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True):
|
|||||||
Note, numpy arrays are considered scalars.
|
Note, numpy arrays are considered scalars.
|
||||||
check_types: bool. If True, check that each node in shallow_tree has the
|
check_types: bool. If True, check that each node in shallow_tree has the
|
||||||
same type as the corresponding node in input_tree.
|
same type as the corresponding node in input_tree.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python list, the partially flattened version of `input_tree` according to
|
A Python list, the partially flattened version of `input_tree` according to
|
||||||
@ -770,12 +780,15 @@ def flatten_up_to(shallow_tree, input_tree, check_types=True):
|
|||||||
ValueError: If the sequence lengths of `shallow_tree` are different from
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
`input_tree`.
|
`input_tree`.
|
||||||
"""
|
"""
|
||||||
assert_shallow_structure(shallow_tree, input_tree, check_types)
|
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||||
|
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types,
|
||||||
|
expand_composites=expand_composites)
|
||||||
# Discard paths returned by _yield_flat_up_to.
|
# Discard paths returned by _yield_flat_up_to.
|
||||||
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
|
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq))
|
||||||
|
|
||||||
|
|
||||||
def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
|
def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True,
|
||||||
|
expand_composites=False):
|
||||||
"""Flattens `input_tree` up to `shallow_tree`.
|
"""Flattens `input_tree` up to `shallow_tree`.
|
||||||
|
|
||||||
Any further depth in structure in `input_tree` is retained as elements in the
|
Any further depth in structure in `input_tree` is retained as elements in the
|
||||||
@ -853,6 +866,8 @@ def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
|
|||||||
Note, numpy arrays are considered scalars.
|
Note, numpy arrays are considered scalars.
|
||||||
check_types: bool. If True, check that each node in shallow_tree has the
|
check_types: bool. If True, check that each node in shallow_tree has the
|
||||||
same type as the corresponding node in input_tree.
|
same type as the corresponding node in input_tree.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Python list, the partially flattened version of `input_tree` according to
|
A Python list, the partially flattened version of `input_tree` according to
|
||||||
@ -865,8 +880,10 @@ def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
|
|||||||
ValueError: If the sequence lengths of `shallow_tree` are different from
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
`input_tree`.
|
`input_tree`.
|
||||||
"""
|
"""
|
||||||
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
|
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||||
return list(_yield_flat_up_to(shallow_tree, input_tree))
|
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types,
|
||||||
|
expand_composites=expand_composites)
|
||||||
|
return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
|
||||||
|
|
||||||
|
|
||||||
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
|
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
|
||||||
@ -1019,22 +1036,28 @@ def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
|
|||||||
raise ValueError("Cannot map over no sequences")
|
raise ValueError("Cannot map over no sequences")
|
||||||
|
|
||||||
check_types = kwargs.pop("check_types", True)
|
check_types = kwargs.pop("check_types", True)
|
||||||
|
expand_composites = kwargs.pop("expand_composites", False)
|
||||||
|
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||||
|
|
||||||
for input_tree in inputs:
|
for input_tree in inputs:
|
||||||
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
|
assert_shallow_structure(shallow_tree, input_tree, check_types=check_types,
|
||||||
|
expand_composites=expand_composites)
|
||||||
|
|
||||||
# Flatten each input separately, apply the function to corresponding elements,
|
# Flatten each input separately, apply the function to corresponding elements,
|
||||||
# then repack based on the structure of the first input.
|
# then repack based on the structure of the first input.
|
||||||
flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types)
|
flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types,
|
||||||
|
expand_composites=expand_composites)
|
||||||
for input_tree in inputs]
|
for input_tree in inputs]
|
||||||
flat_path_list = [path for path, _
|
flat_path_list = [path for path, _
|
||||||
in _yield_flat_up_to(shallow_tree, inputs[0])]
|
in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)]
|
||||||
results = [func(*args, **kwargs) for args in zip(flat_path_list,
|
results = [func(*args, **kwargs) for args in zip(flat_path_list,
|
||||||
*flat_value_lists)]
|
*flat_value_lists)]
|
||||||
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
|
return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
|
||||||
|
expand_composites=expand_composites)
|
||||||
|
|
||||||
|
|
||||||
def get_traverse_shallow_structure(traverse_fn, structure):
|
def get_traverse_shallow_structure(traverse_fn, structure,
|
||||||
|
expand_composites=False):
|
||||||
"""Generates a shallow structure from a `traverse_fn` and `structure`.
|
"""Generates a shallow structure from a `traverse_fn` and `structure`.
|
||||||
|
|
||||||
`traverse_fn` must accept any possible subtree of `structure` and return
|
`traverse_fn` must accept any possible subtree of `structure` and return
|
||||||
@ -1050,6 +1073,8 @@ def get_traverse_shallow_structure(traverse_fn, structure):
|
|||||||
shallow structure of the same type, describing which parts of the
|
shallow structure of the same type, describing which parts of the
|
||||||
substructure to traverse.
|
substructure to traverse.
|
||||||
structure: The structure to traverse.
|
structure: The structure to traverse.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A shallow structure containing python bools, which can be passed to
|
A shallow structure containing python bools, which can be passed to
|
||||||
@ -1061,8 +1086,9 @@ def get_traverse_shallow_structure(traverse_fn, structure):
|
|||||||
or if any leaf values in the returned structure or scalar are not type
|
or if any leaf values in the returned structure or scalar are not type
|
||||||
`bool`.
|
`bool`.
|
||||||
"""
|
"""
|
||||||
|
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||||
to_traverse = traverse_fn(structure)
|
to_traverse = traverse_fn(structure)
|
||||||
if not is_sequence(structure):
|
if not is_seq(structure):
|
||||||
if not isinstance(to_traverse, bool):
|
if not isinstance(to_traverse, bool):
|
||||||
raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
|
raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
|
||||||
% (to_traverse, structure))
|
% (to_traverse, structure))
|
||||||
@ -1076,14 +1102,17 @@ def get_traverse_shallow_structure(traverse_fn, structure):
|
|||||||
# Traverse the entire substructure.
|
# Traverse the entire substructure.
|
||||||
for branch in _yield_value(structure):
|
for branch in _yield_value(structure):
|
||||||
level_traverse.append(
|
level_traverse.append(
|
||||||
get_traverse_shallow_structure(traverse_fn, branch))
|
get_traverse_shallow_structure(traverse_fn, branch,
|
||||||
elif not is_sequence(to_traverse):
|
expand_composites=expand_composites))
|
||||||
|
elif not is_seq(to_traverse):
|
||||||
raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
|
raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
|
||||||
% (to_traverse, structure))
|
% (to_traverse, structure))
|
||||||
else:
|
else:
|
||||||
# Traverse some subset of this substructure.
|
# Traverse some subset of this substructure.
|
||||||
assert_shallow_structure(to_traverse, structure)
|
assert_shallow_structure(to_traverse, structure,
|
||||||
for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
|
expand_composites=expand_composites)
|
||||||
|
for t, branch in zip(_yield_value(to_traverse),
|
||||||
|
_yield_value(structure)):
|
||||||
if not isinstance(t, bool):
|
if not isinstance(t, bool):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"traverse_fn didn't return a depth=1 structure of bools. saw: %s "
|
"traverse_fn didn't return a depth=1 structure of bools. saw: %s "
|
||||||
@ -1096,7 +1125,7 @@ def get_traverse_shallow_structure(traverse_fn, structure):
|
|||||||
return _sequence_like(structure, level_traverse)
|
return _sequence_like(structure, level_traverse)
|
||||||
|
|
||||||
|
|
||||||
def yield_flat_paths(nest):
|
def yield_flat_paths(nest, expand_composites=False):
|
||||||
"""Yields paths for some nested structure.
|
"""Yields paths for some nested structure.
|
||||||
|
|
||||||
Paths are lists of objects which can be str-converted, which may include
|
Paths are lists of objects which can be str-converted, which may include
|
||||||
@ -1124,16 +1153,20 @@ def yield_flat_paths(nest):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
nest: the value to produce a flattened paths list for.
|
nest: the value to produce a flattened paths list for.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuples containing index or key values which form the path to a specific
|
Tuples containing index or key values which form the path to a specific
|
||||||
leaf value in the nested structure.
|
leaf value in the nested structure.
|
||||||
"""
|
"""
|
||||||
for k, _ in _yield_flat_up_to(nest, nest):
|
is_seq = is_sequence_or_composite if expand_composites else is_sequence
|
||||||
|
for k, _ in _yield_flat_up_to(nest, nest, is_seq):
|
||||||
yield k
|
yield k
|
||||||
|
|
||||||
|
|
||||||
def flatten_with_joined_string_paths(structure, separator="/"):
|
def flatten_with_joined_string_paths(structure, separator="/",
|
||||||
|
expand_composites=False):
|
||||||
"""Returns a list of (string path, data element) tuples.
|
"""Returns a list of (string path, data element) tuples.
|
||||||
|
|
||||||
The order of tuples produced matches that of `nest.flatten`. This allows you
|
The order of tuples produced matches that of `nest.flatten`. This allows you
|
||||||
@ -1145,18 +1178,21 @@ def flatten_with_joined_string_paths(structure, separator="/"):
|
|||||||
structure: the nested structure to flatten.
|
structure: the nested structure to flatten.
|
||||||
separator: string to separate levels of hierarchy in the results, defaults
|
separator: string to separate levels of hierarchy in the results, defaults
|
||||||
to '/'.
|
to '/'.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of (string, data element) tuples.
|
A list of (string, data element) tuples.
|
||||||
"""
|
"""
|
||||||
flat_paths = yield_flat_paths(structure)
|
flat_paths = yield_flat_paths(structure, expand_composites=expand_composites)
|
||||||
def stringify_and_join(path_elements):
|
def stringify_and_join(path_elements):
|
||||||
return separator.join(str(path_element) for path_element in path_elements)
|
return separator.join(str(path_element) for path_element in path_elements)
|
||||||
flat_string_paths = [stringify_and_join(path) for path in flat_paths]
|
flat_string_paths = [stringify_and_join(path) for path in flat_paths]
|
||||||
return list(zip(flat_string_paths, flatten(structure)))
|
return list(zip(flat_string_paths,
|
||||||
|
flatten(structure, expand_composites=expand_composites)))
|
||||||
|
|
||||||
|
|
||||||
def flatten_with_tuple_paths(structure):
|
def flatten_with_tuple_paths(structure, expand_composites=False):
|
||||||
"""Returns a list of `(tuple_path, leaf_element)` tuples.
|
"""Returns a list of `(tuple_path, leaf_element)` tuples.
|
||||||
|
|
||||||
The order of pairs produced matches that of `nest.flatten`. This allows you
|
The order of pairs produced matches that of `nest.flatten`. This allows you
|
||||||
@ -1166,13 +1202,17 @@ def flatten_with_tuple_paths(structure):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
structure: the nested structure to flatten.
|
structure: the nested structure to flatten.
|
||||||
|
expand_composites: If true, then composite tensors such as tf.SparseTensor
|
||||||
|
and tf.RaggedTensor are expanded into their component tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple
|
A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple
|
||||||
of indices and/or dictionary keys that uniquely specify the path to
|
of indices and/or dictionary keys that uniquely specify the path to
|
||||||
`leaf_element` within `structure`.
|
`leaf_element` within `structure`.
|
||||||
"""
|
"""
|
||||||
return list(zip(yield_flat_paths(structure), flatten(structure)))
|
return list(zip(yield_flat_paths(structure,
|
||||||
|
expand_composites=expand_composites),
|
||||||
|
flatten(structure, expand_composites=expand_composites)))
|
||||||
|
|
||||||
|
|
||||||
_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
|
_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
|
||||||
|
Loading…
Reference in New Issue
Block a user