From f3f53e8b394bdcaddc707f7bde8dcc98a73531e7 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 6 Jun 2017 07:47:14 -0700 Subject: [PATCH] [tf.contrib.data] Add support for dicts and remove lists from nested structures. This changes the behavior of constructors like `tf.contrib.data.Dataset.from_tensors()` when passed a list. Previously, the `nest` utility would recurse into each element of such a list and create a separate Dataset component. Now the list will be converted to a tensor, allowing code like: ```python dataset = tf.contrib.data.Dataset.from_tensor_slices(([1, 2, 3], [4, 5, 6])) ``` ...to define a dataset with two components (each of shape `()`). This change also adds support for dictionaries as nested structures, which simplifies integration with dictionary-returning ops like `tf.parse_example()`. Fixes #10151. RELNOTES: Breaking change to `tf.contrib.data.Dataset` APIs that expect a nested structure. Lists are now converted to tf.Tensor implicitly. You may need to change uses of lists to tuples in existing code. In addition, dicts are now supported as a nested structure. PiperOrigin-RevId: 158139467 --- tensorflow/BUILD | 1 + tensorflow/contrib/cmake/tf_python.cmake | 1 + .../kernel_tests/batch_dataset_op_test.py | 4 +- .../dataset_constructor_op_test.py | 27 +- .../kernel_tests/filter_dataset_op_test.py | 4 +- .../kernel_tests/flat_map_dataset_op_test.py | 20 +- .../python/kernel_tests/iterator_ops_test.py | 24 +- .../kernel_tests/map_dataset_op_test.py | 19 +- .../kernel_tests/range_dataset_op_test.py | 8 +- .../kernel_tests/reader_dataset_ops_test.py | 25 + .../kernel_tests/sequence_dataset_op_test.py | 8 +- .../kernel_tests/shuffle_dataset_op_test.py | 8 +- .../kernel_tests/zip_dataset_op_test.py | 4 +- tensorflow/contrib/data/python/ops/BUILD | 2 +- .../contrib/data/python/ops/dataset_ops.py | 4 +- tensorflow/contrib/data/python/util/BUILD | 44 ++ tensorflow/contrib/data/python/util/nest.py | 513 ++++++++++++++++++ .../contrib/data/python/util/nest_test.py | 309 +++++++++++ 18 files changed, 967 insertions(+), 58 deletions(-) create mode 100644 tensorflow/contrib/data/python/util/BUILD create mode 100644 tensorflow/contrib/data/python/util/nest.py create mode 100644 tensorflow/contrib/data/python/util/nest_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0eea54a6efc..42e5b921503 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -231,6 +231,7 @@ filegroup( "//tensorflow/contrib/data/python/framework:all_files", "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", + "//tensorflow/contrib/data/python/util:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 80522a18383..124eab17ccc 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -276,6 +276,7 @@ add_python_module("tensorflow/contrib/data/python") add_python_module("tensorflow/contrib/data/python/framework") add_python_module("tensorflow/contrib/data/python/kernel_tests") add_python_module("tensorflow/contrib/data/python/ops") +add_python_module("tensorflow/contrib/data/python/util") add_python_module("tensorflow/contrib/deprecated") add_python_module("tensorflow/contrib/distributions") add_python_module("tensorflow/contrib/distributions/python") diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index c78e1b412c4..c9412d949c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -40,9 +40,9 @@ class BatchDatasetTest(test.TestCase): """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> BatchDataset(batch_size). - components = [np.arange(7), + components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)] + np.array(37.0) * np.arange(7)) count = array_ops.placeholder(dtypes.int64, shape=[]) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index acff83c2396..6a7bc99fa88 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -33,7 +33,7 @@ class DatasetConstructorTest(test.TestCase): def testTensorDataset(self): """Test an dataset that represents a single tuple of tensors.""" - components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)] + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = (dataset_ops.Dataset.from_tensors(components) .make_initializable_iterator()) @@ -53,11 +53,11 @@ class DatasetConstructorTest(test.TestCase): def testTensorSliceDataset(self): """Test an dataset that represents the slices from a tuple of tensors.""" - components = [ + components = ( np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0]) - ] + ) iterator = (dataset_ops.Dataset.from_tensor_slices(components) .make_initializable_iterator()) @@ -76,6 +76,27 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testTensorSliceDatasetWithDict(self): + components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + iterator = (dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual(dtypes.int32, iterator.output_types["foo"]) + self.assertEqual(dtypes.float32, iterator.output_types["bar"]) + self.assertEqual((), iterator.output_shapes["foo"]) + self.assertEqual((1,), iterator.output_shapes["bar"]) + + with self.test_session() as sess: + sess.run(init_op) + for i in range(3): + results = sess.run(get_next) + self.assertEqual(components["foo"][i], results["foo"]) + self.assertEqual(components["bar"][i], results["bar"]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testSparseTensorSliceDataset(self): """Test a dataset based on slices of a `tf.SparseTensor`.""" st = array_ops.sparse_placeholder(dtypes.float64) diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index e4e994479fb..3ea783ad899 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -30,12 +30,12 @@ from tensorflow.python.platform import test class FilterDatasetTest(test.TestCase): def testFilterDataset(self): - components = [ + components = ( np.arange(7, dtype=np.int64), np.array([[1, 2, 3]], dtype=np.int64) * np.arange( 7, dtype=np.int64)[:, np.newaxis], np.array(37.0, dtype=np.float64) * np.arange(7) - ] + ) count = array_ops.placeholder(dtypes.int64, shape=[]) modulus = array_ops.placeholder(dtypes.int64) diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index 705cfef9017..3c9c714bde4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -33,13 +33,13 @@ class FlatMapDatasetTest(test.TestCase): # pylint: disable=g-long-lambda def testFlatMapDataset(self): repeats = [1, 2, 3, 4, 5, 0, 1] - components = [np.array(repeats, dtype=np.int64)] + components = np.array(repeats, dtype=np.int64) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x)) .make_initializable_iterator()) init_op = iterator.initializer - get_next, = iterator.get_next() + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -51,14 +51,14 @@ class FlatMapDatasetTest(test.TestCase): def testNestedFlatMapDataset(self): repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] - components = [np.array(repeats, dtype=np.int64)] + components = np.array(repeats, dtype=np.int64) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]) - .flat_map(lambda y: dataset_ops.Dataset.from_tensors([y]) + .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) + .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) .repeat(y))).make_initializable_iterator()) init_op = iterator.initializer - get_next, = iterator.get_next() + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -72,15 +72,15 @@ class FlatMapDatasetTest(test.TestCase): def testSharedResourceNestedFlatMapDataset(self): repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] - components = [np.array(repeats, dtype=np.int64)] + components = np.array(repeats, dtype=np.int64) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]) - .flat_map(lambda y: dataset_ops.Dataset.from_tensors([y]) + .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) + .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) .repeat(y))).make_initializable_iterator( shared_name="shared_flat_map_iterator")) init_op = iterator.initializer - get_next, = iterator.get_next() + get_next = iterator.get_next() # Create two concurrent sessions that share the same iterator # resource on the same server, and verify that a random diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 650e1ff59e9..d6dd134a5b9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -47,9 +47,9 @@ class IteratorTest(test.TestCase): gradients_impl.gradients(value, [component, side]) def testOneShotIterator(self): - components = [np.arange(7), + components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)] + np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) @@ -71,10 +71,10 @@ class IteratorTest(test.TestCase): sess.run(get_next) def testOneShotIteratorCaptureByValue(self): - components = [np.arange(7), + components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)] - tensor_components = [ops.convert_to_tensor(c) for c in components] + np.array(37.0) * np.arange(7)) + tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) @@ -96,9 +96,9 @@ class IteratorTest(test.TestCase): sess.run(get_next) def testOneShotIteratorInsideContainer(self): - components = [np.arange(7), + components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)] + np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): @@ -129,11 +129,11 @@ class IteratorTest(test.TestCase): sess.run(get_next) def testSimpleSharedResource(self): - components = [ + components = ( np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), np.array(37.0, dtype=np.float64) - ] + ) server = server_lib.Server.create_local_server() @@ -166,8 +166,8 @@ class IteratorTest(test.TestCase): # new graph. iterator = dataset_ops.Iterator.from_structure( shared_name="shared_iterator", - output_types=[dtypes.int64, dtypes.int64, dtypes.float64], - output_shapes=[[], [3], []]) + output_types=(dtypes.int64, dtypes.int64, dtypes.float64), + output_shapes=([], [3], [])) get_next = iterator.get_next() with session.Session(server.target) as sess: @@ -179,7 +179,7 @@ class IteratorTest(test.TestCase): sess.run(get_next) def testNotInitializedError(self): - components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)] + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = (dataset_ops.Dataset.from_tensors(components) .make_initializable_iterator()) get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 68cd3623c00..b5956ac49c3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -45,9 +45,9 @@ class MapDatasetTest(test.TestCase): """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count). - components = [np.arange(7), + components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)] + np.array(37.0) * np.arange(7)) count = array_ops.placeholder(dtypes.int64, shape=[]) dataset = self._buildMapDataset(components, count) @@ -107,9 +107,9 @@ class MapDatasetTest(test.TestCase): """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) -> # RepeatDataset(count). - components = [np.arange(7), + components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)] + np.array(37.0) * np.arange(7)) count = array_ops.placeholder(dtypes.int64, shape=[]) num_threads = array_ops.placeholder(dtypes.int32, shape=[]) output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) @@ -175,9 +175,9 @@ class MapDatasetTest(test.TestCase): def _testDisposeParallelMapDataset(self, explicit_dispose): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(1000). - components = [np.arange(1000), + components = (np.arange(1000), np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], - np.array(37.0) * np.arange(1000)] + np.array(37.0) * np.arange(1000)) dataset = self._buildParallelMapDataset(components, 1000, 100, 100) iterator = dataset.make_initializable_iterator() @@ -200,7 +200,7 @@ class MapDatasetTest(test.TestCase): self._testDisposeParallelMapDataset(False) def testParallelMapError(self): - components = [np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)] + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) dataset = (dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.check_numerics(x, "message"))) @@ -230,10 +230,7 @@ class MapDatasetTest(test.TestCase): lookup_ops.KeyValueTensorInitializer(keys, values), default_val) input_sentences = dataset_ops.Dataset.from_tensor_slices( - constant_op.constant([ - "brain brain tank salad surgery", - "surgery brain", - ])) + ["brain brain tank salad surgery", "surgery brain"]) iterator = (input_sentences .map(lambda x: string_ops.string_split([x]).values) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 8e10e22c353..a8edbbd20c8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -17,8 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -156,7 +154,7 @@ class RangeDatasetTest(test.TestCase): sess.run(get_next) def testEnumerateDataset(self): - components = [np.array(["a", "b"]), np.array([1, 2]), np.array([37.0, 38])] + components = (["a", "b"], [1, 2], [37.0, 38]) start = constant_op.constant(20, dtype=dtypes.int64) iterator = (dataset_ops.Dataset.from_tensor_slices(components).enumerate( @@ -171,8 +169,8 @@ class RangeDatasetTest(test.TestCase): with self.test_session() as sess: sess.run(init_op) - self.assertEqual((20, [b"a", 1, 37.0]), sess.run(get_next)) - self.assertEqual((21, [b"b", 2, 38.0]), sess.run(get_next)) + self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next)) + self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 3d5b34b77b8..133165a1c25 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -496,5 +496,30 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) + def testReadWithEquivalentDataset(self): + # TODO(mrry): Add support for tf.SparseTensor as a Dataset component. + features = { + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + } + dataset = (dataset_ops.TFRecordDataset(self.test_filenames) + .map(lambda x: parsing_ops.parse_single_example(x, features)) + .repeat(10) + .batch(2)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for file_batch, _, _, _, record_batch in self._next_expected_batch( + range(self._num_files), 2, 10): + actual_batch = sess.run(next_element) + self.assertAllEqual(file_batch, actual_batch["file"]) + self.assertAllEqual(record_batch, actual_batch["record"]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 6362b5e450a..91615e9f620 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -30,7 +30,7 @@ class SequenceDatasetTest(test.TestCase): def testRepeatTensorDataset(self): """Test a dataset that repeats its input multiple times.""" - components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)] + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) # This placeholder can be fed when dataset-definition subgraph # runs (i.e. `init_op` below) to configure the number of # repetitions used in a particular iterator. @@ -79,7 +79,7 @@ class SequenceDatasetTest(test.TestCase): self.assertAllEqual(component, result_component) def testTakeTensorDataset(self): - components = [np.arange(10)] + components = (np.arange(10),) count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) iterator = (dataset_ops.Dataset.from_tensor_slices(components) @@ -125,7 +125,7 @@ class SequenceDatasetTest(test.TestCase): sess.run(get_next) def testSkipTensorDataset(self): - components = [np.arange(10)] + components = (np.arange(10),) count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) iterator = (dataset_ops.Dataset.from_tensor_slices(components) @@ -171,7 +171,7 @@ class SequenceDatasetTest(test.TestCase): def testRepeatRepeatTensorDataset(self): """Test the composition of repeat datasets.""" - components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)] + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) inner_count = array_ops.placeholder(dtypes.int64, shape=[]) outer_count = array_ops.placeholder(dtypes.int64, shape=[]) diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 8048e4f87ed..d9bfca30bbf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -32,10 +32,10 @@ from tensorflow.python.platform import test class ShuffleDatasetTest(test.TestCase): def testShuffleDataset(self): - components = [ + components = ( np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), np.array([9.0, 10.0, 11.0, 12.0]) - ] + ) count_placeholder = array_ops.placeholder_with_default( constant_op.constant(5, dtypes.int64), shape=[]) buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) @@ -47,7 +47,7 @@ class ShuffleDatasetTest(test.TestCase): shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder, seed_placeholder) - self.assertEqual([c.shape[1:] for c in components], + self.assertEqual(tuple([c.shape[1:] for c in components]), shuffle_dataset.output_shapes) # Create initialization ops for iterators without and with @@ -132,7 +132,7 @@ class ShuffleDatasetTest(test.TestCase): sess.run(get_next) def testDefaultArguments(self): - components = np.array([0, 1, 2, 3, 4]) + components = [0, 1, 2, 3, 4] iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) .repeat().make_one_shot_iterator()) diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py index c47f072361c..b0e72183019 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py @@ -35,10 +35,10 @@ class ZipDatasetTest(test.TestCase): array_ops.placeholder(dtypes.float64) ] - datasets = [ + datasets = tuple([ dataset_ops.Dataset.from_tensor_slices(component_placeholder) for component_placeholder in component_placeholders - ] + ]) zipped = dataset_ops.Dataset.zip(datasets) iterator = zipped.make_initializable_iterator() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 489ec879a3e..08a2774ece2 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -10,11 +10,11 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/framework:function", + "//tensorflow/contrib/data/python/util:nest", "//tensorflow/contrib/util:util_py", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:parsing_ops", - "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 0c3b24430a4..65bba6c7442 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -22,6 +22,7 @@ import abc import numpy as np from tensorflow.contrib.data.python.framework import function +from tensorflow.contrib.data.python.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -38,7 +39,6 @@ from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import gfile -from tensorflow.python.util import nest class Iterator(object): @@ -1869,7 +1869,7 @@ def _parse_example(serialized, features): result.extend([val.indices, val.values, val.dense_shape]) else: result.append(val) - return result + return tuple(result) def _get_file_names(file_pattern, randomize_input): diff --git a/tensorflow/contrib/data/python/util/BUILD b/tensorflow/contrib/data/python/util/BUILD new file mode 100644 index 00000000000..b9691c8e491 --- /dev/null +++ b/tensorflow/contrib/data/python/util/BUILD @@ -0,0 +1,44 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "nest", + srcs = ["nest.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + ], +) + +py_test( + name = "nest_test", + size = "small", + srcs = ["nest_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":nest", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/data/python/util/nest.py b/tensorflow/contrib/data/python/util/nest.py new file mode 100644 index 00000000000..91c8416d5ae --- /dev/null +++ b/tensorflow/contrib/data/python/util/nest.py @@ -0,0 +1,513 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""## Functions for working with arbitrarily nested sequences of elements. + +NOTE(mrry): This fork of the `tensorflow.python.util.nest` module +makes two changes: + +1. It adds support for dictionaries as a level of nesting in nested structures. +2. It removes support for lists as a level of nesting in nested structures. + +The motivation for this change is twofold: + +1. Many input-processing functions (e.g. `tf.parse_example()`) return + dictionaries, and we would like to support them natively in datasets. +2. It seems more natural for lists to be treated (e.g. in Dataset constructors) + as tensors, rather than lists of (lists of...) tensors. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections as _collections + +import six as _six + +from tensorflow.python.util.all_util import remove_undocumented + + +def _sequence_like(instance, args): + """Converts the sequence `args` to the same type as `instance`. + + Args: + instance: an instance of `tuple`, `list`, or a `namedtuple` class. + args: elements to be converted to a sequence. + + Returns: + `args` with the type of `instance`. + """ + if isinstance(instance, dict): + # This is a dict. Iterate over the keys in sorted order to make + # this deterministic. + return {k: v for k, v in zip(sorted(instance.keys()), args)} + elif (isinstance(instance, tuple) and + hasattr(instance, "_fields") and + isinstance(instance._fields, _collections.Sequence) and + all(isinstance(f, _six.string_types) for f in instance._fields)): + # This is a namedtuple + return type(instance)(*args) + else: + # Not a namedtuple + return type(instance)(args) + + +def _elements_of(nest): + if isinstance(nest, dict): + # Iterate over dict keys in sorted order to make this deterministic. + return [v for _, v in sorted(nest.items())] + else: + return nest + + +def _yield_flat_nest(nest): + for n in _elements_of(nest): + if is_sequence(n): + for ni in _yield_flat_nest(n): + yield ni + else: + yield n + + +def is_sequence(seq): + """Returns a true if `seq` is a Sequence or dict (except strings/lists). + + NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`, + which *does* treat a Python list as a sequence. For ergonomic + reasons, `tf.contrib.data` users would prefer to treat lists as + implict `tf.Tensor` objects, and dicts as (nested) sequences. + + Args: + seq: an input sequence. + + Returns: + True if the sequence is a not a string or list and is a + collections.Sequence. + """ + return (isinstance(seq, (_collections.Sequence, dict)) + and not isinstance(seq, (list, _six.string_types))) + + +def flatten(nest): + """Returns a flat sequence from a given nested structure. + + If `nest` is not a sequence, this returns a single-element list: `[nest]`. + + Args: + nest: an arbitrarily nested structure or a scalar object. + Note, numpy arrays are considered scalars. + + Returns: + A Python list, the flattened version of the input. + """ + return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest] + + +def _recursive_assert_same_structure(nest1, nest2, check_types): + is_sequence_nest1 = is_sequence(nest1) + if is_sequence_nest1 != is_sequence(nest2): + raise ValueError( + "The two structures don't have the same nested structure. " + "First structure: %s, second structure: %s." % (nest1, nest2)) + + if is_sequence_nest1: + type_nest1 = type(nest1) + type_nest2 = type(nest2) + if check_types and type_nest1 != type_nest2: + raise TypeError( + "The two structures don't have the same sequence type. First " + "structure has type %s, while second structure has type %s." + % (type_nest1, type_nest2)) + + for n1, n2 in zip(_elements_of(nest1), _elements_of(nest2)): + _recursive_assert_same_structure(n1, n2, check_types) + + +def assert_same_structure(nest1, nest2, check_types=True): + """Asserts that two structures are nested in the same way. + + Args: + nest1: an arbitrarily nested structure. + nest2: an arbitrarily nested structure. + check_types: if `True` (default) types of sequences are checked as + well. If set to `False`, for example a list and a tuple of objects will + look same if they have the same size. + + Raises: + ValueError: If the two structures do not have the same number of elements or + if the two structures are not nested in the same way. + TypeError: If the two structures differ in the type of sequence in any of + their substructures. Only possible if `check_types` is `True`. + """ + len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 + len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 + if len_nest1 != len_nest2: + raise ValueError("The two structures don't have the same number of " + "elements. First structure: %s, second structure: %s." + % (nest1, nest2)) + _recursive_assert_same_structure(nest1, nest2, check_types) + + +def _packed_nest_with_indices(structure, flat, index): + """Helper function for pack_nest_as. + + Args: + structure: Substructure (tuple of elements and/or tuples) to mimic + flat: Flattened values to output substructure for. + index: Index at which to start reading from flat. + + Returns: + The tuple (new_index, child), where: + * new_index - the updated index into `flat` having processed `structure`. + * packed - the subset of `flat` corresponding to `structure`, + having started at `index`, and packed into the same nested + format. + + Raises: + ValueError: if `structure` contains more elements than `flat` + (assuming indexing starts from `index`). + """ + packed = [] + for s in structure: + if is_sequence(s): + new_index, child = _packed_nest_with_indices(s, flat, index) + packed.append(_sequence_like(s, child)) + index = new_index + else: + packed.append(flat[index]) + index += 1 + return index, packed + + +def pack_sequence_as(structure, flat_sequence): + """Returns a given flattened sequence packed into a nest. + + If `structure` is a scalar, `flat_sequence` must be a single-element list; + in this case the return value is `flat_sequence[0]`. + + Args: + structure: tuple or list constructed of scalars and/or other tuples/lists, + or a scalar. Note: numpy arrays are considered scalars. + flat_sequence: flat sequence to pack. + + Returns: + packed: `flat_sequence` converted to have the same recursive structure as + `structure`. + + Raises: + ValueError: If nest and structure have different element counts. + """ + if not (is_sequence(flat_sequence) or isinstance(flat_sequence, list)): + raise TypeError("flat_sequence must be a sequence") + + if not is_sequence(structure): + if len(flat_sequence) != 1: + raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1" + % len(flat_sequence)) + return flat_sequence[0] + + flat_structure = flatten(structure) + if len(flat_structure) != len(flat_sequence): + raise ValueError( + "Could not pack sequence. Structure had %d elements, but flat_sequence " + "had %d elements. Structure: %s, flat_sequence: %s." + % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) + + _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) + return _sequence_like(structure, packed) + + +def map_structure(func, *structure, **check_types_dict): + """Applies `func` to each entry in `structure` and returns a new structure. + + Applies `func(x[0], x[1], ...)` where x[i] is an entry in + `structure[i]`. All structures in `structure` must have the same arity, + and the return value will contain the results in the same structure. + + Args: + func: A callable that acceps as many arguments are there are structures. + *structure: scalar, or tuple or list of constructed scalars and/or other + tuples/lists, or scalars. Note: numpy arrays are considered scalars. + **check_types_dict: only valid keyword argument is `check_types`. If set to + `True` (default) the types of iterables within the structures have to be + same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` + exception). To allow this set this argument to `False`. + + Returns: + A new structure with the same arity as `structure`, whose values correspond + to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding + location in `structure[i]`. If there are different sequence types and + `check_types` is `False` the sequence types of the first structure will be + used. + + Raises: + TypeError: If `func` is not callable or if the structures do not match + each other by depth tree. + ValueError: If no structure is provided or if the structures do not match + each other by type. + ValueError: If wrong keyword arguments are provided. + """ + if not callable(func): + raise TypeError("func must be callable, got: %s" % func) + + if not structure: + raise ValueError("Must provide at least one structure") + + if check_types_dict: + if "check_types" not in check_types_dict or len(check_types_dict) > 1: + raise ValueError("Only valid keyword argument is check_types") + check_types = check_types_dict["check_types"] + else: + check_types = True + + for other in structure[1:]: + assert_same_structure(structure[0], other, check_types=check_types) + + flat_structure = [flatten(s) for s in structure] + entries = zip(*flat_structure) + + return pack_sequence_as( + structure[0], [func(*x) for x in entries]) + + +def _yield_flat_up_to(shallow_tree, input_tree): + """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" + if is_sequence(shallow_tree): + for shallow_branch, input_branch in zip(shallow_tree, input_tree): + for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): + yield input_leaf + else: + yield input_tree + + +def assert_shallow_structure(shallow_tree, input_tree, check_types=True): + """Asserts that `shallow_tree` is a shallow structure of `input_tree`. + + That is, this function tests if the `input_tree` structure can be created from + the `shallow_tree` structure by replacing its leaf nodes with deeper + tree structures. + + Examples: + + The following code will raise an exception: + ```python + shallow_tree = ["a", "b"] + input_tree = ["c", ["d", "e"], "f"] + assert_shallow_structure(shallow_tree, input_tree) + ``` + + The following code will not raise an exception: + ```python + shallow_tree = ["a", "b"] + input_tree = ["c", ["d", "e"]] + assert_shallow_structure(shallow_tree, input_tree) + ``` + + Args: + shallow_tree: an arbitrarily nested structure. + input_tree: an arbitrarily nested structure. + check_types: if `True` (default) the sequence types of `shallow_tree` and + `input_tree` have to be the same. + + Raises: + TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + TypeError: If the sequence types of `shallow_tree` are different from + `input_tree`. Only raised if `check_types` is `True`. + ValueError: If the sequence lengths of `shallow_tree` are different from + `input_tree`. + """ + if is_sequence(shallow_tree): + if not is_sequence(input_tree): + raise TypeError( + "If shallow structure is a sequence, input must also be a sequence. " + "Input has type: %s." % type(input_tree)) + + if check_types and not isinstance(input_tree, type(shallow_tree)): + raise TypeError( + "The two structures don't have the same sequence type. Input " + "structure has type %s, while shallow structure has type %s." + % (type(input_tree), type(shallow_tree))) + + if len(input_tree) != len(shallow_tree): + raise ValueError( + "The two structures don't have the same sequence length. Input " + "structure has length %s, while shallow structure has length %s." + % (len(input_tree), len(shallow_tree))) + + for shallow_branch, input_branch in zip(shallow_tree, input_tree): + assert_shallow_structure(shallow_branch, input_branch, + check_types=check_types) + + +def flatten_up_to(shallow_tree, input_tree): + """Flattens `input_tree` up to `shallow_tree`. + + Any further depth in structure in `input_tree` is retained as elements in the + partially flatten output. + + If `shallow_tree` and `input_tree` are not sequences, this returns a + single-element list: `[input_tree]`. + + Use Case: + + Sometimes we may wish to partially flatten a nested sequence, retaining some + of the nested structure. We achieve this by specifying a shallow structure, + `shallow_tree`, we wish to flatten up to. + + The input, `input_tree`, can be thought of as having the same structure as + `shallow_tree`, but with leaf nodes that are themselves tree structures. + + Examples: + + ```python + input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + shallow_tree = [[True, True], [False, True]] + + flattened_input_tree = flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) + + # Output is: + # [[2, 2], [3, 3], [4, 9], [5, 5]] + # [True, True, False, True] + ``` + + ```python + input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] + shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] + + input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) + input_tree_flattened = flatten(input_tree) + + # Output is: + # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + # ['a', 1, 'b', 2, 'c', 3, 'd', 4] + ``` + + Non-Sequence Edge Cases: + + ```python + flatten_up_to(0, 0) # Output: [0] + flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] + flatten_up_to([0, 1, 2], 0) # Output: TypeError + flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] + ``` + + Args: + shallow_tree: a possibly pruned structure of input_tree. + input_tree: an arbitrarily nested structure or a scalar object. + Note, numpy arrays are considered scalars. + + Returns: + A Python list, the partially flattened version of `input_tree` according to + the structure of `shallow_tree`. + + Raises: + TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + TypeError: If the sequence types of `shallow_tree` are different from + `input_tree`. + ValueError: If the sequence lengths of `shallow_tree` are different from + `input_tree`. + """ + assert_shallow_structure(shallow_tree, input_tree) + return list(_yield_flat_up_to(shallow_tree, input_tree)) + + +def map_structure_up_to(shallow_tree, func, *inputs): + """Applies a function or op to a number of partially flattened inputs. + + The `inputs` are flattened up to `shallow_tree` before being mapped. + + Use Case: + + Sometimes we wish to apply a function to a partially flattened + sequence (for example when the function itself takes sequence inputs). We + achieve this by specifying a shallow structure, `shallow_tree` we wish to + flatten up to. + + The `inputs`, can be thought of as having the same structure as + `shallow_tree`, but with leaf nodes that are themselves tree structures. + + This function therefore will return something with the same base structure as + `shallow_tree`. + + Examples: + + ```python + ab_tuple = collections.namedtuple("ab_tuple", "a, b") + op_tuple = collections.namedtuple("op_tuple", "add, mul") + inp_val = ab_tuple(a=2, b=3) + inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) + out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, + inp_val, inp_ops) + + # Output is: ab_tuple(a=6, b=15) + ``` + + ```python + data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + name_list = ['evens', ['odds', 'primes']] + out = map_structure_up_to( + name_list, + lambda name, sec: "first_{}_{}".format(len(sec), name), + name_list, data_list) + + # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] + ``` + + Args: + shallow_tree: a shallow tree, common to all the inputs. + func: callable which will be applied to each input individually. + *inputs: arbitrarily nested combination of objects that are compatible with + shallow_tree. The function `func` is applied to corresponding + partially flattened elements of each input, so the function must support + arity of `len(inputs)`. + + Raises: + TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + TypeError: If the sequence types of `shallow_tree` are different from + `input_tree`. + ValueError: If the sequence lengths of `shallow_tree` are different from + `input_tree`. + + Returns: + result of repeatedly applying `func`, with same structure as + `shallow_tree`. + """ + if not inputs: + raise ValueError("Cannot map over no sequences") + for input_tree in inputs: + assert_shallow_structure(shallow_tree, input_tree) + + # Flatten each input separately, apply the function to corresponding elements, + # then repack based on the structure of the first input. + all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) + for input_tree in inputs] + results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] + return pack_sequence_as(structure=shallow_tree, flat_sequence=results) + + +_allowed_symbols = [ + "assert_same_structure", + "is_sequence", + "flatten", + "pack_sequence_as", + "map_structure", + "assert_shallow_structure", + "flatten_up_to", + "map_structure_up_to", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/data/python/util/nest_test.py b/tensorflow/contrib/data/python/util/nest_test.py new file mode 100644 index 00000000000..7852e4f8617 --- /dev/null +++ b/tensorflow/contrib/data/python/util/nest_test.py @@ -0,0 +1,309 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities working with arbitrarily nested structures.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.contrib.data.python.util import nest +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class NestTest(test.TestCase): + + def testFlattenAndPack(self): + structure = ((3, 4), 5, (6, 7, (9, 10), 8)) + flat = ["a", "b", "c", "d", "e", "f", "g", "h"] + self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) + self.assertEqual( + nest.pack_sequence_as(structure, flat), (("a", "b"), "c", + ("d", "e", ("f", "g"), "h"))) + point = collections.namedtuple("Point", ["x", "y"]) + structure = (point(x=4, y=2), ((point(x=1, y=0),),)) + flat = [4, 2, 1, 0] + self.assertEqual(nest.flatten(structure), flat) + restructured_from_flat = nest.pack_sequence_as(structure, flat) + self.assertEqual(restructured_from_flat, structure) + self.assertEqual(restructured_from_flat[0].x, 4) + self.assertEqual(restructured_from_flat[0].y, 2) + self.assertEqual(restructured_from_flat[1][0][0].x, 1) + self.assertEqual(restructured_from_flat[1][0][0].y, 0) + + self.assertEqual([5], nest.flatten(5)) + self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) + + self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) + self.assertEqual( + np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) + + with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): + nest.pack_sequence_as("scalar", [4, 5]) + + with self.assertRaisesRegexp(TypeError, "flat_sequence"): + nest.pack_sequence_as([4, 5], "bad_sequence") + + with self.assertRaises(ValueError): + nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + + def testIsSequence(self): + self.assertFalse(nest.is_sequence("1234")) + self.assertFalse(nest.is_sequence([1, 3, [4, 5]])) + self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) + self.assertFalse(nest.is_sequence([])) + self.assertFalse(nest.is_sequence(set([1, 2]))) + ones = array_ops.ones([2, 3]) + self.assertFalse(nest.is_sequence(ones)) + self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) + self.assertFalse(nest.is_sequence(np.ones((4, 5)))) + self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2})) + + def testAssertSameStructure(self): + structure1 = (((1, 2), 3), 4, (5, 6)) + structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + structure_different_num_elements = ("spam", "eggs") + structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) + nest.assert_same_structure(structure1, structure2) + nest.assert_same_structure("abc", 1.0) + nest.assert_same_structure("abc", np.array([0, 1])) + nest.assert_same_structure("abc", constant_op.constant([0, 1])) + + with self.assertRaisesRegexp(ValueError, + "don't have the same number of elements"): + nest.assert_same_structure(structure1, structure_different_num_elements) + + with self.assertRaisesRegexp(ValueError, + "don't have the same number of elements"): + nest.assert_same_structure((0, 1), np.array([0, 1])) + + with self.assertRaisesRegexp(ValueError, + "don't have the same number of elements"): + nest.assert_same_structure(0, (0, 1)) + + with self.assertRaisesRegexp(ValueError, + "don't have the same nested structure"): + nest.assert_same_structure(structure1, structure_different_nesting) + + named_type_0 = collections.namedtuple("named_0", ("a", "b")) + named_type_1 = collections.namedtuple("named_1", ("a", "b")) + self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), + named_type_0("a", "b")) + + nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) + + self.assertRaises(TypeError, nest.assert_same_structure, + named_type_0(3, 4), named_type_1(3, 4)) + + with self.assertRaisesRegexp(ValueError, + "don't have the same nested structure"): + nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4)) + + with self.assertRaisesRegexp(ValueError, + "don't have the same nested structure"): + nest.assert_same_structure(((3,), 4), (3, (4,))) + + structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)} + with self.assertRaisesRegexp(TypeError, + "don't have the same sequence type"): + nest.assert_same_structure(structure1, structure1_list) + nest.assert_same_structure(structure1, structure2, check_types=False) + nest.assert_same_structure(structure1, structure1_list, check_types=False) + + def testMapStructure(self): + structure1 = (((1, 2), 3), 4, (5, 6)) + structure2 = (((7, 8), 9), 10, (11, 12)) + structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) + nest.assert_same_structure(structure1, structure1_plus1) + self.assertAllEqual( + [2, 3, 4, 5, 6, 7], + nest.flatten(structure1_plus1)) + structure1_plus_structure2 = nest.map_structure( + lambda x, y: x + y, structure1, structure2) + self.assertEqual( + (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), + structure1_plus_structure2) + + self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) + + self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) + + with self.assertRaisesRegexp(TypeError, "callable"): + nest.map_structure("bad", structure1_plus1) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, 3, (3,)) + + with self.assertRaisesRegexp(TypeError, "same sequence type"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5}) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), + check_types=False) + + with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + nest.map_structure(lambda x: None, structure1, foo="a") + + with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") + + def testAssertShallowStructure(self): + inp_ab = ("a", "b") + inp_abc = ("a", "b", "c") + expected_message = ( + "The two structures don't have the same sequence length. Input " + "structure has length 2, while shallow structure has length 3.") + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_shallow_structure(inp_abc, inp_ab) + + inp_ab1 = ((1, 1), (2, 2)) + inp_ab2 = {"a": (1, 1), "b": (2, 2)} + expected_message = ( + "The two structures don't have the same sequence type. Input structure " + "has type <(type|class) 'tuple'>, while shallow structure has type " + "<(type|class) 'dict'>.") + with self.assertRaisesRegexp(TypeError, expected_message): + nest.assert_shallow_structure(inp_ab2, inp_ab1) + nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) + + def testFlattenUpTo(self): + input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5))) + shallow_tree = ((True, True), (False, True)) + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)]) + self.assertEqual(flattened_shallow_tree, [True, True, False, True]) + + input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4)))))) + shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4"))))) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + input_tree_flattened = nest.flatten(input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) + self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) + + ## Shallow non-list edge-case. + # Using iterable elements. + input_tree = ["input_tree"] + shallow_tree = "shallow_tree" + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + input_tree = ("input_tree_0", "input_tree_1") + shallow_tree = "shallow_tree" + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + # Using non-iterable elements. + input_tree = (0,) + shallow_tree = 9 + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + input_tree = (0, 1) + shallow_tree = 9 + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + ## Both non-list edge-case. + # Using iterable elements. + input_tree = "input_tree" + shallow_tree = "shallow_tree" + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + # Using non-iterable elements. + input_tree = 0 + shallow_tree = 0 + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + ## Input non-list edge-case. + # Using iterable elements. + input_tree = "input_tree" + shallow_tree = ("shallow_tree",) + expected_message = ("If shallow structure is a sequence, input must also " + "be a sequence. Input has type: <(type|class) 'str'>.") + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, list(shallow_tree)) + + input_tree = "input_tree" + shallow_tree = ("shallow_tree_9", "shallow_tree_8") + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, list(shallow_tree)) + + # Using non-iterable elements. + input_tree = 0 + shallow_tree = (9,) + expected_message = ("If shallow structure is a sequence, input must also " + "be a sequence. Input has type: <(type|class) 'int'>.") + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, list(shallow_tree)) + + input_tree = 0 + shallow_tree = (9, 8) + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, list(shallow_tree)) + + def testMapStructureUpTo(self): + ab_tuple = collections.namedtuple("ab_tuple", "a, b") + op_tuple = collections.namedtuple("op_tuple", "add, mul") + inp_val = ab_tuple(a=2, b=3) + inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) + out = nest.map_structure_up_to( + inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) + self.assertEqual(out.a, 6) + self.assertEqual(out.b, 15) + + data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7))) + name_list = ("evens", ("odds", "primes")) + out = nest.map_structure_up_to( + name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), + name_list, data_list) + self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes"))) + + +if __name__ == "__main__": + test.main()