[tf.data] Adding support for None as a component value.

PiperOrigin-RevId: 295257787
Change-Id: I1a008f7340953142156994e154be7d65299cb31e
This commit is contained in:
Jiri Simsa 2020-02-14 17:02:20 -08:00 committed by TensorFlower Gardener
parent a755fe8236
commit a31daa50d0
5 changed files with 127 additions and 4 deletions

View File

@ -223,6 +223,12 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
]
self.assertDatasetProduces(dataset, expected_output=expected_output)
@combinations.generate(test_base.default_test_combinations())
def testNoneComponent(self):
dataset = dataset_ops.Dataset.range(10).map(lambda x: (x, None)).batch(
10).map(lambda x, y: x)
self.assertDatasetProduces(dataset, expected_output=[list(range(10))])
if __name__ == '__main__':
test.main()

View File

@ -515,6 +515,34 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(([], ([], []), []),
dataset_ops.get_legacy_output_shapes(dataset))
@combinations.generate(test_base.default_test_combinations())
def testNoneComponent(self):
dataset = dataset_ops.Dataset.from_tensors((42, None))
if context.executing_eagerly():
self.assertDatasetProduces(dataset, expected_output=[(42, None)])
else:
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_first, next_second = iterator.get_next()
self.assertEqual(next_second, None)
with self.cached_session() as sess:
self.assertEqual(sess.run(next_first), 42)
@combinations.generate(test_base.default_test_combinations())
def testNoneComponentInFunction(self):
@def_function.function
def fn(ds):
total = 0
it = iter(ds)
for elem in it:
x, _ = elem
total += x
return total
dataset = dataset_ops.Dataset.range(
10, output_type=dtypes.int32).map(lambda x: (x, None))
self.assertEqual(self.evaluate(fn(dataset)), 45)
if __name__ == "__main__":
test.main()

View File

@ -140,16 +140,23 @@ def _make_coordinated_sloppy_dataset(apply_map, num_elements,
coordination_events[x].clear()
return x * x
def map_fn(x):
def fn(x):
return script_ops.py_func(map_py_fn, [x], x.dtype)
options = dataset_ops.Options()
options.experimental_deterministic = False
dataset = dataset_ops.Dataset.range(num_elements)
dataset = apply_map(dataset, map_fn, num_parallel_calls).with_options(options)
dataset = apply_map(dataset, fn, num_parallel_calls).with_options(options)
return dataset, coordination_events
class Foo(object):
"""Dummy class used for invalid return value tests."""
def __init__(self):
pass
class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
def _map_dataset_factory(self, components, apply_map, count):
@ -1007,8 +1014,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
with self.assertRaisesRegexp(
TypeError, r"Unsupported return value from function passed to "
r"Dataset.map\(\): None."):
_ = apply_map(dataset, lambda x: None)
r"Dataset.map\(\)"):
_ = apply_map(dataset, lambda x: Foo)
@combinations.generate(test_base.default_test_combinations())
def testBrokenFunctionErrorOnInitialization(self):
@ -1361,6 +1368,18 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
self.checkDeterminism(
dataset_fn, expect_determinism, expected_elements=elements)
@combinations.generate(_test_combinations())
def testNoneComponent(self, apply_map):
dataset = dataset_ops.Dataset.from_tensors((42, None))
def map_function(x, y):
if y is None:
return x / 2
return x
dataset = apply_map(dataset, map_function)
self.assertDatasetProduces(dataset, expected_output=[21])
if __name__ == "__main__":
test.main()

View File

@ -217,6 +217,12 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
self.assertDatasetProduces(data, expected_output)
@combinations.generate(test_base.default_test_combinations())
def testNoneComponent(self):
dataset = dataset_ops.Dataset.from_tensors(
(list(range(10)), None)).unbatch().map(lambda x, y: x)
self.assertDatasetProduces(dataset, expected_output=range(10))
if __name__ == "__main__":
test.main()

View File

@ -106,6 +106,8 @@ def normalize_element(element):
elif isinstance(
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
normalized_components.append(t)
elif isinstance(spec, NoneTensorSpec):
normalized_components.append(NoneTensor())
elif isinstance(t, composite_tensor.CompositeTensor):
normalized_components.append(t)
else:
@ -462,3 +464,65 @@ def type_spec_from_value(element, use_fallback=True):
raise TypeError("Could not build a TypeSpec for %r with type %s" %
(element, type(element).__name__))
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
# functionality.
class NoneTensor(composite_tensor.CompositeTensor):
"""Composite tensor representation for `None` value."""
@property
def _type_spec(self):
return NoneTensorSpec()
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
# functionality.
class NoneTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for `None` value."""
@property
def value_type(self):
return NoneTensor
def _serialize(self):
return ()
@property
def _component_specs(self):
return []
def _to_components(self, value):
return []
def _from_components(self, components):
return
def _to_tensor_list(self, value):
return []
@staticmethod
def from_value(value):
return NoneTensorSpec()
def _batch(self, batch_size):
return NoneTensorSpec()
def _unbatch(self):
return NoneTensorSpec()
def _to_batched_tensor_list(self, value):
return []
def _to_legacy_output_types(self):
return self
def _to_legacy_output_shapes(self):
return self
def _to_legacy_output_classes(self):
return self
type_spec.register_type_spec_from_value_converter(type(None),
NoneTensorSpec.from_value)