From 0349c8ddf5b185ae102e62b5de2bddbe56773926 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 25 Mar 2020 12:33:10 -0700 Subject: [PATCH] Add test for tf.data.Dataset.from_generator with dictionary elements. PiperOrigin-RevId: 302946391 Change-Id: I150cf9493052c163db8ab90374722bbc9a5bc9f0 --- .../data/kernel_tests/from_generator_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tensorflow/python/data/kernel_tests/from_generator_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py index d320b281136..ecc022b58a5 100644 --- a/tensorflow/python/data/kernel_tests/from_generator_test.py +++ b/tensorflow/python/data/kernel_tests/from_generator_test.py @@ -226,6 +226,23 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[b"foo", b"bar", b"baz"]) + @combinations.generate(test_base.default_test_combinations()) + def testFromGeneratorDict(self): + def generator(): + yield {"a": "foo", "b": [1, 2]} + yield {"a": "bar", "b": [3, 4]} + yield {"a": "baz", "b": [5, 6]} + + dataset = dataset_ops.Dataset.from_generator( + generator, + output_types={"a": dtypes.string, "b": dtypes.int32}, + output_shapes={"a": [], "b": [None]}) + self.assertDatasetProduces( + dataset, + expected_output=[{"a": b"foo", "b": [1, 2]}, + {"a": b"bar", "b": [3, 4]}, + {"a": b"baz", "b": [5, 6]}]) + @combinations.generate(test_base.default_test_combinations()) def testFromGeneratorTypeError(self): def generator():