[tf.contrib.data] Add support for dicts and remove lists from nested structures.

This changes the behavior of constructors like
`tf.contrib.data.Dataset.from_tensors()` when passed a list. Previously, the
`nest` utility would recurse into each element of such a list and create a
separate Dataset component. Now the list will be converted to a tensor, allowing code like:

```python
dataset = tf.contrib.data.Dataset.from_tensor_slices(([1, 2, 3], [4, 5, 6]))
```

...to define a dataset with two components (each of shape `()`).

This change also adds support for dictionaries as nested structures, which
simplifies integration with dictionary-returning ops like `tf.parse_example()`.

Fixes #10151.

RELNOTES: Breaking change to `tf.contrib.data.Dataset` APIs that expect a
nested structure. Lists are now converted to tf.Tensor implicitly. You may need
to change uses of lists to tuples in existing code. In addition, dicts are now
supported as a nested structure.
PiperOrigin-RevId: 158139467
This commit is contained in:
Derek Murray 2017-06-06 07:47:14 -07:00 committed by TensorFlower Gardener
parent b6a8848c17
commit f3f53e8b39
18 changed files with 967 additions and 58 deletions

View File

@ -231,6 +231,7 @@ filegroup(
"//tensorflow/contrib/data/python/framework:all_files",
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/data/python/util:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/factorization:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",

View File

@ -276,6 +276,7 @@ add_python_module("tensorflow/contrib/data/python")
add_python_module("tensorflow/contrib/data/python/framework")
add_python_module("tensorflow/contrib/data/python/kernel_tests")
add_python_module("tensorflow/contrib/data/python/ops")
add_python_module("tensorflow/contrib/data/python/util")
add_python_module("tensorflow/contrib/deprecated")
add_python_module("tensorflow/contrib/distributions")
add_python_module("tensorflow/contrib/distributions/python")

View File

@ -40,9 +40,9 @@ class BatchDatasetTest(test.TestCase):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count) -> BatchDataset(batch_size).
components = [np.arange(7),
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
np.array(37.0) * np.arange(7))
count = array_ops.placeholder(dtypes.int64, shape=[])
batch_size = array_ops.placeholder(dtypes.int64, shape=[])

View File

@ -33,7 +33,7 @@ class DatasetConstructorTest(test.TestCase):
def testTensorDataset(self):
"""Test an dataset that represents a single tuple of tensors."""
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
iterator = (dataset_ops.Dataset.from_tensors(components)
.make_initializable_iterator())
@ -53,11 +53,11 @@ class DatasetConstructorTest(test.TestCase):
def testTensorSliceDataset(self):
"""Test an dataset that represents the slices from a tuple of tensors."""
components = [
components = (
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0])
]
)
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.make_initializable_iterator())
@ -76,6 +76,27 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testTensorSliceDatasetWithDict(self):
components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual(dtypes.int32, iterator.output_types["foo"])
self.assertEqual(dtypes.float32, iterator.output_types["bar"])
self.assertEqual((), iterator.output_shapes["foo"])
self.assertEqual((1,), iterator.output_shapes["bar"])
with self.test_session() as sess:
sess.run(init_op)
for i in range(3):
results = sess.run(get_next)
self.assertEqual(components["foo"][i], results["foo"])
self.assertEqual(components["bar"][i], results["bar"])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSparseTensorSliceDataset(self):
"""Test a dataset based on slices of a `tf.SparseTensor`."""
st = array_ops.sparse_placeholder(dtypes.float64)

View File

@ -30,12 +30,12 @@ from tensorflow.python.platform import test
class FilterDatasetTest(test.TestCase):
def testFilterDataset(self):
components = [
components = (
np.arange(7, dtype=np.int64),
np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
7, dtype=np.int64)[:, np.newaxis],
np.array(37.0, dtype=np.float64) * np.arange(7)
]
)
count = array_ops.placeholder(dtypes.int64, shape=[])
modulus = array_ops.placeholder(dtypes.int64)

View File

@ -33,13 +33,13 @@ class FlatMapDatasetTest(test.TestCase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
repeats = [1, 2, 3, 4, 5, 0, 1]
components = [np.array(repeats, dtype=np.int64)]
components = np.array(repeats, dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
.make_initializable_iterator())
init_op = iterator.initializer
get_next, = iterator.get_next()
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
@ -51,14 +51,14 @@ class FlatMapDatasetTest(test.TestCase):
def testNestedFlatMapDataset(self):
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
components = [np.array(repeats, dtype=np.int64)]
components = np.array(repeats, dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
.flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
.repeat(y))).make_initializable_iterator())
init_op = iterator.initializer
get_next, = iterator.get_next()
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
@ -72,15 +72,15 @@ class FlatMapDatasetTest(test.TestCase):
def testSharedResourceNestedFlatMapDataset(self):
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
components = [np.array(repeats, dtype=np.int64)]
components = np.array(repeats, dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
.flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
.repeat(y))).make_initializable_iterator(
shared_name="shared_flat_map_iterator"))
init_op = iterator.initializer
get_next, = iterator.get_next()
get_next = iterator.get_next()
# Create two concurrent sessions that share the same iterator
# resource on the same server, and verify that a random

View File

@ -47,9 +47,9 @@ class IteratorTest(test.TestCase):
gradients_impl.gradients(value, [component, side])
def testOneShotIterator(self):
components = [np.arange(7),
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
np.array(37.0) * np.arange(7))
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
@ -71,10 +71,10 @@ class IteratorTest(test.TestCase):
sess.run(get_next)
def testOneShotIteratorCaptureByValue(self):
components = [np.arange(7),
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
tensor_components = [ops.convert_to_tensor(c) for c in components]
np.array(37.0) * np.arange(7))
tensor_components = tuple([ops.convert_to_tensor(c) for c in components])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
@ -96,9 +96,9 @@ class IteratorTest(test.TestCase):
sess.run(get_next)
def testOneShotIteratorInsideContainer(self):
components = [np.arange(7),
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
np.array(37.0) * np.arange(7))
def within_container():
def _map_fn(x, y, z):
@ -129,11 +129,11 @@ class IteratorTest(test.TestCase):
sess.run(get_next)
def testSimpleSharedResource(self):
components = [
components = (
np.array(1, dtype=np.int64),
np.array([1, 2, 3], dtype=np.int64),
np.array(37.0, dtype=np.float64)
]
)
server = server_lib.Server.create_local_server()
@ -166,8 +166,8 @@ class IteratorTest(test.TestCase):
# new graph.
iterator = dataset_ops.Iterator.from_structure(
shared_name="shared_iterator",
output_types=[dtypes.int64, dtypes.int64, dtypes.float64],
output_shapes=[[], [3], []])
output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
output_shapes=([], [3], []))
get_next = iterator.get_next()
with session.Session(server.target) as sess:
@ -179,7 +179,7 @@ class IteratorTest(test.TestCase):
sess.run(get_next)
def testNotInitializedError(self):
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
iterator = (dataset_ops.Dataset.from_tensors(components)
.make_initializable_iterator())
get_next = iterator.get_next()

View File

@ -45,9 +45,9 @@ class MapDatasetTest(test.TestCase):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count).
components = [np.arange(7),
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
np.array(37.0) * np.arange(7))
count = array_ops.placeholder(dtypes.int64, shape=[])
dataset = self._buildMapDataset(components, count)
@ -107,9 +107,9 @@ class MapDatasetTest(test.TestCase):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
# RepeatDataset(count).
components = [np.arange(7),
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
np.array(37.0) * np.arange(7))
count = array_ops.placeholder(dtypes.int64, shape=[])
num_threads = array_ops.placeholder(dtypes.int32, shape=[])
output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
@ -175,9 +175,9 @@ class MapDatasetTest(test.TestCase):
def _testDisposeParallelMapDataset(self, explicit_dispose):
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(1000).
components = [np.arange(1000),
components = (np.arange(1000),
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
np.array(37.0) * np.arange(1000)]
np.array(37.0) * np.arange(1000))
dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
iterator = dataset.make_initializable_iterator()
@ -200,7 +200,7 @@ class MapDatasetTest(test.TestCase):
self._testDisposeParallelMapDataset(False)
def testParallelMapError(self):
components = [np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)]
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.check_numerics(x, "message")))
@ -230,10 +230,7 @@ class MapDatasetTest(test.TestCase):
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
input_sentences = dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([
"brain brain tank salad surgery",
"surgery brain",
]))
["brain brain tank salad surgery", "surgery brain"])
iterator = (input_sentences
.map(lambda x: string_ops.string_split([x]).values)

View File

@ -17,8 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -156,7 +154,7 @@ class RangeDatasetTest(test.TestCase):
sess.run(get_next)
def testEnumerateDataset(self):
components = [np.array(["a", "b"]), np.array([1, 2]), np.array([37.0, 38])]
components = (["a", "b"], [1, 2], [37.0, 38])
start = constant_op.constant(20, dtype=dtypes.int64)
iterator = (dataset_ops.Dataset.from_tensor_slices(components).enumerate(
@ -171,8 +169,8 @@ class RangeDatasetTest(test.TestCase):
with self.test_session() as sess:
sess.run(init_op)
self.assertEqual((20, [b"a", 1, 37.0]), sess.run(get_next))
self.assertEqual((21, [b"b", 2, 38.0]), sess.run(get_next))
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -496,5 +496,30 @@ class ReadBatchFeaturesTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
def testReadWithEquivalentDataset(self):
# TODO(mrry): Add support for tf.SparseTensor as a Dataset component.
features = {
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
}
dataset = (dataset_ops.TFRecordDataset(self.test_filenames)
.map(lambda x: parsing_ops.parse_single_example(x, features))
.repeat(10)
.batch(2))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
next_element = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for file_batch, _, _, _, record_batch in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
self.assertAllEqual(record_batch, actual_batch["record"])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
if __name__ == "__main__":
test.main()

View File

@ -30,7 +30,7 @@ class SequenceDatasetTest(test.TestCase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
# This placeholder can be fed when dataset-definition subgraph
# runs (i.e. `init_op` below) to configure the number of
# repetitions used in a particular iterator.
@ -79,7 +79,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertAllEqual(component, result_component)
def testTakeTensorDataset(self):
components = [np.arange(10)]
components = (np.arange(10),)
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
@ -125,7 +125,7 @@ class SequenceDatasetTest(test.TestCase):
sess.run(get_next)
def testSkipTensorDataset(self):
components = [np.arange(10)]
components = (np.arange(10),)
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
@ -171,7 +171,7 @@ class SequenceDatasetTest(test.TestCase):
def testRepeatRepeatTensorDataset(self):
"""Test the composition of repeat datasets."""
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
inner_count = array_ops.placeholder(dtypes.int64, shape=[])
outer_count = array_ops.placeholder(dtypes.int64, shape=[])

View File

@ -32,10 +32,10 @@ from tensorflow.python.platform import test
class ShuffleDatasetTest(test.TestCase):
def testShuffleDataset(self):
components = [
components = (
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0])
]
)
count_placeholder = array_ops.placeholder_with_default(
constant_op.constant(5, dtypes.int64), shape=[])
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
@ -47,7 +47,7 @@ class ShuffleDatasetTest(test.TestCase):
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
seed_placeholder)
self.assertEqual([c.shape[1:] for c in components],
self.assertEqual(tuple([c.shape[1:] for c in components]),
shuffle_dataset.output_shapes)
# Create initialization ops for iterators without and with
@ -132,7 +132,7 @@ class ShuffleDatasetTest(test.TestCase):
sess.run(get_next)
def testDefaultArguments(self):
components = np.array([0, 1, 2, 3, 4])
components = [0, 1, 2, 3, 4]
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
.repeat().make_one_shot_iterator())

View File

@ -35,10 +35,10 @@ class ZipDatasetTest(test.TestCase):
array_ops.placeholder(dtypes.float64)
]
datasets = [
datasets = tuple([
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
for component_placeholder in component_placeholders
]
])
zipped = dataset_ops.Dataset.zip(datasets)
iterator = zipped.make_initializable_iterator()

View File

@ -10,11 +10,11 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/framework:function",
"//tensorflow/contrib/data/python/util:nest",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
],
)

View File

@ -22,6 +22,7 @@ import abc
import numpy as np
from tensorflow.contrib.data.python.framework import function
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -38,7 +39,6 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import nest
class Iterator(object):
@ -1869,7 +1869,7 @@ def _parse_example(serialized, features):
result.extend([val.indices, val.values, val.dense_shape])
else:
result.append(val)
return result
return tuple(result)
def _get_file_names(file_pattern, randomize_input):

View File

@ -0,0 +1,44 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "nest",
srcs = ["nest.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
],
)
py_test(
name = "nest_test",
size = "small",
srcs = ["nest_test.py"],
srcs_version = "PY2AND3",
deps = [
":nest",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,513 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""## Functions for working with arbitrarily nested sequences of elements.
NOTE(mrry): This fork of the `tensorflow.python.util.nest` module
makes two changes:
1. It adds support for dictionaries as a level of nesting in nested structures.
2. It removes support for lists as a level of nesting in nested structures.
The motivation for this change is twofold:
1. Many input-processing functions (e.g. `tf.parse_example()`) return
dictionaries, and we would like to support them natively in datasets.
2. It seems more natural for lists to be treated (e.g. in Dataset constructors)
as tensors, rather than lists of (lists of...) tensors.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as _collections
import six as _six
from tensorflow.python.util.all_util import remove_undocumented
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
Args:
instance: an instance of `tuple`, `list`, or a `namedtuple` class.
args: elements to be converted to a sequence.
Returns:
`args` with the type of `instance`.
"""
if isinstance(instance, dict):
# This is a dict. Iterate over the keys in sorted order to make
# this deterministic.
return {k: v for k, v in zip(sorted(instance.keys()), args)}
elif (isinstance(instance, tuple) and
hasattr(instance, "_fields") and
isinstance(instance._fields, _collections.Sequence) and
all(isinstance(f, _six.string_types) for f in instance._fields)):
# This is a namedtuple
return type(instance)(*args)
else:
# Not a namedtuple
return type(instance)(args)
def _elements_of(nest):
if isinstance(nest, dict):
# Iterate over dict keys in sorted order to make this deterministic.
return [v for _, v in sorted(nest.items())]
else:
return nest
def _yield_flat_nest(nest):
for n in _elements_of(nest):
if is_sequence(n):
for ni in _yield_flat_nest(n):
yield ni
else:
yield n
def is_sequence(seq):
"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
which *does* treat a Python list as a sequence. For ergonomic
reasons, `tf.contrib.data` users would prefer to treat lists as
implict `tf.Tensor` objects, and dicts as (nested) sequences.
Args:
seq: an input sequence.
Returns:
True if the sequence is a not a string or list and is a
collections.Sequence.
"""
return (isinstance(seq, (_collections.Sequence, dict))
and not isinstance(seq, (list, _six.string_types)))
def flatten(nest):
"""Returns a flat sequence from a given nested structure.
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
Args:
nest: an arbitrarily nested structure or a scalar object.
Note, numpy arrays are considered scalars.
Returns:
A Python list, the flattened version of the input.
"""
return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest]
def _recursive_assert_same_structure(nest1, nest2, check_types):
is_sequence_nest1 = is_sequence(nest1)
if is_sequence_nest1 != is_sequence(nest2):
raise ValueError(
"The two structures don't have the same nested structure. "
"First structure: %s, second structure: %s." % (nest1, nest2))
if is_sequence_nest1:
type_nest1 = type(nest1)
type_nest2 = type(nest2)
if check_types and type_nest1 != type_nest2:
raise TypeError(
"The two structures don't have the same sequence type. First "
"structure has type %s, while second structure has type %s."
% (type_nest1, type_nest2))
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
_recursive_assert_same_structure(n1, n2, check_types)
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
check_types: if `True` (default) types of sequences are checked as
well. If set to `False`, for example a list and a tuple of objects will
look same if they have the same size.
Raises:
ValueError: If the two structures do not have the same number of elements or
if the two structures are not nested in the same way.
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
if len_nest1 != len_nest2:
raise ValueError("The two structures don't have the same number of "
"elements. First structure: %s, second structure: %s."
% (nest1, nest2))
_recursive_assert_same_structure(nest1, nest2, check_types)
def _packed_nest_with_indices(structure, flat, index):
"""Helper function for pack_nest_as.
Args:
structure: Substructure (tuple of elements and/or tuples) to mimic
flat: Flattened values to output substructure for.
index: Index at which to start reading from flat.
Returns:
The tuple (new_index, child), where:
* new_index - the updated index into `flat` having processed `structure`.
* packed - the subset of `flat` corresponding to `structure`,
having started at `index`, and packed into the same nested
format.
Raises:
ValueError: if `structure` contains more elements than `flat`
(assuming indexing starts from `index`).
"""
packed = []
for s in structure:
if is_sequence(s):
new_index, child = _packed_nest_with_indices(s, flat, index)
packed.append(_sequence_like(s, child))
index = new_index
else:
packed.append(flat[index])
index += 1
return index, packed
def pack_sequence_as(structure, flat_sequence):
"""Returns a given flattened sequence packed into a nest.
If `structure` is a scalar, `flat_sequence` must be a single-element list;
in this case the return value is `flat_sequence[0]`.
Args:
structure: tuple or list constructed of scalars and/or other tuples/lists,
or a scalar. Note: numpy arrays are considered scalars.
flat_sequence: flat sequence to pack.
Returns:
packed: `flat_sequence` converted to have the same recursive structure as
`structure`.
Raises:
ValueError: If nest and structure have different element counts.
"""
if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)):
raise TypeError("flat_sequence must be a sequence")
if not is_sequence(structure):
if len(flat_sequence) != 1:
raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1"
% len(flat_sequence))
return flat_sequence[0]
flat_structure = flatten(structure)
if len(flat_structure) != len(flat_sequence):
raise ValueError(
"Could not pack sequence. Structure had %d elements, but flat_sequence "
"had %d elements. Structure: %s, flat_sequence: %s."
% (len(flat_structure), len(flat_sequence), structure, flat_sequence))
_, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
return _sequence_like(structure, packed)
def map_structure(func, *structure, **check_types_dict):
"""Applies `func` to each entry in `structure` and returns a new structure.
Applies `func(x[0], x[1], ...)` where x[i] is an entry in
`structure[i]`. All structures in `structure` must have the same arity,
and the return value will contain the results in the same structure.
Args:
func: A callable that acceps as many arguments are there are structures.
*structure: scalar, or tuple or list of constructed scalars and/or other
tuples/lists, or scalars. Note: numpy arrays are considered scalars.
**check_types_dict: only valid keyword argument is `check_types`. If set to
`True` (default) the types of iterables within the structures have to be
same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
exception). To allow this set this argument to `False`.
Returns:
A new structure with the same arity as `structure`, whose values correspond
to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
location in `structure[i]`. If there are different sequence types and
`check_types` is `False` the sequence types of the first structure will be
used.
Raises:
TypeError: If `func` is not callable or if the structures do not match
each other by depth tree.
ValueError: If no structure is provided or if the structures do not match
each other by type.
ValueError: If wrong keyword arguments are provided.
"""
if not callable(func):
raise TypeError("func must be callable, got: %s" % func)
if not structure:
raise ValueError("Must provide at least one structure")
if check_types_dict:
if "check_types" not in check_types_dict or len(check_types_dict) > 1:
raise ValueError("Only valid keyword argument is check_types")
check_types = check_types_dict["check_types"]
else:
check_types = True
for other in structure[1:]:
assert_same_structure(structure[0], other, check_types=check_types)
flat_structure = [flatten(s) for s in structure]
entries = zip(*flat_structure)
return pack_sequence_as(
structure[0], [func(*x) for x in entries])
def _yield_flat_up_to(shallow_tree, input_tree):
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
if is_sequence(shallow_tree):
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
yield input_leaf
else:
yield input_tree
def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
That is, this function tests if the `input_tree` structure can be created from
the `shallow_tree` structure by replacing its leaf nodes with deeper
tree structures.
Examples:
The following code will raise an exception:
```python
shallow_tree = ["a", "b"]
input_tree = ["c", ["d", "e"], "f"]
assert_shallow_structure(shallow_tree, input_tree)
```
The following code will not raise an exception:
```python
shallow_tree = ["a", "b"]
input_tree = ["c", ["d", "e"]]
assert_shallow_structure(shallow_tree, input_tree)
```
Args:
shallow_tree: an arbitrarily nested structure.
input_tree: an arbitrarily nested structure.
check_types: if `True` (default) the sequence types of `shallow_tree` and
`input_tree` have to be the same.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
TypeError: If the sequence types of `shallow_tree` are different from
`input_tree`. Only raised if `check_types` is `True`.
ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`.
"""
if is_sequence(shallow_tree):
if not is_sequence(input_tree):
raise TypeError(
"If shallow structure is a sequence, input must also be a sequence. "
"Input has type: %s." % type(input_tree))
if check_types and not isinstance(input_tree, type(shallow_tree)):
raise TypeError(
"The two structures don't have the same sequence type. Input "
"structure has type %s, while shallow structure has type %s."
% (type(input_tree), type(shallow_tree)))
if len(input_tree) != len(shallow_tree):
raise ValueError(
"The two structures don't have the same sequence length. Input "
"structure has length %s, while shallow structure has length %s."
% (len(input_tree), len(shallow_tree)))
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
assert_shallow_structure(shallow_branch, input_branch,
check_types=check_types)
def flatten_up_to(shallow_tree, input_tree):
"""Flattens `input_tree` up to `shallow_tree`.
Any further depth in structure in `input_tree` is retained as elements in the
partially flatten output.
If `shallow_tree` and `input_tree` are not sequences, this returns a
single-element list: `[input_tree]`.
Use Case:
Sometimes we may wish to partially flatten a nested sequence, retaining some
of the nested structure. We achieve this by specifying a shallow structure,
`shallow_tree`, we wish to flatten up to.
The input, `input_tree`, can be thought of as having the same structure as
`shallow_tree`, but with leaf nodes that are themselves tree structures.
Examples:
```python
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
# Output is:
# [[2, 2], [3, 3], [4, 9], [5, 5]]
# [True, True, False, True]
```
```python
input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
input_tree_flattened = flatten(input_tree)
# Output is:
# [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
```
Non-Sequence Edge Cases:
```python
flatten_up_to(0, 0) # Output: [0]
flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
flatten_up_to([0, 1, 2], 0) # Output: TypeError
flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
```
Args:
shallow_tree: a possibly pruned structure of input_tree.
input_tree: an arbitrarily nested structure or a scalar object.
Note, numpy arrays are considered scalars.
Returns:
A Python list, the partially flattened version of `input_tree` according to
the structure of `shallow_tree`.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
TypeError: If the sequence types of `shallow_tree` are different from
`input_tree`.
ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`.
"""
assert_shallow_structure(shallow_tree, input_tree)
return list(_yield_flat_up_to(shallow_tree, input_tree))
def map_structure_up_to(shallow_tree, func, *inputs):
"""Applies a function or op to a number of partially flattened inputs.
The `inputs` are flattened up to `shallow_tree` before being mapped.
Use Case:
Sometimes we wish to apply a function to a partially flattened
sequence (for example when the function itself takes sequence inputs). We
achieve this by specifying a shallow structure, `shallow_tree` we wish to
flatten up to.
The `inputs`, can be thought of as having the same structure as
`shallow_tree`, but with leaf nodes that are themselves tree structures.
This function therefore will return something with the same base structure as
`shallow_tree`.
Examples:
```python
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
inp_val, inp_ops)
# Output is: ab_tuple(a=6, b=15)
```
```python
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ['evens', ['odds', 'primes']]
out = map_structure_up_to(
name_list,
lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
```
Args:
shallow_tree: a shallow tree, common to all the inputs.
func: callable which will be applied to each input individually.
*inputs: arbitrarily nested combination of objects that are compatible with
shallow_tree. The function `func` is applied to corresponding
partially flattened elements of each input, so the function must support
arity of `len(inputs)`.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
TypeError: If the sequence types of `shallow_tree` are different from
`input_tree`.
ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`.
Returns:
result of repeatedly applying `func`, with same structure as
`shallow_tree`.
"""
if not inputs:
raise ValueError("Cannot map over no sequences")
for input_tree in inputs:
assert_shallow_structure(shallow_tree, input_tree)
# Flatten each input separately, apply the function to corresponding elements,
# then repack based on the structure of the first input.
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
for input_tree in inputs]
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
_allowed_symbols = [
"assert_same_structure",
"is_sequence",
"flatten",
"pack_sequence_as",
"map_structure",
"assert_shallow_structure",
"flatten_up_to",
"map_structure_up_to",
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,309 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with arbitrarily nested structures."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class NestTest(test.TestCase):
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
self.assertFalse(nest.is_sequence([]))
self.assertFalse(nest.is_sequence(set([1, 2])))
ones = array_ops.ones([2, 3])
self.assertFalse(nest.is_sequence(ones))
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))
self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2}))
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
with self.assertRaisesRegexp(ValueError,
"don't have the same number of elements"):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(ValueError,
"don't have the same number of elements"):
nest.assert_same_structure((0, 1), np.array([0, 1]))
with self.assertRaisesRegexp(ValueError,
"don't have the same number of elements"):
nest.assert_same_structure(0, (0, 1))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(structure1, structure_different_nesting)
named_type_0 = collections.namedtuple("named_0", ("a", "b"))
named_type_1 = collections.namedtuple("named_1", ("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
named_type_0("a", "b"))
nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
named_type_0(3, 4), named_type_1(3, 4))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(((3,), 4), (3, (4,)))
structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, foo="a")
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
def testAssertShallowStructure(self):
inp_ab = ("a", "b")
inp_abc = ("a", "b", "c")
expected_message = (
"The two structures don't have the same sequence length. Input "
"structure has length 2, while shallow structure has length 3.")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_abc, inp_ab)
inp_ab1 = ((1, 1), (2, 2))
inp_ab2 = {"a": (1, 1), "b": (2, 2)}
expected_message = (
"The two structures don't have the same sequence type. Input structure "
"has type <(type|class) 'tuple'>, while shallow structure has type "
"<(type|class) 'dict'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
def testFlattenUpTo(self):
input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
shallow_tree = ((True, True), (False, True))
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
input_tree_flattened = nest.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = ("input_tree_0", "input_tree_1")
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = (0,)
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = (0, 1)
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ("shallow_tree",)
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'str'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
input_tree = "input_tree"
shallow_tree = ("shallow_tree_9", "shallow_tree_8")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
# Using non-iterable elements.
input_tree = 0
shallow_tree = (9,)
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'int'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
input_tree = 0
shallow_tree = (9, 8)
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
def testMapStructureUpTo(self):
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)
data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7)))
name_list = ("evens", ("odds", "primes"))
out = nest.map_structure_up_to(
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes")))
if __name__ == "__main__":
test.main()