Add test for tf.data.Dataset.from_generator with dictionary elements.
PiperOrigin-RevId: 302946391 Change-Id: I150cf9493052c163db8ab90374722bbc9a5bc9f0
This commit is contained in:
parent
4f39560183
commit
0349c8ddf5
@ -226,6 +226,23 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[b"foo", b"bar", b"baz"])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorDict(self):
|
||||
def generator():
|
||||
yield {"a": "foo", "b": [1, 2]}
|
||||
yield {"a": "bar", "b": [3, 4]}
|
||||
yield {"a": "baz", "b": [5, 6]}
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator,
|
||||
output_types={"a": dtypes.string, "b": dtypes.int32},
|
||||
output_shapes={"a": [], "b": [None]})
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
expected_output=[{"a": b"foo", "b": [1, 2]},
|
||||
{"a": b"bar", "b": [3, 4]},
|
||||
{"a": b"baz", "b": [5, 6]}])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromGeneratorTypeError(self):
|
||||
def generator():
|
||||
|
Loading…
Reference in New Issue
Block a user