From 9b11f458196f6f0528c9974238497a6c8b6da547 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 29 Jun 2017 11:10:56 -0700 Subject: [PATCH] [tf.contrib.data] Fix the handling of dict-typed elements in functions. Previously, we were treating a `dict` as a sequence, which led to incorrect behavior like passing all of the dict's keys rather than values as the arguments to a map or filter function. This change changes the behavior so that the dict is passed as a single argument to these functions. It additionally fixes the ported version of `nest.flatten_up_to()` so that `Dataset.padded_batch()` works with dict-typed elements. Fixes #11016. PiperOrigin-RevId: 160548475 --- .../python/kernel_tests/bucketing_test.py | 24 ++++++++++++------- .../kernel_tests/filter_dataset_op_test.py | 17 +++++++++++++ .../kernel_tests/flat_map_dataset_op_test.py | 17 +++++++++++++ .../kernel_tests/map_dataset_op_test.py | 16 +++++++++++++ .../contrib/data/python/ops/dataset_ops.py | 13 ++++++---- tensorflow/contrib/data/python/util/nest.py | 4 +++- .../contrib/data/python/util/nest_test.py | 8 +++++++ 7 files changed, 86 insertions(+), 13 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 20d66d7f231..71df1ee0a50 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -class BucketingTest(test.TestCase): +class GroupByWindowTest(test.TestCase): def testSimple(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -257,16 +257,24 @@ class BucketTest(test.TestCase): def testEvenOddBucketsFilterOutAllOdd(self): def _map_fn(v): - return (v, array_ops.fill([v], v), - array_ops.fill([3], string_ops.as_string(v))) + return {"x": v, + "y": array_ops.fill([v], v), + "z": array_ops.fill([3], string_ops.as_string(v))} + + def _dynamic_pad_fn(bucket, window, _): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch( + 32, {"x": tensor_shape.TensorShape([]), + "y": tensor_shape.TensorShape([None]), + "z": tensor_shape.TensorShape([3])}))) input_dataset = ( dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) - .filter(lambda x, y, z: math_ops.equal(x % 2, 0))) + .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) bucketed_dataset = input_dataset.group_by_window( - lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), - lambda k, bucket: self._dynamicPad(k, bucket, 32), 32) + lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), + lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32) iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) init_op = iterator.initializer @@ -283,9 +291,9 @@ class BucketTest(test.TestCase): self.assertAllEqual(0, which_bucket0) self.assertAllEqual(0, which_bucket1) self.assertAllEqual( - np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0[0]) + np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"]) self.assertAllEqual( - np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1[0]) + np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 19be94e1742..e6d50dc1547 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -83,6 +83,23 @@ class FilterDatasetTest(test.TestCase): self.assertEqual(1, sess.run(get_next)) self.assertEqual(3, sess.run(get_next)) + def testFilterDict(self): + iterator = (dataset_ops.Dataset.range(10) + .map(lambda x: {"foo": x * 2, "bar": x ** 2}) + .filter(lambda d: math_ops.equal(d["bar"] % 2, 0)) + .map(lambda d: d["foo"] + d["bar"]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + if (i ** 2) % 2 == 0: + self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index 3c9c714bde4..ace0dd3668a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -101,6 +101,23 @@ class FlatMapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess = random.choice([sess1, sess2]) sess.run(get_next) + + def testMapDict(self): + iterator = (dataset_ops.Dataset.range(10) + .map(lambda x: {"foo": x * 2, "bar": x ** 2}) + .flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"]) + .repeat(d["bar"])) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for _ in range(i ** 2): + self.assertEqual(i * 2, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) # pylint: enable=g-long-lambda diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index b5956ac49c3..2c07248c541 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -324,5 +324,21 @@ class MapDatasetTest(test.TestCase): # Randomness is repeatable given same seed self.assertAllClose(random_values, random_values_2) + def testMapDict(self): + iterator = (dataset_ops.Dataset.range(10) + .map(lambda x: {"foo": x * 2, "bar": x ** 2}) + .map(lambda d: d["foo"] + d["bar"]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 29f1209a58a..a689bfc9019 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -1357,6 +1357,11 @@ class DenseToSparseBatchDataset(Dataset): return (dtypes.int64, self._input_dataset.output_types, dtypes.int64) +def _should_unpack_args(args): + """Returns `True` if `args` should be `*args` when passed to a callable.""" + return nest.is_sequence(args) and not isinstance(args, dict) + + class _ResourceDataset(Dataset): """A Dataset wrapper for a tf.resource-typed function argument.""" @@ -1394,7 +1399,7 @@ class GroupByWindowDataset(Dataset): for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - if nest.is_sequence(nested_args): + if _should_unpack_args(nested_args): ret = key_func(*nested_args) else: ret = key_func(nested_args) @@ -1483,7 +1488,7 @@ class MapDataset(Dataset): nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - if nest.is_sequence(nested_args): + if _should_unpack_args(nested_args): ret = map_func(*nested_args) else: ret = map_func(nested_args) @@ -1559,7 +1564,7 @@ class FlatMapDataset(Dataset): nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - if nest.is_sequence(nested_args): + if _should_unpack_args(nested_args): dataset = map_func(*nested_args) else: dataset = map_func(nested_args) @@ -1609,7 +1614,7 @@ class FilterDataset(Dataset): nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - if nest.is_sequence(nested_args): + if _should_unpack_args(nested_args): ret = predicate(*nested_args) else: ret = predicate(nested_args) diff --git a/tensorflow/contrib/data/python/util/nest.py b/tensorflow/contrib/data/python/util/nest.py index 91c8416d5ae..a29c3c562bd 100644 --- a/tensorflow/contrib/data/python/util/nest.py +++ b/tensorflow/contrib/data/python/util/nest.py @@ -286,7 +286,8 @@ def map_structure(func, *structure, **check_types_dict): def _yield_flat_up_to(shallow_tree, input_tree): """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" if is_sequence(shallow_tree): - for shallow_branch, input_branch in zip(shallow_tree, input_tree): + for shallow_branch, input_branch in zip(_elements_of(shallow_tree), + _elements_of(input_tree)): for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): yield input_leaf else: @@ -495,6 +496,7 @@ def map_structure_up_to(shallow_tree, func, *inputs): # then repack based on the structure of the first input. all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) for input_tree in inputs] + results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] return pack_sequence_as(structure=shallow_tree, flat_sequence=results) diff --git a/tensorflow/contrib/data/python/util/nest_test.py b/tensorflow/contrib/data/python/util/nest_test.py index 7852e4f8617..5132881afb9 100644 --- a/tensorflow/contrib/data/python/util/nest_test.py +++ b/tensorflow/contrib/data/python/util/nest_test.py @@ -287,6 +287,14 @@ class NestTest(test.TestCase): flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, list(shallow_tree)) + # Using dict. + input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))} + shallow_tree = {"a": (True, True), "b": (False, True)} + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)]) + self.assertEqual(flattened_shallow_tree, [True, True, False, True]) + def testMapStructureUpTo(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul")