[tf.data] Adding support for None
as a component value.
PiperOrigin-RevId: 295257787 Change-Id: I1a008f7340953142156994e154be7d65299cb31e
This commit is contained in:
parent
a755fe8236
commit
a31daa50d0
@ -223,6 +223,12 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
]
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoneComponent(self):
|
||||
dataset = dataset_ops.Dataset.range(10).map(lambda x: (x, None)).batch(
|
||||
10).map(lambda x, y: x)
|
||||
self.assertDatasetProduces(dataset, expected_output=[list(range(10))])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -515,6 +515,34 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(([], ([], []), []),
|
||||
dataset_ops.get_legacy_output_shapes(dataset))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoneComponent(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors((42, None))
|
||||
if context.executing_eagerly():
|
||||
self.assertDatasetProduces(dataset, expected_output=[(42, None)])
|
||||
else:
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
next_first, next_second = iterator.get_next()
|
||||
self.assertEqual(next_second, None)
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(sess.run(next_first), 42)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoneComponentInFunction(self):
|
||||
|
||||
@def_function.function
|
||||
def fn(ds):
|
||||
total = 0
|
||||
it = iter(ds)
|
||||
for elem in it:
|
||||
x, _ = elem
|
||||
total += x
|
||||
return total
|
||||
|
||||
dataset = dataset_ops.Dataset.range(
|
||||
10, output_type=dtypes.int32).map(lambda x: (x, None))
|
||||
self.assertEqual(self.evaluate(fn(dataset)), 45)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -140,16 +140,23 @@ def _make_coordinated_sloppy_dataset(apply_map, num_elements,
|
||||
coordination_events[x].clear()
|
||||
return x * x
|
||||
|
||||
def map_fn(x):
|
||||
def fn(x):
|
||||
return script_ops.py_func(map_py_fn, [x], x.dtype)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = False
|
||||
dataset = dataset_ops.Dataset.range(num_elements)
|
||||
dataset = apply_map(dataset, map_fn, num_parallel_calls).with_options(options)
|
||||
dataset = apply_map(dataset, fn, num_parallel_calls).with_options(options)
|
||||
return dataset, coordination_events
|
||||
|
||||
|
||||
class Foo(object):
|
||||
"""Dummy class used for invalid return value tests."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def _map_dataset_factory(self, components, apply_map, count):
|
||||
@ -1007,8 +1014,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r"Unsupported return value from function passed to "
|
||||
r"Dataset.map\(\): None."):
|
||||
_ = apply_map(dataset, lambda x: None)
|
||||
r"Dataset.map\(\)"):
|
||||
_ = apply_map(dataset, lambda x: Foo)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testBrokenFunctionErrorOnInitialization(self):
|
||||
@ -1361,6 +1368,18 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.checkDeterminism(
|
||||
dataset_fn, expect_determinism, expected_elements=elements)
|
||||
|
||||
@combinations.generate(_test_combinations())
|
||||
def testNoneComponent(self, apply_map):
|
||||
dataset = dataset_ops.Dataset.from_tensors((42, None))
|
||||
|
||||
def map_function(x, y):
|
||||
if y is None:
|
||||
return x / 2
|
||||
return x
|
||||
|
||||
dataset = apply_map(dataset, map_function)
|
||||
self.assertDatasetProduces(dataset, expected_output=[21])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -217,6 +217,12 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||
self.assertDatasetProduces(data, expected_output)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoneComponent(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
(list(range(10)), None)).unbatch().map(lambda x, y: x)
|
||||
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -106,6 +106,8 @@ def normalize_element(element):
|
||||
elif isinstance(
|
||||
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
||||
normalized_components.append(t)
|
||||
elif isinstance(spec, NoneTensorSpec):
|
||||
normalized_components.append(NoneTensor())
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
normalized_components.append(t)
|
||||
else:
|
||||
@ -462,3 +464,65 @@ def type_spec_from_value(element, use_fallback=True):
|
||||
|
||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||
(element, type(element).__name__))
|
||||
|
||||
|
||||
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
||||
# functionality.
|
||||
class NoneTensor(composite_tensor.CompositeTensor):
|
||||
"""Composite tensor representation for `None` value."""
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return NoneTensorSpec()
|
||||
|
||||
|
||||
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
||||
# functionality.
|
||||
class NoneTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for `None` value."""
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return NoneTensor
|
||||
|
||||
def _serialize(self):
|
||||
return ()
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
return []
|
||||
|
||||
def _to_components(self, value):
|
||||
return []
|
||||
|
||||
def _from_components(self, components):
|
||||
return
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
return NoneTensorSpec()
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return NoneTensorSpec()
|
||||
|
||||
def _unbatch(self):
|
||||
return NoneTensorSpec()
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
return []
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return self
|
||||
|
||||
|
||||
type_spec.register_type_spec_from_value_converter(type(None),
|
||||
NoneTensorSpec.from_value)
|
||||
|
Loading…
Reference in New Issue
Block a user