Add coverage to from_generator_test for different datastructures (i.e., dicts, tuples, lists).

PiperOrigin-RevId: 306269645
Change-Id: I1f1bf2f0b4dba2baf01b8a80640b0834a11bc143
This commit is contained in:
Adam Roberts 2020-04-13 11:15:06 -07:00 committed by TensorFlower Gardener
parent 35d23da927
commit c93a91e0e0

View File

@ -227,21 +227,22 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset, expected_output=[b"foo", b"bar", b"baz"])
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorDict(self):
def testFromGeneratorDatastructures(self):
# Tests multiple datastructures.
def generator():
yield {"a": "foo", "b": [1, 2]}
yield {"a": "bar", "b": [3, 4]}
yield {"a": "baz", "b": [5, 6]}
yield {"a": "foo", "b": [1, 2], "c": (9,)}
yield {"a": "bar", "b": [3], "c": (7, 6)}
yield {"a": "baz", "b": [5, 6], "c": (5, 4)}
dataset = dataset_ops.Dataset.from_generator(
generator,
output_types={"a": dtypes.string, "b": dtypes.int32},
output_shapes={"a": [], "b": [None]})
output_types={"a": dtypes.string, "b": dtypes.int32, "c": dtypes.int32},
output_shapes={"a": [], "b": [None], "c": [None]})
self.assertDatasetProduces(
dataset,
expected_output=[{"a": b"foo", "b": [1, 2]},
{"a": b"bar", "b": [3, 4]},
{"a": b"baz", "b": [5, 6]}])
expected_output=[{"a": b"foo", "b": [1, 2], "c": [9]},
{"a": b"bar", "b": [3], "c": [7, 6]},
{"a": b"baz", "b": [5, 6], "c": [5, 4]}])
@combinations.generate(test_base.default_test_combinations())
def testFromGeneratorTypeError(self):