Add nest.map_structure.
Useful for simplifying a number of flatten -> map -> recreate structure calls. Change: 144388829
This commit is contained in:
parent
9914eaba62
commit
1013cee707
tensorflow/python/util
@ -26,6 +26,7 @@ should be recursive.
|
||||
@@flatten
|
||||
@@flatten_dict_items
|
||||
@@pack_sequence_as
|
||||
@@map_structure
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -260,3 +261,42 @@ def pack_sequence_as(structure, flat_sequence):
|
||||
|
||||
_, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
|
||||
return _sequence_like(structure, packed)
|
||||
|
||||
|
||||
def map_structure(func, *structure):
|
||||
"""Applies `func` to each entry in `structure` and returns a new structure.
|
||||
|
||||
Applies `func(x[0], x[1], ...)` where x[i] is an entry in
|
||||
`structure[i]`. All structures in `structure` must have the same arity,
|
||||
and the return value will contain the results in the same structure.
|
||||
|
||||
Args:
|
||||
func: A callable that acceps as many arguments are there are structures.
|
||||
*structure: scalar, or tuple or list of constructed scalars and/or other
|
||||
tuples/lists, or scalars. Note: numpy arrays are considered scalars.
|
||||
|
||||
Returns:
|
||||
A new structure with the same arity as `structure`, whose values correspond
|
||||
to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
|
||||
location in `structure[i]`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `func` is not callable or if the structures do not match
|
||||
each other by depth tree.
|
||||
ValueError: If no structure is provided or if the structures do not match
|
||||
each other by type.
|
||||
"""
|
||||
if not callable(func):
|
||||
raise TypeError("func must be callable, got: %s" % func)
|
||||
|
||||
if not structure:
|
||||
raise ValueError("Must provide at least one structure")
|
||||
|
||||
for other in structure[1:]:
|
||||
assert_same_structure(structure[0], other)
|
||||
|
||||
flat_structure = [flatten(s) for s in structure]
|
||||
entries = zip(*flat_structure)
|
||||
|
||||
return pack_sequence_as(
|
||||
structure[0], [func(*x) for x in entries])
|
||||
|
@ -139,6 +139,36 @@ class NestTest(test.TestCase):
|
||||
"don't have the same nested structure"):
|
||||
nest.assert_same_structure([[3], 4], [3, [4]])
|
||||
|
||||
def testMapStructure(self):
|
||||
structure1 = (((1, 2), 3), 4, (5, 6))
|
||||
structure2 = (((7, 8), 9), 10, (11, 12))
|
||||
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
|
||||
nest.assert_same_structure(structure1, structure1_plus1)
|
||||
self.assertAllEqual(
|
||||
[2, 3, 4, 5, 6, 7],
|
||||
nest.flatten(structure1_plus1))
|
||||
structure1_plus_structure2 = nest.map_structure(
|
||||
lambda x, y: x + y, structure1, structure2)
|
||||
self.assertEqual(
|
||||
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
|
||||
structure1_plus_structure2)
|
||||
|
||||
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
|
||||
|
||||
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "callable"):
|
||||
nest.map_structure("bad", structure1_plus1)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||
nest.map_structure(lambda x, y: None, 3, (3,))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "same sequence type"):
|
||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user