Add test for tf.data.Dataset.from_generator with dictionary elements.

PiperOrigin-RevId: 302946391
Change-Id: I150cf9493052c163db8ab90374722bbc9a5bc9f0
This commit is contained in:
Adam Roberts 2020-03-25 12:33:10 -07:00 committed by TensorFlower Gardener
parent 4f39560183
commit 0349c8ddf5

View File

@ -226,6 +226,23 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[b"foo", b"bar", b"baz"])
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorDict(self):
def generator():
yield {"a": "foo", "b": [1, 2]}
yield {"a": "bar", "b": [3, 4]}
yield {"a": "baz", "b": [5, 6]}
dataset = dataset_ops.Dataset.from_generator(
generator,
output_types={"a": dtypes.string, "b": dtypes.int32},
output_shapes={"a": [], "b": [None]})
self.assertDatasetProduces(
dataset,
expected_output=[{"a": b"foo", "b": [1, 2]},
{"a": b"bar", "b": [3, 4]},
{"a": b"baz", "b": [5, 6]}])
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorTypeError(self):
def generator():