[tf.data] Adding support for None
as a component value.
PiperOrigin-RevId: 295257787 Change-Id: I1a008f7340953142156994e154be7d65299cb31e
This commit is contained in:
parent
a755fe8236
commit
a31daa50d0
@ -223,6 +223,12 @@ class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testNoneComponent(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(10).map(lambda x: (x, None)).batch(
|
||||||
|
10).map(lambda x, y: x)
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=[list(range(10))])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -515,6 +515,34 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(([], ([], []), []),
|
self.assertEqual(([], ([], []), []),
|
||||||
dataset_ops.get_legacy_output_shapes(dataset))
|
dataset_ops.get_legacy_output_shapes(dataset))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testNoneComponent(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors((42, None))
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=[(42, None)])
|
||||||
|
else:
|
||||||
|
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||||
|
next_first, next_second = iterator.get_next()
|
||||||
|
self.assertEqual(next_second, None)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
self.assertEqual(sess.run(next_first), 42)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testNoneComponentInFunction(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def fn(ds):
|
||||||
|
total = 0
|
||||||
|
it = iter(ds)
|
||||||
|
for elem in it:
|
||||||
|
x, _ = elem
|
||||||
|
total += x
|
||||||
|
return total
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(
|
||||||
|
10, output_type=dtypes.int32).map(lambda x: (x, None))
|
||||||
|
self.assertEqual(self.evaluate(fn(dataset)), 45)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -140,16 +140,23 @@ def _make_coordinated_sloppy_dataset(apply_map, num_elements,
|
|||||||
coordination_events[x].clear()
|
coordination_events[x].clear()
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
def map_fn(x):
|
def fn(x):
|
||||||
return script_ops.py_func(map_py_fn, [x], x.dtype)
|
return script_ops.py_func(map_py_fn, [x], x.dtype)
|
||||||
|
|
||||||
options = dataset_ops.Options()
|
options = dataset_ops.Options()
|
||||||
options.experimental_deterministic = False
|
options.experimental_deterministic = False
|
||||||
dataset = dataset_ops.Dataset.range(num_elements)
|
dataset = dataset_ops.Dataset.range(num_elements)
|
||||||
dataset = apply_map(dataset, map_fn, num_parallel_calls).with_options(options)
|
dataset = apply_map(dataset, fn, num_parallel_calls).with_options(options)
|
||||||
return dataset, coordination_events
|
return dataset, coordination_events
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(object):
|
||||||
|
"""Dummy class used for invalid return value tests."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
def _map_dataset_factory(self, components, apply_map, count):
|
def _map_dataset_factory(self, components, apply_map, count):
|
||||||
@ -1007,8 +1014,8 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
|
dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
TypeError, r"Unsupported return value from function passed to "
|
TypeError, r"Unsupported return value from function passed to "
|
||||||
r"Dataset.map\(\): None."):
|
r"Dataset.map\(\)"):
|
||||||
_ = apply_map(dataset, lambda x: None)
|
_ = apply_map(dataset, lambda x: Foo)
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testBrokenFunctionErrorOnInitialization(self):
|
def testBrokenFunctionErrorOnInitialization(self):
|
||||||
@ -1361,6 +1368,18 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.checkDeterminism(
|
self.checkDeterminism(
|
||||||
dataset_fn, expect_determinism, expected_elements=elements)
|
dataset_fn, expect_determinism, expected_elements=elements)
|
||||||
|
|
||||||
|
@combinations.generate(_test_combinations())
|
||||||
|
def testNoneComponent(self, apply_map):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors((42, None))
|
||||||
|
|
||||||
|
def map_function(x, y):
|
||||||
|
if y is None:
|
||||||
|
return x / 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
dataset = apply_map(dataset, map_function)
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=[21])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -217,6 +217,12 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
|
||||||
self.assertDatasetProduces(data, expected_output)
|
self.assertDatasetProduces(data, expected_output)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testNoneComponent(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
(list(range(10)), None)).unbatch().map(lambda x, y: x)
|
||||||
|
self.assertDatasetProduces(dataset, expected_output=range(10))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -106,6 +106,8 @@ def normalize_element(element):
|
|||||||
elif isinstance(
|
elif isinstance(
|
||||||
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
||||||
normalized_components.append(t)
|
normalized_components.append(t)
|
||||||
|
elif isinstance(spec, NoneTensorSpec):
|
||||||
|
normalized_components.append(NoneTensor())
|
||||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||||
normalized_components.append(t)
|
normalized_components.append(t)
|
||||||
else:
|
else:
|
||||||
@ -462,3 +464,65 @@ def type_spec_from_value(element, use_fallback=True):
|
|||||||
|
|
||||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||||
(element, type(element).__name__))
|
(element, type(element).__name__))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
||||||
|
# functionality.
|
||||||
|
class NoneTensor(composite_tensor.CompositeTensor):
|
||||||
|
"""Composite tensor representation for `None` value."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type_spec(self):
|
||||||
|
return NoneTensorSpec()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/149584798): Move this to framework and add tests for non-tf.data
|
||||||
|
# functionality.
|
||||||
|
class NoneTensorSpec(type_spec.BatchableTypeSpec):
|
||||||
|
"""Type specification for `None` value."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value_type(self):
|
||||||
|
return NoneTensor
|
||||||
|
|
||||||
|
def _serialize(self):
|
||||||
|
return ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _component_specs(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _to_components(self, value):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _from_components(self, components):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _to_tensor_list(self, value):
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_value(value):
|
||||||
|
return NoneTensorSpec()
|
||||||
|
|
||||||
|
def _batch(self, batch_size):
|
||||||
|
return NoneTensorSpec()
|
||||||
|
|
||||||
|
def _unbatch(self):
|
||||||
|
return NoneTensorSpec()
|
||||||
|
|
||||||
|
def _to_batched_tensor_list(self, value):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _to_legacy_output_types(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _to_legacy_output_shapes(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _to_legacy_output_classes(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
type_spec.register_type_spec_from_value_converter(type(None),
|
||||||
|
NoneTensorSpec.from_value)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user