[tf.contrib.data] Add support for dicts and remove lists from nested structures.
This changes the behavior of constructors like `tf.contrib.data.Dataset.from_tensors()` when passed a list. Previously, the `nest` utility would recurse into each element of such a list and create a separate Dataset component. Now the list will be converted to a tensor, allowing code like: ```python dataset = tf.contrib.data.Dataset.from_tensor_slices(([1, 2, 3], [4, 5, 6])) ``` ...to define a dataset with two components (each of shape `()`). This change also adds support for dictionaries as nested structures, which simplifies integration with dictionary-returning ops like `tf.parse_example()`. Fixes #10151. RELNOTES: Breaking change to `tf.contrib.data.Dataset` APIs that expect a nested structure. Lists are now converted to tf.Tensor implicitly. You may need to change uses of lists to tuples in existing code. In addition, dicts are now supported as a nested structure. PiperOrigin-RevId: 158139467
This commit is contained in:
parent
b6a8848c17
commit
f3f53e8b39
@ -231,6 +231,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/data/python/framework:all_files",
|
"//tensorflow/contrib/data/python/framework:all_files",
|
||||||
"//tensorflow/contrib/data/python/kernel_tests:all_files",
|
"//tensorflow/contrib/data/python/kernel_tests:all_files",
|
||||||
"//tensorflow/contrib/data/python/ops:all_files",
|
"//tensorflow/contrib/data/python/ops:all_files",
|
||||||
|
"//tensorflow/contrib/data/python/util:all_files",
|
||||||
"//tensorflow/contrib/distributions:all_files",
|
"//tensorflow/contrib/distributions:all_files",
|
||||||
"//tensorflow/contrib/factorization:all_files",
|
"//tensorflow/contrib/factorization:all_files",
|
||||||
"//tensorflow/contrib/factorization/kernels:all_files",
|
"//tensorflow/contrib/factorization/kernels:all_files",
|
||||||
|
@ -276,6 +276,7 @@ add_python_module("tensorflow/contrib/data/python")
|
|||||||
add_python_module("tensorflow/contrib/data/python/framework")
|
add_python_module("tensorflow/contrib/data/python/framework")
|
||||||
add_python_module("tensorflow/contrib/data/python/kernel_tests")
|
add_python_module("tensorflow/contrib/data/python/kernel_tests")
|
||||||
add_python_module("tensorflow/contrib/data/python/ops")
|
add_python_module("tensorflow/contrib/data/python/ops")
|
||||||
|
add_python_module("tensorflow/contrib/data/python/util")
|
||||||
add_python_module("tensorflow/contrib/deprecated")
|
add_python_module("tensorflow/contrib/deprecated")
|
||||||
add_python_module("tensorflow/contrib/distributions")
|
add_python_module("tensorflow/contrib/distributions")
|
||||||
add_python_module("tensorflow/contrib/distributions/python")
|
add_python_module("tensorflow/contrib/distributions/python")
|
||||||
|
@ -40,9 +40,9 @@ class BatchDatasetTest(test.TestCase):
|
|||||||
"""Test an dataset that maps a TF function across its input elements."""
|
"""Test an dataset that maps a TF function across its input elements."""
|
||||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||||
# RepeatDataset(count) -> BatchDataset(batch_size).
|
# RepeatDataset(count) -> BatchDataset(batch_size).
|
||||||
components = [np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(7)]
|
np.array(37.0) * np.arange(7))
|
||||||
|
|
||||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
@ -33,7 +33,7 @@ class DatasetConstructorTest(test.TestCase):
|
|||||||
|
|
||||||
def testTensorDataset(self):
|
def testTensorDataset(self):
|
||||||
"""Test an dataset that represents a single tuple of tensors."""
|
"""Test an dataset that represents a single tuple of tensors."""
|
||||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||||
|
|
||||||
iterator = (dataset_ops.Dataset.from_tensors(components)
|
iterator = (dataset_ops.Dataset.from_tensors(components)
|
||||||
.make_initializable_iterator())
|
.make_initializable_iterator())
|
||||||
@ -53,11 +53,11 @@ class DatasetConstructorTest(test.TestCase):
|
|||||||
|
|
||||||
def testTensorSliceDataset(self):
|
def testTensorSliceDataset(self):
|
||||||
"""Test an dataset that represents the slices from a tuple of tensors."""
|
"""Test an dataset that represents the slices from a tuple of tensors."""
|
||||||
components = [
|
components = (
|
||||||
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
|
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
|
||||||
np.array([[12], [13], [14], [15]]), 22),
|
np.array([[12], [13], [14], [15]]), 22),
|
||||||
np.array([37.0, 38.0, 39.0, 40.0])
|
np.array([37.0, 38.0, 39.0, 40.0])
|
||||||
]
|
)
|
||||||
|
|
||||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
.make_initializable_iterator())
|
.make_initializable_iterator())
|
||||||
@ -76,6 +76,27 @@ class DatasetConstructorTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
def testTensorSliceDatasetWithDict(self):
|
||||||
|
components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
|
||||||
|
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
|
.make_initializable_iterator())
|
||||||
|
init_op = iterator.initializer
|
||||||
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
|
self.assertEqual(dtypes.int32, iterator.output_types["foo"])
|
||||||
|
self.assertEqual(dtypes.float32, iterator.output_types["bar"])
|
||||||
|
self.assertEqual((), iterator.output_shapes["foo"])
|
||||||
|
self.assertEqual((1,), iterator.output_shapes["bar"])
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(init_op)
|
||||||
|
for i in range(3):
|
||||||
|
results = sess.run(get_next)
|
||||||
|
self.assertEqual(components["foo"][i], results["foo"])
|
||||||
|
self.assertEqual(components["bar"][i], results["bar"])
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
def testSparseTensorSliceDataset(self):
|
def testSparseTensorSliceDataset(self):
|
||||||
"""Test a dataset based on slices of a `tf.SparseTensor`."""
|
"""Test a dataset based on slices of a `tf.SparseTensor`."""
|
||||||
st = array_ops.sparse_placeholder(dtypes.float64)
|
st = array_ops.sparse_placeholder(dtypes.float64)
|
||||||
|
@ -30,12 +30,12 @@ from tensorflow.python.platform import test
|
|||||||
class FilterDatasetTest(test.TestCase):
|
class FilterDatasetTest(test.TestCase):
|
||||||
|
|
||||||
def testFilterDataset(self):
|
def testFilterDataset(self):
|
||||||
components = [
|
components = (
|
||||||
np.arange(7, dtype=np.int64),
|
np.arange(7, dtype=np.int64),
|
||||||
np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
|
np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
|
||||||
7, dtype=np.int64)[:, np.newaxis],
|
7, dtype=np.int64)[:, np.newaxis],
|
||||||
np.array(37.0, dtype=np.float64) * np.arange(7)
|
np.array(37.0, dtype=np.float64) * np.arange(7)
|
||||||
]
|
)
|
||||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
modulus = array_ops.placeholder(dtypes.int64)
|
modulus = array_ops.placeholder(dtypes.int64)
|
||||||
|
|
||||||
|
@ -33,13 +33,13 @@ class FlatMapDatasetTest(test.TestCase):
|
|||||||
# pylint: disable=g-long-lambda
|
# pylint: disable=g-long-lambda
|
||||||
def testFlatMapDataset(self):
|
def testFlatMapDataset(self):
|
||||||
repeats = [1, 2, 3, 4, 5, 0, 1]
|
repeats = [1, 2, 3, 4, 5, 0, 1]
|
||||||
components = [np.array(repeats, dtype=np.int64)]
|
components = np.array(repeats, dtype=np.int64)
|
||||||
iterator = (
|
iterator = (
|
||||||
dataset_ops.Dataset.from_tensor_slices(components)
|
dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
.flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
|
.flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
|
||||||
.make_initializable_iterator())
|
.make_initializable_iterator())
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next, = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
@ -51,14 +51,14 @@ class FlatMapDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
def testNestedFlatMapDataset(self):
|
def testNestedFlatMapDataset(self):
|
||||||
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
||||||
components = [np.array(repeats, dtype=np.int64)]
|
components = np.array(repeats, dtype=np.int64)
|
||||||
iterator = (
|
iterator = (
|
||||||
dataset_ops.Dataset.from_tensor_slices(components)
|
dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
|
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
|
||||||
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
|
.flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
|
||||||
.repeat(y))).make_initializable_iterator())
|
.repeat(y))).make_initializable_iterator())
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next, = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
@ -72,15 +72,15 @@ class FlatMapDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
def testSharedResourceNestedFlatMapDataset(self):
|
def testSharedResourceNestedFlatMapDataset(self):
|
||||||
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
||||||
components = [np.array(repeats, dtype=np.int64)]
|
components = np.array(repeats, dtype=np.int64)
|
||||||
iterator = (
|
iterator = (
|
||||||
dataset_ops.Dataset.from_tensor_slices(components)
|
dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
|
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
|
||||||
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
|
.flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
|
||||||
.repeat(y))).make_initializable_iterator(
|
.repeat(y))).make_initializable_iterator(
|
||||||
shared_name="shared_flat_map_iterator"))
|
shared_name="shared_flat_map_iterator"))
|
||||||
init_op = iterator.initializer
|
init_op = iterator.initializer
|
||||||
get_next, = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
# Create two concurrent sessions that share the same iterator
|
# Create two concurrent sessions that share the same iterator
|
||||||
# resource on the same server, and verify that a random
|
# resource on the same server, and verify that a random
|
||||||
|
@ -47,9 +47,9 @@ class IteratorTest(test.TestCase):
|
|||||||
gradients_impl.gradients(value, [component, side])
|
gradients_impl.gradients(value, [component, side])
|
||||||
|
|
||||||
def testOneShotIterator(self):
|
def testOneShotIterator(self):
|
||||||
components = [np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(7)]
|
np.array(37.0) * np.arange(7))
|
||||||
|
|
||||||
def _map_fn(x, y, z):
|
def _map_fn(x, y, z):
|
||||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||||
@ -71,10 +71,10 @@ class IteratorTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testOneShotIteratorCaptureByValue(self):
|
def testOneShotIteratorCaptureByValue(self):
|
||||||
components = [np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(7)]
|
np.array(37.0) * np.arange(7))
|
||||||
tensor_components = [ops.convert_to_tensor(c) for c in components]
|
tensor_components = tuple([ops.convert_to_tensor(c) for c in components])
|
||||||
|
|
||||||
def _map_fn(x, y, z):
|
def _map_fn(x, y, z):
|
||||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||||
@ -96,9 +96,9 @@ class IteratorTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testOneShotIteratorInsideContainer(self):
|
def testOneShotIteratorInsideContainer(self):
|
||||||
components = [np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(7)]
|
np.array(37.0) * np.arange(7))
|
||||||
|
|
||||||
def within_container():
|
def within_container():
|
||||||
def _map_fn(x, y, z):
|
def _map_fn(x, y, z):
|
||||||
@ -129,11 +129,11 @@ class IteratorTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testSimpleSharedResource(self):
|
def testSimpleSharedResource(self):
|
||||||
components = [
|
components = (
|
||||||
np.array(1, dtype=np.int64),
|
np.array(1, dtype=np.int64),
|
||||||
np.array([1, 2, 3], dtype=np.int64),
|
np.array([1, 2, 3], dtype=np.int64),
|
||||||
np.array(37.0, dtype=np.float64)
|
np.array(37.0, dtype=np.float64)
|
||||||
]
|
)
|
||||||
|
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
|
|
||||||
@ -166,8 +166,8 @@ class IteratorTest(test.TestCase):
|
|||||||
# new graph.
|
# new graph.
|
||||||
iterator = dataset_ops.Iterator.from_structure(
|
iterator = dataset_ops.Iterator.from_structure(
|
||||||
shared_name="shared_iterator",
|
shared_name="shared_iterator",
|
||||||
output_types=[dtypes.int64, dtypes.int64, dtypes.float64],
|
output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
|
||||||
output_shapes=[[], [3], []])
|
output_shapes=([], [3], []))
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
with session.Session(server.target) as sess:
|
with session.Session(server.target) as sess:
|
||||||
@ -179,7 +179,7 @@ class IteratorTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testNotInitializedError(self):
|
def testNotInitializedError(self):
|
||||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||||
iterator = (dataset_ops.Dataset.from_tensors(components)
|
iterator = (dataset_ops.Dataset.from_tensors(components)
|
||||||
.make_initializable_iterator())
|
.make_initializable_iterator())
|
||||||
get_next = iterator.get_next()
|
get_next = iterator.get_next()
|
||||||
|
@ -45,9 +45,9 @@ class MapDatasetTest(test.TestCase):
|
|||||||
"""Test an dataset that maps a TF function across its input elements."""
|
"""Test an dataset that maps a TF function across its input elements."""
|
||||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||||
# RepeatDataset(count).
|
# RepeatDataset(count).
|
||||||
components = [np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(7)]
|
np.array(37.0) * np.arange(7))
|
||||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
|
||||||
dataset = self._buildMapDataset(components, count)
|
dataset = self._buildMapDataset(components, count)
|
||||||
@ -107,9 +107,9 @@ class MapDatasetTest(test.TestCase):
|
|||||||
"""Test an dataset that maps a TF function across its input elements."""
|
"""Test an dataset that maps a TF function across its input elements."""
|
||||||
# The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
|
# The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
|
||||||
# RepeatDataset(count).
|
# RepeatDataset(count).
|
||||||
components = [np.arange(7),
|
components = (np.arange(7),
|
||||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(7)]
|
np.array(37.0) * np.arange(7))
|
||||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
num_threads = array_ops.placeholder(dtypes.int32, shape=[])
|
num_threads = array_ops.placeholder(dtypes.int32, shape=[])
|
||||||
output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
|
output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
@ -175,9 +175,9 @@ class MapDatasetTest(test.TestCase):
|
|||||||
def _testDisposeParallelMapDataset(self, explicit_dispose):
|
def _testDisposeParallelMapDataset(self, explicit_dispose):
|
||||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||||
# RepeatDataset(1000).
|
# RepeatDataset(1000).
|
||||||
components = [np.arange(1000),
|
components = (np.arange(1000),
|
||||||
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
|
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
|
||||||
np.array(37.0) * np.arange(1000)]
|
np.array(37.0) * np.arange(1000))
|
||||||
|
|
||||||
dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
|
dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
|
||||||
iterator = dataset.make_initializable_iterator()
|
iterator = dataset.make_initializable_iterator()
|
||||||
@ -200,7 +200,7 @@ class MapDatasetTest(test.TestCase):
|
|||||||
self._testDisposeParallelMapDataset(False)
|
self._testDisposeParallelMapDataset(False)
|
||||||
|
|
||||||
def testParallelMapError(self):
|
def testParallelMapError(self):
|
||||||
components = [np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)]
|
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||||
|
|
||||||
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
.map(lambda x: array_ops.check_numerics(x, "message")))
|
.map(lambda x: array_ops.check_numerics(x, "message")))
|
||||||
@ -230,10 +230,7 @@ class MapDatasetTest(test.TestCase):
|
|||||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||||
|
|
||||||
input_sentences = dataset_ops.Dataset.from_tensor_slices(
|
input_sentences = dataset_ops.Dataset.from_tensor_slices(
|
||||||
constant_op.constant([
|
["brain brain tank salad surgery", "surgery brain"])
|
||||||
"brain brain tank salad surgery",
|
|
||||||
"surgery brain",
|
|
||||||
]))
|
|
||||||
|
|
||||||
iterator = (input_sentences
|
iterator = (input_sentences
|
||||||
.map(lambda x: string_ops.string_split([x]).values)
|
.map(lambda x: string_ops.string_split([x]).values)
|
||||||
|
@ -17,8 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -156,7 +154,7 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testEnumerateDataset(self):
|
def testEnumerateDataset(self):
|
||||||
components = [np.array(["a", "b"]), np.array([1, 2]), np.array([37.0, 38])]
|
components = (["a", "b"], [1, 2], [37.0, 38])
|
||||||
start = constant_op.constant(20, dtype=dtypes.int64)
|
start = constant_op.constant(20, dtype=dtypes.int64)
|
||||||
|
|
||||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).enumerate(
|
iterator = (dataset_ops.Dataset.from_tensor_slices(components).enumerate(
|
||||||
@ -171,8 +169,8 @@ class RangeDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
self.assertEqual((20, [b"a", 1, 37.0]), sess.run(get_next))
|
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
|
||||||
self.assertEqual((21, [b"b", 2, 38.0]), sess.run(get_next))
|
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
|
||||||
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
@ -496,5 +496,30 @@ class ReadBatchFeaturesTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
self._next_actual_batch(sess)
|
self._next_actual_batch(sess)
|
||||||
|
|
||||||
|
def testReadWithEquivalentDataset(self):
|
||||||
|
# TODO(mrry): Add support for tf.SparseTensor as a Dataset component.
|
||||||
|
features = {
|
||||||
|
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||||
|
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||||
|
}
|
||||||
|
dataset = (dataset_ops.TFRecordDataset(self.test_filenames)
|
||||||
|
.map(lambda x: parsing_ops.parse_single_example(x, features))
|
||||||
|
.repeat(10)
|
||||||
|
.batch(2))
|
||||||
|
iterator = dataset.make_initializable_iterator()
|
||||||
|
init_op = iterator.initializer
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(init_op)
|
||||||
|
for file_batch, _, _, _, record_batch in self._next_expected_batch(
|
||||||
|
range(self._num_files), 2, 10):
|
||||||
|
actual_batch = sess.run(next_element)
|
||||||
|
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||||
|
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(next_element)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -30,7 +30,7 @@ class SequenceDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
def testRepeatTensorDataset(self):
|
def testRepeatTensorDataset(self):
|
||||||
"""Test a dataset that repeats its input multiple times."""
|
"""Test a dataset that repeats its input multiple times."""
|
||||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||||
# This placeholder can be fed when dataset-definition subgraph
|
# This placeholder can be fed when dataset-definition subgraph
|
||||||
# runs (i.e. `init_op` below) to configure the number of
|
# runs (i.e. `init_op` below) to configure the number of
|
||||||
# repetitions used in a particular iterator.
|
# repetitions used in a particular iterator.
|
||||||
@ -79,7 +79,7 @@ class SequenceDatasetTest(test.TestCase):
|
|||||||
self.assertAllEqual(component, result_component)
|
self.assertAllEqual(component, result_component)
|
||||||
|
|
||||||
def testTakeTensorDataset(self):
|
def testTakeTensorDataset(self):
|
||||||
components = [np.arange(10)]
|
components = (np.arange(10),)
|
||||||
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
|
||||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
@ -125,7 +125,7 @@ class SequenceDatasetTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testSkipTensorDataset(self):
|
def testSkipTensorDataset(self):
|
||||||
components = [np.arange(10)]
|
components = (np.arange(10),)
|
||||||
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
|
||||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||||
@ -171,7 +171,7 @@ class SequenceDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
def testRepeatRepeatTensorDataset(self):
|
def testRepeatRepeatTensorDataset(self):
|
||||||
"""Test the composition of repeat datasets."""
|
"""Test the composition of repeat datasets."""
|
||||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
|
||||||
inner_count = array_ops.placeholder(dtypes.int64, shape=[])
|
inner_count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
outer_count = array_ops.placeholder(dtypes.int64, shape=[])
|
outer_count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
|
||||||
|
@ -32,10 +32,10 @@ from tensorflow.python.platform import test
|
|||||||
class ShuffleDatasetTest(test.TestCase):
|
class ShuffleDatasetTest(test.TestCase):
|
||||||
|
|
||||||
def testShuffleDataset(self):
|
def testShuffleDataset(self):
|
||||||
components = [
|
components = (
|
||||||
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
|
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
|
||||||
np.array([9.0, 10.0, 11.0, 12.0])
|
np.array([9.0, 10.0, 11.0, 12.0])
|
||||||
]
|
)
|
||||||
count_placeholder = array_ops.placeholder_with_default(
|
count_placeholder = array_ops.placeholder_with_default(
|
||||||
constant_op.constant(5, dtypes.int64), shape=[])
|
constant_op.constant(5, dtypes.int64), shape=[])
|
||||||
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
@ -47,7 +47,7 @@ class ShuffleDatasetTest(test.TestCase):
|
|||||||
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
|
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
|
||||||
seed_placeholder)
|
seed_placeholder)
|
||||||
|
|
||||||
self.assertEqual([c.shape[1:] for c in components],
|
self.assertEqual(tuple([c.shape[1:] for c in components]),
|
||||||
shuffle_dataset.output_shapes)
|
shuffle_dataset.output_shapes)
|
||||||
|
|
||||||
# Create initialization ops for iterators without and with
|
# Create initialization ops for iterators without and with
|
||||||
@ -132,7 +132,7 @@ class ShuffleDatasetTest(test.TestCase):
|
|||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
def testDefaultArguments(self):
|
def testDefaultArguments(self):
|
||||||
components = np.array([0, 1, 2, 3, 4])
|
components = [0, 1, 2, 3, 4]
|
||||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
|
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
|
||||||
.repeat().make_one_shot_iterator())
|
.repeat().make_one_shot_iterator())
|
||||||
|
|
||||||
|
@ -35,10 +35,10 @@ class ZipDatasetTest(test.TestCase):
|
|||||||
array_ops.placeholder(dtypes.float64)
|
array_ops.placeholder(dtypes.float64)
|
||||||
]
|
]
|
||||||
|
|
||||||
datasets = [
|
datasets = tuple([
|
||||||
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
|
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
|
||||||
for component_placeholder in component_placeholders
|
for component_placeholder in component_placeholders
|
||||||
]
|
])
|
||||||
zipped = dataset_ops.Dataset.zip(datasets)
|
zipped = dataset_ops.Dataset.zip(datasets)
|
||||||
|
|
||||||
iterator = zipped.make_initializable_iterator()
|
iterator = zipped.make_initializable_iterator()
|
||||||
|
@ -10,11 +10,11 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/contrib/data/python/framework:function",
|
"//tensorflow/contrib/data/python/framework:function",
|
||||||
|
"//tensorflow/contrib/data/python/util:nest",
|
||||||
"//tensorflow/contrib/util:util_py",
|
"//tensorflow/contrib/util:util_py",
|
||||||
"//tensorflow/python:dataset_ops_gen",
|
"//tensorflow/python:dataset_ops_gen",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:parsing_ops",
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:util",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ import abc
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.data.python.framework import function
|
from tensorflow.contrib.data.python.framework import function
|
||||||
|
from tensorflow.contrib.data.python.util import nest
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -38,7 +39,6 @@ from tensorflow.python.ops import parsing_ops
|
|||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
|
|
||||||
class Iterator(object):
|
class Iterator(object):
|
||||||
@ -1869,7 +1869,7 @@ def _parse_example(serialized, features):
|
|||||||
result.extend([val.indices, val.values, val.dense_shape])
|
result.extend([val.indices, val.values, val.dense_shape])
|
||||||
else:
|
else:
|
||||||
result.append(val)
|
result.append(val)
|
||||||
return result
|
return tuple(result)
|
||||||
|
|
||||||
|
|
||||||
def _get_file_names(file_pattern, randomize_input):
|
def _get_file_names(file_pattern, randomize_input):
|
||||||
|
44
tensorflow/contrib/data/python/util/BUILD
Normal file
44
tensorflow/contrib/data/python/util/BUILD
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "nest",
|
||||||
|
srcs = ["nest.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "nest_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["nest_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":nest",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
513
tensorflow/contrib/data/python/util/nest.py
Normal file
513
tensorflow/contrib/data/python/util/nest.py
Normal file
@ -0,0 +1,513 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""## Functions for working with arbitrarily nested sequences of elements.
|
||||||
|
|
||||||
|
NOTE(mrry): This fork of the `tensorflow.python.util.nest` module
|
||||||
|
makes two changes:
|
||||||
|
|
||||||
|
1. It adds support for dictionaries as a level of nesting in nested structures.
|
||||||
|
2. It removes support for lists as a level of nesting in nested structures.
|
||||||
|
|
||||||
|
The motivation for this change is twofold:
|
||||||
|
|
||||||
|
1. Many input-processing functions (e.g. `tf.parse_example()`) return
|
||||||
|
dictionaries, and we would like to support them natively in datasets.
|
||||||
|
2. It seems more natural for lists to be treated (e.g. in Dataset constructors)
|
||||||
|
as tensors, rather than lists of (lists of...) tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections as _collections
|
||||||
|
|
||||||
|
import six as _six
|
||||||
|
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
|
||||||
|
def _sequence_like(instance, args):
|
||||||
|
"""Converts the sequence `args` to the same type as `instance`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: an instance of `tuple`, `list`, or a `namedtuple` class.
|
||||||
|
args: elements to be converted to a sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`args` with the type of `instance`.
|
||||||
|
"""
|
||||||
|
if isinstance(instance, dict):
|
||||||
|
# This is a dict. Iterate over the keys in sorted order to make
|
||||||
|
# this deterministic.
|
||||||
|
return {k: v for k, v in zip(sorted(instance.keys()), args)}
|
||||||
|
elif (isinstance(instance, tuple) and
|
||||||
|
hasattr(instance, "_fields") and
|
||||||
|
isinstance(instance._fields, _collections.Sequence) and
|
||||||
|
all(isinstance(f, _six.string_types) for f in instance._fields)):
|
||||||
|
# This is a namedtuple
|
||||||
|
return type(instance)(*args)
|
||||||
|
else:
|
||||||
|
# Not a namedtuple
|
||||||
|
return type(instance)(args)
|
||||||
|
|
||||||
|
|
||||||
|
def _elements_of(nest):
|
||||||
|
if isinstance(nest, dict):
|
||||||
|
# Iterate over dict keys in sorted order to make this deterministic.
|
||||||
|
return [v for _, v in sorted(nest.items())]
|
||||||
|
else:
|
||||||
|
return nest
|
||||||
|
|
||||||
|
|
||||||
|
def _yield_flat_nest(nest):
|
||||||
|
for n in _elements_of(nest):
|
||||||
|
if is_sequence(n):
|
||||||
|
for ni in _yield_flat_nest(n):
|
||||||
|
yield ni
|
||||||
|
else:
|
||||||
|
yield n
|
||||||
|
|
||||||
|
|
||||||
|
def is_sequence(seq):
|
||||||
|
"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
|
||||||
|
|
||||||
|
NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
|
||||||
|
which *does* treat a Python list as a sequence. For ergonomic
|
||||||
|
reasons, `tf.contrib.data` users would prefer to treat lists as
|
||||||
|
implict `tf.Tensor` objects, and dicts as (nested) sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: an input sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sequence is a not a string or list and is a
|
||||||
|
collections.Sequence.
|
||||||
|
"""
|
||||||
|
return (isinstance(seq, (_collections.Sequence, dict))
|
||||||
|
and not isinstance(seq, (list, _six.string_types)))
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(nest):
|
||||||
|
"""Returns a flat sequence from a given nested structure.
|
||||||
|
|
||||||
|
If `nest` is not a sequence, this returns a single-element list: `[nest]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nest: an arbitrarily nested structure or a scalar object.
|
||||||
|
Note, numpy arrays are considered scalars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python list, the flattened version of the input.
|
||||||
|
"""
|
||||||
|
return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest]
|
||||||
|
|
||||||
|
|
||||||
|
def _recursive_assert_same_structure(nest1, nest2, check_types):
|
||||||
|
is_sequence_nest1 = is_sequence(nest1)
|
||||||
|
if is_sequence_nest1 != is_sequence(nest2):
|
||||||
|
raise ValueError(
|
||||||
|
"The two structures don't have the same nested structure. "
|
||||||
|
"First structure: %s, second structure: %s." % (nest1, nest2))
|
||||||
|
|
||||||
|
if is_sequence_nest1:
|
||||||
|
type_nest1 = type(nest1)
|
||||||
|
type_nest2 = type(nest2)
|
||||||
|
if check_types and type_nest1 != type_nest2:
|
||||||
|
raise TypeError(
|
||||||
|
"The two structures don't have the same sequence type. First "
|
||||||
|
"structure has type %s, while second structure has type %s."
|
||||||
|
% (type_nest1, type_nest2))
|
||||||
|
|
||||||
|
for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)):
|
||||||
|
_recursive_assert_same_structure(n1, n2, check_types)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_same_structure(nest1, nest2, check_types=True):
|
||||||
|
"""Asserts that two structures are nested in the same way.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nest1: an arbitrarily nested structure.
|
||||||
|
nest2: an arbitrarily nested structure.
|
||||||
|
check_types: if `True` (default) types of sequences are checked as
|
||||||
|
well. If set to `False`, for example a list and a tuple of objects will
|
||||||
|
look same if they have the same size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the two structures do not have the same number of elements or
|
||||||
|
if the two structures are not nested in the same way.
|
||||||
|
TypeError: If the two structures differ in the type of sequence in any of
|
||||||
|
their substructures. Only possible if `check_types` is `True`.
|
||||||
|
"""
|
||||||
|
len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
|
||||||
|
len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
|
||||||
|
if len_nest1 != len_nest2:
|
||||||
|
raise ValueError("The two structures don't have the same number of "
|
||||||
|
"elements. First structure: %s, second structure: %s."
|
||||||
|
% (nest1, nest2))
|
||||||
|
_recursive_assert_same_structure(nest1, nest2, check_types)
|
||||||
|
|
||||||
|
|
||||||
|
def _packed_nest_with_indices(structure, flat, index):
|
||||||
|
"""Helper function for pack_nest_as.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
structure: Substructure (tuple of elements and/or tuples) to mimic
|
||||||
|
flat: Flattened values to output substructure for.
|
||||||
|
index: Index at which to start reading from flat.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tuple (new_index, child), where:
|
||||||
|
* new_index - the updated index into `flat` having processed `structure`.
|
||||||
|
* packed - the subset of `flat` corresponding to `structure`,
|
||||||
|
having started at `index`, and packed into the same nested
|
||||||
|
format.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `structure` contains more elements than `flat`
|
||||||
|
(assuming indexing starts from `index`).
|
||||||
|
"""
|
||||||
|
packed = []
|
||||||
|
for s in structure:
|
||||||
|
if is_sequence(s):
|
||||||
|
new_index, child = _packed_nest_with_indices(s, flat, index)
|
||||||
|
packed.append(_sequence_like(s, child))
|
||||||
|
index = new_index
|
||||||
|
else:
|
||||||
|
packed.append(flat[index])
|
||||||
|
index += 1
|
||||||
|
return index, packed
|
||||||
|
|
||||||
|
|
||||||
|
def pack_sequence_as(structure, flat_sequence):
|
||||||
|
"""Returns a given flattened sequence packed into a nest.
|
||||||
|
|
||||||
|
If `structure` is a scalar, `flat_sequence` must be a single-element list;
|
||||||
|
in this case the return value is `flat_sequence[0]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
structure: tuple or list constructed of scalars and/or other tuples/lists,
|
||||||
|
or a scalar. Note: numpy arrays are considered scalars.
|
||||||
|
flat_sequence: flat sequence to pack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
packed: `flat_sequence` converted to have the same recursive structure as
|
||||||
|
`structure`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If nest and structure have different element counts.
|
||||||
|
"""
|
||||||
|
if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)):
|
||||||
|
raise TypeError("flat_sequence must be a sequence")
|
||||||
|
|
||||||
|
if not is_sequence(structure):
|
||||||
|
if len(flat_sequence) != 1:
|
||||||
|
raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1"
|
||||||
|
% len(flat_sequence))
|
||||||
|
return flat_sequence[0]
|
||||||
|
|
||||||
|
flat_structure = flatten(structure)
|
||||||
|
if len(flat_structure) != len(flat_sequence):
|
||||||
|
raise ValueError(
|
||||||
|
"Could not pack sequence. Structure had %d elements, but flat_sequence "
|
||||||
|
"had %d elements. Structure: %s, flat_sequence: %s."
|
||||||
|
% (len(flat_structure), len(flat_sequence), structure, flat_sequence))
|
||||||
|
|
||||||
|
_, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
|
||||||
|
return _sequence_like(structure, packed)
|
||||||
|
|
||||||
|
|
||||||
|
def map_structure(func, *structure, **check_types_dict):
|
||||||
|
"""Applies `func` to each entry in `structure` and returns a new structure.
|
||||||
|
|
||||||
|
Applies `func(x[0], x[1], ...)` where x[i] is an entry in
|
||||||
|
`structure[i]`. All structures in `structure` must have the same arity,
|
||||||
|
and the return value will contain the results in the same structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: A callable that acceps as many arguments are there are structures.
|
||||||
|
*structure: scalar, or tuple or list of constructed scalars and/or other
|
||||||
|
tuples/lists, or scalars. Note: numpy arrays are considered scalars.
|
||||||
|
**check_types_dict: only valid keyword argument is `check_types`. If set to
|
||||||
|
`True` (default) the types of iterables within the structures have to be
|
||||||
|
same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
|
||||||
|
exception). To allow this set this argument to `False`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new structure with the same arity as `structure`, whose values correspond
|
||||||
|
to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
|
||||||
|
location in `structure[i]`. If there are different sequence types and
|
||||||
|
`check_types` is `False` the sequence types of the first structure will be
|
||||||
|
used.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `func` is not callable or if the structures do not match
|
||||||
|
each other by depth tree.
|
||||||
|
ValueError: If no structure is provided or if the structures do not match
|
||||||
|
each other by type.
|
||||||
|
ValueError: If wrong keyword arguments are provided.
|
||||||
|
"""
|
||||||
|
if not callable(func):
|
||||||
|
raise TypeError("func must be callable, got: %s" % func)
|
||||||
|
|
||||||
|
if not structure:
|
||||||
|
raise ValueError("Must provide at least one structure")
|
||||||
|
|
||||||
|
if check_types_dict:
|
||||||
|
if "check_types" not in check_types_dict or len(check_types_dict) > 1:
|
||||||
|
raise ValueError("Only valid keyword argument is check_types")
|
||||||
|
check_types = check_types_dict["check_types"]
|
||||||
|
else:
|
||||||
|
check_types = True
|
||||||
|
|
||||||
|
for other in structure[1:]:
|
||||||
|
assert_same_structure(structure[0], other, check_types=check_types)
|
||||||
|
|
||||||
|
flat_structure = [flatten(s) for s in structure]
|
||||||
|
entries = zip(*flat_structure)
|
||||||
|
|
||||||
|
return pack_sequence_as(
|
||||||
|
structure[0], [func(*x) for x in entries])
|
||||||
|
|
||||||
|
|
||||||
|
def _yield_flat_up_to(shallow_tree, input_tree):
|
||||||
|
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||||||
|
if is_sequence(shallow_tree):
|
||||||
|
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
|
||||||
|
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||||||
|
yield input_leaf
|
||||||
|
else:
|
||||||
|
yield input_tree
|
||||||
|
|
||||||
|
|
||||||
|
def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
|
||||||
|
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
|
||||||
|
|
||||||
|
That is, this function tests if the `input_tree` structure can be created from
|
||||||
|
the `shallow_tree` structure by replacing its leaf nodes with deeper
|
||||||
|
tree structures.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
The following code will raise an exception:
|
||||||
|
```python
|
||||||
|
shallow_tree = ["a", "b"]
|
||||||
|
input_tree = ["c", ["d", "e"], "f"]
|
||||||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||||||
|
```
|
||||||
|
|
||||||
|
The following code will not raise an exception:
|
||||||
|
```python
|
||||||
|
shallow_tree = ["a", "b"]
|
||||||
|
input_tree = ["c", ["d", "e"]]
|
||||||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shallow_tree: an arbitrarily nested structure.
|
||||||
|
input_tree: an arbitrarily nested structure.
|
||||||
|
check_types: if `True` (default) the sequence types of `shallow_tree` and
|
||||||
|
`input_tree` have to be the same.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||||||
|
`input_tree`. Only raised if `check_types` is `True`.
|
||||||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
"""
|
||||||
|
if is_sequence(shallow_tree):
|
||||||
|
if not is_sequence(input_tree):
|
||||||
|
raise TypeError(
|
||||||
|
"If shallow structure is a sequence, input must also be a sequence. "
|
||||||
|
"Input has type: %s." % type(input_tree))
|
||||||
|
|
||||||
|
if check_types and not isinstance(input_tree, type(shallow_tree)):
|
||||||
|
raise TypeError(
|
||||||
|
"The two structures don't have the same sequence type. Input "
|
||||||
|
"structure has type %s, while shallow structure has type %s."
|
||||||
|
% (type(input_tree), type(shallow_tree)))
|
||||||
|
|
||||||
|
if len(input_tree) != len(shallow_tree):
|
||||||
|
raise ValueError(
|
||||||
|
"The two structures don't have the same sequence length. Input "
|
||||||
|
"structure has length %s, while shallow structure has length %s."
|
||||||
|
% (len(input_tree), len(shallow_tree)))
|
||||||
|
|
||||||
|
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
|
||||||
|
assert_shallow_structure(shallow_branch, input_branch,
|
||||||
|
check_types=check_types)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_up_to(shallow_tree, input_tree):
|
||||||
|
"""Flattens `input_tree` up to `shallow_tree`.
|
||||||
|
|
||||||
|
Any further depth in structure in `input_tree` is retained as elements in the
|
||||||
|
partially flatten output.
|
||||||
|
|
||||||
|
If `shallow_tree` and `input_tree` are not sequences, this returns a
|
||||||
|
single-element list: `[input_tree]`.
|
||||||
|
|
||||||
|
Use Case:
|
||||||
|
|
||||||
|
Sometimes we may wish to partially flatten a nested sequence, retaining some
|
||||||
|
of the nested structure. We achieve this by specifying a shallow structure,
|
||||||
|
`shallow_tree`, we wish to flatten up to.
|
||||||
|
|
||||||
|
The input, `input_tree`, can be thought of as having the same structure as
|
||||||
|
`shallow_tree`, but with leaf nodes that are themselves tree structures.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
|
||||||
|
shallow_tree = [[True, True], [False, True]]
|
||||||
|
|
||||||
|
flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
|
||||||
|
# Output is:
|
||||||
|
# [[2, 2], [3, 3], [4, 9], [5, 5]]
|
||||||
|
# [True, True, False, True]
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
|
||||||
|
shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
|
||||||
|
|
||||||
|
input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
|
||||||
|
input_tree_flattened = flatten(input_tree)
|
||||||
|
|
||||||
|
# Output is:
|
||||||
|
# [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
|
||||||
|
# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
|
||||||
|
```
|
||||||
|
|
||||||
|
Non-Sequence Edge Cases:
|
||||||
|
|
||||||
|
```python
|
||||||
|
flatten_up_to(0, 0) # Output: [0]
|
||||||
|
flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
|
||||||
|
flatten_up_to([0, 1, 2], 0) # Output: TypeError
|
||||||
|
flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shallow_tree: a possibly pruned structure of input_tree.
|
||||||
|
input_tree: an arbitrarily nested structure or a scalar object.
|
||||||
|
Note, numpy arrays are considered scalars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python list, the partially flattened version of `input_tree` according to
|
||||||
|
the structure of `shallow_tree`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
"""
|
||||||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||||||
|
return list(_yield_flat_up_to(shallow_tree, input_tree))
|
||||||
|
|
||||||
|
|
||||||
|
def map_structure_up_to(shallow_tree, func, *inputs):
|
||||||
|
"""Applies a function or op to a number of partially flattened inputs.
|
||||||
|
|
||||||
|
The `inputs` are flattened up to `shallow_tree` before being mapped.
|
||||||
|
|
||||||
|
Use Case:
|
||||||
|
|
||||||
|
Sometimes we wish to apply a function to a partially flattened
|
||||||
|
sequence (for example when the function itself takes sequence inputs). We
|
||||||
|
achieve this by specifying a shallow structure, `shallow_tree` we wish to
|
||||||
|
flatten up to.
|
||||||
|
|
||||||
|
The `inputs`, can be thought of as having the same structure as
|
||||||
|
`shallow_tree`, but with leaf nodes that are themselves tree structures.
|
||||||
|
|
||||||
|
This function therefore will return something with the same base structure as
|
||||||
|
`shallow_tree`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||||
|
op_tuple = collections.namedtuple("op_tuple", "add, mul")
|
||||||
|
inp_val = ab_tuple(a=2, b=3)
|
||||||
|
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
|
||||||
|
out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
|
||||||
|
inp_val, inp_ops)
|
||||||
|
|
||||||
|
# Output is: ab_tuple(a=6, b=15)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
|
||||||
|
name_list = ['evens', ['odds', 'primes']]
|
||||||
|
out = map_structure_up_to(
|
||||||
|
name_list,
|
||||||
|
lambda name, sec: "first_{}_{}".format(len(sec), name),
|
||||||
|
name_list, data_list)
|
||||||
|
|
||||||
|
# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shallow_tree: a shallow tree, common to all the inputs.
|
||||||
|
func: callable which will be applied to each input individually.
|
||||||
|
*inputs: arbitrarily nested combination of objects that are compatible with
|
||||||
|
shallow_tree. The function `func` is applied to corresponding
|
||||||
|
partially flattened elements of each input, so the function must support
|
||||||
|
arity of `len(inputs)`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
|
||||||
|
TypeError: If the sequence types of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
ValueError: If the sequence lengths of `shallow_tree` are different from
|
||||||
|
`input_tree`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result of repeatedly applying `func`, with same structure as
|
||||||
|
`shallow_tree`.
|
||||||
|
"""
|
||||||
|
if not inputs:
|
||||||
|
raise ValueError("Cannot map over no sequences")
|
||||||
|
for input_tree in inputs:
|
||||||
|
assert_shallow_structure(shallow_tree, input_tree)
|
||||||
|
|
||||||
|
# Flatten each input separately, apply the function to corresponding elements,
|
||||||
|
# then repack based on the structure of the first input.
|
||||||
|
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
|
||||||
|
for input_tree in inputs]
|
||||||
|
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
|
||||||
|
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
|
||||||
|
|
||||||
|
|
||||||
|
_allowed_symbols = [
|
||||||
|
"assert_same_structure",
|
||||||
|
"is_sequence",
|
||||||
|
"flatten",
|
||||||
|
"pack_sequence_as",
|
||||||
|
"map_structure",
|
||||||
|
"assert_shallow_structure",
|
||||||
|
"flatten_up_to",
|
||||||
|
"map_structure_up_to",
|
||||||
|
]
|
||||||
|
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
309
tensorflow/contrib/data/python/util/nest_test.py
Normal file
309
tensorflow/contrib/data/python/util/nest_test.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for utilities working with arbitrarily nested structures."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.data.python.util import nest
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class NestTest(test.TestCase):
|
||||||
|
|
||||||
|
def testFlattenAndPack(self):
|
||||||
|
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
|
||||||
|
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||||
|
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
|
||||||
|
self.assertEqual(
|
||||||
|
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
|
||||||
|
("d", "e", ("f", "g"), "h")))
|
||||||
|
point = collections.namedtuple("Point", ["x", "y"])
|
||||||
|
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
|
||||||
|
flat = [4, 2, 1, 0]
|
||||||
|
self.assertEqual(nest.flatten(structure), flat)
|
||||||
|
restructured_from_flat = nest.pack_sequence_as(structure, flat)
|
||||||
|
self.assertEqual(restructured_from_flat, structure)
|
||||||
|
self.assertEqual(restructured_from_flat[0].x, 4)
|
||||||
|
self.assertEqual(restructured_from_flat[0].y, 2)
|
||||||
|
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
|
||||||
|
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
|
||||||
|
|
||||||
|
self.assertEqual([5], nest.flatten(5))
|
||||||
|
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
|
||||||
|
|
||||||
|
self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
|
||||||
|
self.assertEqual(
|
||||||
|
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
|
||||||
|
nest.pack_sequence_as("scalar", [4, 5])
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
|
||||||
|
nest.pack_sequence_as([4, 5], "bad_sequence")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
|
||||||
|
|
||||||
|
def testIsSequence(self):
|
||||||
|
self.assertFalse(nest.is_sequence("1234"))
|
||||||
|
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
|
||||||
|
self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
|
||||||
|
self.assertFalse(nest.is_sequence([]))
|
||||||
|
self.assertFalse(nest.is_sequence(set([1, 2])))
|
||||||
|
ones = array_ops.ones([2, 3])
|
||||||
|
self.assertFalse(nest.is_sequence(ones))
|
||||||
|
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
|
||||||
|
self.assertFalse(nest.is_sequence(np.ones((4, 5))))
|
||||||
|
self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2}))
|
||||||
|
|
||||||
|
def testAssertSameStructure(self):
|
||||||
|
structure1 = (((1, 2), 3), 4, (5, 6))
|
||||||
|
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
|
||||||
|
structure_different_num_elements = ("spam", "eggs")
|
||||||
|
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
|
||||||
|
nest.assert_same_structure(structure1, structure2)
|
||||||
|
nest.assert_same_structure("abc", 1.0)
|
||||||
|
nest.assert_same_structure("abc", np.array([0, 1]))
|
||||||
|
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"don't have the same number of elements"):
|
||||||
|
nest.assert_same_structure(structure1, structure_different_num_elements)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"don't have the same number of elements"):
|
||||||
|
nest.assert_same_structure((0, 1), np.array([0, 1]))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"don't have the same number of elements"):
|
||||||
|
nest.assert_same_structure(0, (0, 1))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"don't have the same nested structure"):
|
||||||
|
nest.assert_same_structure(structure1, structure_different_nesting)
|
||||||
|
|
||||||
|
named_type_0 = collections.namedtuple("named_0", ("a", "b"))
|
||||||
|
named_type_1 = collections.namedtuple("named_1", ("a", "b"))
|
||||||
|
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
|
||||||
|
named_type_0("a", "b"))
|
||||||
|
|
||||||
|
nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, nest.assert_same_structure,
|
||||||
|
named_type_0(3, 4), named_type_1(3, 4))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"don't have the same nested structure"):
|
||||||
|
nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
"don't have the same nested structure"):
|
||||||
|
nest.assert_same_structure(((3,), 4), (3, (4,)))
|
||||||
|
|
||||||
|
structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
|
||||||
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
"don't have the same sequence type"):
|
||||||
|
nest.assert_same_structure(structure1, structure1_list)
|
||||||
|
nest.assert_same_structure(structure1, structure2, check_types=False)
|
||||||
|
nest.assert_same_structure(structure1, structure1_list, check_types=False)
|
||||||
|
|
||||||
|
def testMapStructure(self):
|
||||||
|
structure1 = (((1, 2), 3), 4, (5, 6))
|
||||||
|
structure2 = (((7, 8), 9), 10, (11, 12))
|
||||||
|
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
|
||||||
|
nest.assert_same_structure(structure1, structure1_plus1)
|
||||||
|
self.assertAllEqual(
|
||||||
|
[2, 3, 4, 5, 6, 7],
|
||||||
|
nest.flatten(structure1_plus1))
|
||||||
|
structure1_plus_structure2 = nest.map_structure(
|
||||||
|
lambda x, y: x + y, structure1, structure2)
|
||||||
|
self.assertEqual(
|
||||||
|
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
|
||||||
|
structure1_plus_structure2)
|
||||||
|
|
||||||
|
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
|
||||||
|
|
||||||
|
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, "callable"):
|
||||||
|
nest.map_structure("bad", structure1_plus1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||||
|
nest.map_structure(lambda x, y: None, 3, (3,))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, "same sequence type"):
|
||||||
|
nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||||
|
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "same nested structure"):
|
||||||
|
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
|
||||||
|
check_types=False)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
|
||||||
|
nest.map_structure(lambda x: None, structure1, foo="a")
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
|
||||||
|
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
|
||||||
|
|
||||||
|
def testAssertShallowStructure(self):
|
||||||
|
inp_ab = ("a", "b")
|
||||||
|
inp_abc = ("a", "b", "c")
|
||||||
|
expected_message = (
|
||||||
|
"The two structures don't have the same sequence length. Input "
|
||||||
|
"structure has length 2, while shallow structure has length 3.")
|
||||||
|
with self.assertRaisesRegexp(ValueError, expected_message):
|
||||||
|
nest.assert_shallow_structure(inp_abc, inp_ab)
|
||||||
|
|
||||||
|
inp_ab1 = ((1, 1), (2, 2))
|
||||||
|
inp_ab2 = {"a": (1, 1), "b": (2, 2)}
|
||||||
|
expected_message = (
|
||||||
|
"The two structures don't have the same sequence type. Input structure "
|
||||||
|
"has type <(type|class) 'tuple'>, while shallow structure has type "
|
||||||
|
"<(type|class) 'dict'>.")
|
||||||
|
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||||
|
nest.assert_shallow_structure(inp_ab2, inp_ab1)
|
||||||
|
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
|
||||||
|
|
||||||
|
def testFlattenUpTo(self):
|
||||||
|
input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
|
||||||
|
shallow_tree = ((True, True), (False, True))
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
|
||||||
|
|
||||||
|
input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
|
||||||
|
shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
|
||||||
|
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
|
||||||
|
input_tree)
|
||||||
|
input_tree_flattened = nest.flatten(input_tree)
|
||||||
|
self.assertEqual(input_tree_flattened_as_shallow_tree,
|
||||||
|
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
|
||||||
|
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
|
||||||
|
|
||||||
|
## Shallow non-list edge-case.
|
||||||
|
# Using iterable elements.
|
||||||
|
input_tree = ["input_tree"]
|
||||||
|
shallow_tree = "shallow_tree"
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
input_tree = ("input_tree_0", "input_tree_1")
|
||||||
|
shallow_tree = "shallow_tree"
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
# Using non-iterable elements.
|
||||||
|
input_tree = (0,)
|
||||||
|
shallow_tree = 9
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
input_tree = (0, 1)
|
||||||
|
shallow_tree = 9
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
## Both non-list edge-case.
|
||||||
|
# Using iterable elements.
|
||||||
|
input_tree = "input_tree"
|
||||||
|
shallow_tree = "shallow_tree"
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
# Using non-iterable elements.
|
||||||
|
input_tree = 0
|
||||||
|
shallow_tree = 0
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_input_tree, [input_tree])
|
||||||
|
self.assertEqual(flattened_shallow_tree, [shallow_tree])
|
||||||
|
|
||||||
|
## Input non-list edge-case.
|
||||||
|
# Using iterable elements.
|
||||||
|
input_tree = "input_tree"
|
||||||
|
shallow_tree = ("shallow_tree",)
|
||||||
|
expected_message = ("If shallow structure is a sequence, input must also "
|
||||||
|
"be a sequence. Input has type: <(type|class) 'str'>.")
|
||||||
|
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||||
|
|
||||||
|
input_tree = "input_tree"
|
||||||
|
shallow_tree = ("shallow_tree_9", "shallow_tree_8")
|
||||||
|
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||||
|
|
||||||
|
# Using non-iterable elements.
|
||||||
|
input_tree = 0
|
||||||
|
shallow_tree = (9,)
|
||||||
|
expected_message = ("If shallow structure is a sequence, input must also "
|
||||||
|
"be a sequence. Input has type: <(type|class) 'int'>.")
|
||||||
|
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||||
|
|
||||||
|
input_tree = 0
|
||||||
|
shallow_tree = (9, 8)
|
||||||
|
with self.assertRaisesRegexp(TypeError, expected_message):
|
||||||
|
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||||
|
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||||
|
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||||
|
|
||||||
|
def testMapStructureUpTo(self):
|
||||||
|
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||||
|
op_tuple = collections.namedtuple("op_tuple", "add, mul")
|
||||||
|
inp_val = ab_tuple(a=2, b=3)
|
||||||
|
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
|
||||||
|
out = nest.map_structure_up_to(
|
||||||
|
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
|
||||||
|
self.assertEqual(out.a, 6)
|
||||||
|
self.assertEqual(out.b, 15)
|
||||||
|
|
||||||
|
data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7)))
|
||||||
|
name_list = ("evens", ("odds", "primes"))
|
||||||
|
out = nest.map_structure_up_to(
|
||||||
|
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
|
||||||
|
name_list, data_list)
|
||||||
|
self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes")))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user