Add nest.map_structure.

Useful for simplifying a number of flatten -> map -> recreate structure calls.
Change: 144388829
This commit is contained in:
Eugene Brevdo 2017-01-12 16:48:05 -08:00 committed by TensorFlower Gardener
parent 9914eaba62
commit 1013cee707
2 changed files with 70 additions and 0 deletions
tensorflow/python/util

View File

@ -26,6 +26,7 @@ should be recursive.
@@flatten
@@flatten_dict_items
@@pack_sequence_as
@@map_structure
"""
from __future__ import absolute_import
@ -260,3 +261,42 @@ def pack_sequence_as(structure, flat_sequence):
_, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
return _sequence_like(structure, packed)
def map_structure(func, *structure):
"""Applies `func` to each entry in `structure` and returns a new structure.
Applies `func(x[0], x[1], ...)` where x[i] is an entry in
`structure[i]`. All structures in `structure` must have the same arity,
and the return value will contain the results in the same structure.
Args:
func: A callable that acceps as many arguments are there are structures.
*structure: scalar, or tuple or list of constructed scalars and/or other
tuples/lists, or scalars. Note: numpy arrays are considered scalars.
Returns:
A new structure with the same arity as `structure`, whose values correspond
to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
location in `structure[i]`.
Raises:
TypeError: If `func` is not callable or if the structures do not match
each other by depth tree.
ValueError: If no structure is provided or if the structures do not match
each other by type.
"""
if not callable(func):
raise TypeError("func must be callable, got: %s" % func)
if not structure:
raise ValueError("Must provide at least one structure")
for other in structure[1:]:
assert_same_structure(structure[0], other)
flat_structure = [flatten(s) for s in structure]
entries = zip(*flat_structure)
return pack_sequence_as(
structure[0], [func(*x) for x in entries])

View File

@ -139,6 +139,36 @@ class NestTest(test.TestCase):
"don't have the same nested structure"):
nest.assert_same_structure([[3], 4], [3, [4]])
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
if __name__ == "__main__":
test.main()