[tf.contrib.data] Fix the handling of dict-typed elements in functions.
Previously, we were treating a `dict` as a sequence, which led to incorrect behavior like passing all of the dict's keys rather than values as the arguments to a map or filter function. This change changes the behavior so that the dict is passed as a single argument to these functions. It additionally fixes the ported version of `nest.flatten_up_to()` so that `Dataset.padded_batch()` works with dict-typed elements. Fixes #11016. PiperOrigin-RevId: 160548475
This commit is contained in:
parent
c1087b3a0b
commit
9b11f45819
@ -31,7 +31,7 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BucketingTest(test.TestCase):
|
||||
class GroupByWindowTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
@ -257,16 +257,24 @@ class BucketTest(test.TestCase):
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
return {"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))}
|
||||
|
||||
def _dynamic_pad_fn(bucket, window, _):
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
|
||||
32, {"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])})))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
|
||||
.filter(lambda x, y, z: math_ops.equal(x % 2, 0)))
|
||||
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
|
||||
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
@ -283,9 +291,9 @@ class BucketTest(test.TestCase):
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
self.assertAllEqual(0, which_bucket1)
|
||||
self.assertAllEqual(
|
||||
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0[0])
|
||||
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1[0])
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -83,6 +83,23 @@ class FilterDatasetTest(test.TestCase):
|
||||
self.assertEqual(1, sess.run(get_next))
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
|
||||
def testFilterDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(lambda x: {"foo": x * 2, "bar": x ** 2})
|
||||
.filter(lambda d: math_ops.equal(d["bar"] % 2, 0))
|
||||
.map(lambda d: d["foo"] + d["bar"])
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
if (i ** 2) % 2 == 0:
|
||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -101,6 +101,23 @@ class FlatMapDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess = random.choice([sess1, sess2])
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(lambda x: {"foo": x * 2, "bar": x ** 2})
|
||||
.flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"])
|
||||
.repeat(d["bar"]))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
for _ in range(i ** 2):
|
||||
self.assertEqual(i * 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
|
||||
|
@ -324,5 +324,21 @@ class MapDatasetTest(test.TestCase):
|
||||
# Randomness is repeatable given same seed
|
||||
self.assertAllClose(random_values, random_values_2)
|
||||
|
||||
def testMapDict(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.map(lambda x: {"foo": x * 2, "bar": x ** 2})
|
||||
.map(lambda d: d["foo"] + d["bar"])
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1357,6 +1357,11 @@ class DenseToSparseBatchDataset(Dataset):
|
||||
return (dtypes.int64, self._input_dataset.output_types, dtypes.int64)
|
||||
|
||||
|
||||
def _should_unpack_args(args):
|
||||
"""Returns `True` if `args` should be `*args` when passed to a callable."""
|
||||
return nest.is_sequence(args) and not isinstance(args, dict)
|
||||
|
||||
|
||||
class _ResourceDataset(Dataset):
|
||||
"""A Dataset wrapper for a tf.resource-typed function argument."""
|
||||
|
||||
@ -1394,7 +1399,7 @@ class GroupByWindowDataset(Dataset):
|
||||
for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
|
||||
arg.set_shape(shape)
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
if nest.is_sequence(nested_args):
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = key_func(*nested_args)
|
||||
else:
|
||||
ret = key_func(nested_args)
|
||||
@ -1483,7 +1488,7 @@ class MapDataset(Dataset):
|
||||
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
|
||||
if nest.is_sequence(nested_args):
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = map_func(*nested_args)
|
||||
else:
|
||||
ret = map_func(nested_args)
|
||||
@ -1559,7 +1564,7 @@ class FlatMapDataset(Dataset):
|
||||
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
|
||||
if nest.is_sequence(nested_args):
|
||||
if _should_unpack_args(nested_args):
|
||||
dataset = map_func(*nested_args)
|
||||
else:
|
||||
dataset = map_func(nested_args)
|
||||
@ -1609,7 +1614,7 @@ class FilterDataset(Dataset):
|
||||
|
||||
nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
|
||||
|
||||
if nest.is_sequence(nested_args):
|
||||
if _should_unpack_args(nested_args):
|
||||
ret = predicate(*nested_args)
|
||||
else:
|
||||
ret = predicate(nested_args)
|
||||
|
@ -286,7 +286,8 @@ def map_structure(func, *structure, **check_types_dict):
|
||||
def _yield_flat_up_to(shallow_tree, input_tree):
|
||||
"""Yields elements `input_tree` partially flattened up to `shallow_tree`."""
|
||||
if is_sequence(shallow_tree):
|
||||
for shallow_branch, input_branch in zip(shallow_tree, input_tree):
|
||||
for shallow_branch, input_branch in zip(_elements_of(shallow_tree),
|
||||
_elements_of(input_tree)):
|
||||
for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
|
||||
yield input_leaf
|
||||
else:
|
||||
@ -495,6 +496,7 @@ def map_structure_up_to(shallow_tree, func, *inputs):
|
||||
# then repack based on the structure of the first input.
|
||||
all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
|
||||
for input_tree in inputs]
|
||||
|
||||
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
|
||||
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
|
||||
|
||||
|
@ -287,6 +287,14 @@ class NestTest(test.TestCase):
|
||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
|
||||
|
||||
# Using dict.
|
||||
input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
|
||||
shallow_tree = {"a": (True, True), "b": (False, True)}
|
||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
|
||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
|
||||
self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
|
||||
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
|
||||
|
||||
def testMapStructureUpTo(self):
|
||||
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||
op_tuple = collections.namedtuple("op_tuple", "add, mul")
|
||||
|
Loading…
Reference in New Issue
Block a user