Remaining core kernel tests coverage.

PiperOrigin-RevId: 224865488
This commit is contained in:
Shivani Agrawal 2018-12-10 12:53:22 -08:00 committed by TensorFlower Gardener
parent 4bc66cd75a
commit 51a86aae7c
2 changed files with 458 additions and 673 deletions

View File

@ -21,7 +21,6 @@ import threading
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@ -32,43 +31,27 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class FromGeneratorTest(test_base.DatasetTestBase):
@test_util.run_all_in_graph_and_eager_modes
class DatasetConstructorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
if output_types is None:
output_types = dtypes.int64
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(generator, output_types=output_types)
.repeat(num_repeats)
.prefetch(5))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
for _ in range(2): # Run twice to test reinitialization.
sess.run(init_op)
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=output_types).repeat(num_repeats).prefetch(5)
self.assertDatasetProduces(
dataset,
elem_sequence * num_repeats,
requires_initialization=True,
num_test_iterations=2)
def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
iterator = dataset_ops.make_one_shot_iterator(
dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
.repeat(num_repeats)
.prefetch(5))
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5)
self.assertDatasetProduces(
dataset, elem_sequence * num_repeats, num_test_iterations=2)
with self.cached_session() as sess:
for _ in range(num_repeats):
for elem in elem_sequence:
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@test_util.run_deprecated_v1
def testFromGeneratorUsingFunction(self):
def generator():
for i in range(1, 100):
@ -79,21 +62,18 @@ class FromGeneratorTest(test_base.DatasetTestBase):
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
@test_util.run_deprecated_v1
def testFromGeneratorUsingList(self):
generator = lambda: [[i] * i for i in range(1, 100)]
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
@test_util.run_deprecated_v1
def testFromGeneratorUsingNdarray(self):
generator = lambda: np.arange(100, dtype=np.int64)
elem_sequence = list(generator())
self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
@test_util.run_deprecated_v1
def testFromGeneratorUsingGeneratorExpression(self):
# NOTE(mrry): Generator *expressions* are not repeatable (or in
# general reusable), because they eagerly evaluate the `for`
@ -105,7 +85,6 @@ class FromGeneratorTest(test_base.DatasetTestBase):
self._testFromGenerator(generator, elem_sequence, 1)
self._testFromGenerator(generator, elem_sequence, 5)
@test_util.run_deprecated_v1
def testFromMultipleConcurrentGenerators(self):
num_inner_repeats = 5
num_outer_repeats = 100
@ -128,22 +107,16 @@ class FromGeneratorTest(test_base.DatasetTestBase):
output_shapes=([None], [3]))
.repeat(num_inner_repeats).prefetch(5))
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.range(num_outer_repeats)
.interleave(interleave_fn, cycle_length=10,
block_length=len(input_list)))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
val0, val1 = sess.run(get_next)
self.assertAllEqual(elem[0], val0)
self.assertAllEqual(elem[1], val1)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
dataset = dataset_ops.Dataset.range(num_outer_repeats).interleave(
interleave_fn, cycle_length=10, block_length=len(input_list))
get_next = self.getNext(dataset)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
val0, val1 = self.evaluate(get_next())
self.assertAllEqual(elem[0], val0)
self.assertAllEqual(elem[1], val1)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
# TODO(b/67868766): Reenable this when the source of flakiness is discovered.
def _testFromGeneratorsRunningInParallel(self):
@ -186,22 +159,16 @@ class FromGeneratorTest(test_base.DatasetTestBase):
return dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2)
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.range(num_parallel_iterators)
.interleave(
interleave_fn, cycle_length=num_parallel_iterators, block_length=1))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.range(num_parallel_iterators).interleave(
interleave_fn, cycle_length=num_parallel_iterators, block_length=1)
get_next = self.getNext(dataset)
with self.cached_session() as sess:
sess.run(init_op)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
self.assertAllEqual(elem, self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@test_util.run_deprecated_v1
def testFromGeneratorImplicitConversion(self):
def generator():
yield [1]
@ -209,45 +176,28 @@ class FromGeneratorTest(test_base.DatasetTestBase):
yield [3]
for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]:
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=dtype, output_shapes=[1]))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtype, output_shapes=[1])
get_next = self.getNext(dataset)
self.assertEqual(dtype, get_next.dtype)
for expected in [[1], [2], [3]]:
next_val = self.evaluate(get_next())
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.cached_session() as sess:
sess.run(init_op)
for expected in [[1], [2], [3]]:
next_val = sess.run(get_next)
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@test_util.run_deprecated_v1
def testFromGeneratorString(self):
def generator():
yield "foo"
yield b"bar"
yield u"baz"
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.string, output_shapes=[]))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.string, output_shapes=[])
self.assertDatasetProduces(
dataset, expected_output=[b"foo", b"bar", b"baz"])
with self.cached_session() as sess:
sess.run(init_op)
for expected in [b"foo", b"bar", b"baz"]:
next_val = sess.run(get_next)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@test_util.run_deprecated_v1
def testFromGeneratorTypeError(self):
def generator():
yield np.array([1, 2, 3], dtype=np.int64)
@ -255,23 +205,19 @@ class FromGeneratorTest(test_base.DatasetTestBase):
yield "ERROR"
yield np.array([7, 8, 9], dtype=np.int64)
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[3]))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[3])
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError("The expected type was int64"):
sess.run(get_next)
self.assertAllEqual([7, 8, 9], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
get_next = self.getNext(dataset)
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
with self.assertRaisesOpError("The expected type was int64"):
self.evaluate(get_next())
self.assertAllEqual([7, 8, 9], self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@test_util.run_deprecated_v1
def testFromGeneratorShapeError(self):
def generator():
yield np.array([1, 2, 3], dtype=np.int64)
@ -279,23 +225,18 @@ class FromGeneratorTest(test_base.DatasetTestBase):
yield np.array([7, 8, 9, 10], dtype=np.int64)
yield np.array([11, 12, 13], dtype=np.int64)
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[3]))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64, output_shapes=[3])
get_next = self.getNext(dataset)
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
sess.run(get_next)
self.assertAllEqual([11, 12, 13], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
self.evaluate(get_next())
self.assertAllEqual([11, 12, 13], self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@test_util.run_deprecated_v1
def testFromGeneratorStructureError(self):
def generator():
yield 1, 2
@ -304,46 +245,31 @@ class FromGeneratorTest(test_base.DatasetTestBase):
yield 6, 7, 8
yield 9, 10
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=(dtypes.int64, dtypes.int64)))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=(dtypes.int64, dtypes.int64))
get_next = self.getNext(dataset)
with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((1, 2), sess.run(get_next))
self.assertEqual((3, 4), sess.run(get_next))
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
sess.run(get_next)
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
sess.run(get_next)
self.assertEqual((9, 10), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual((1, 2), self.evaluate(get_next()))
self.assertEqual((3, 4), self.evaluate(get_next()))
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
self.evaluate(get_next())
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
self.evaluate(get_next())
self.assertEqual((9, 10), self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@test_util.run_deprecated_v1
def testFromGeneratorHeterogeneous(self):
def generator():
yield 1
yield [2, 3]
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64)
self.assertDatasetProduces(dataset, expected_output=[1, [2, 3]])
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(1, sess.run(get_next))
self.assertAllEqual([2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@test_util.run_deprecated_v1
def testFromGeneratorStopShort(self):
def generator():
@ -351,18 +277,12 @@ class FromGeneratorTest(test_base.DatasetTestBase):
yield 1
yield 2
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int64)
get_next = self.getNext(dataset)
self.assertAllEqual(0, self.evaluate(get_next()))
self.assertAllEqual(1, self.evaluate(get_next()))
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(0, sess.run(get_next))
self.assertAllEqual(1, sess.run(get_next))
@test_util.run_deprecated_v1
def testFromGeneratorDestructorCalled(self):
# Use an `Event` to signal that the generator has been deleted.
event = threading.Event()
@ -381,23 +301,18 @@ class FromGeneratorTest(test_base.DatasetTestBase):
def __del__(self):
event.set()
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_generator(
GeneratorWrapper, output_types=dtypes.int64).take(2))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.from_generator(
GeneratorWrapper, output_types=dtypes.int64).take(2)
get_next = self.getNext(dataset)
with session.Session() as sess:
sess.run(init_op)
self.assertAllEqual(42, sess.run(get_next))
self.assertAllEqual(42, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `GeneratorWrapper` object is destroyed when the
# iterator terminates (and the generator iterator is deleted).
self.assertTrue(event.is_set())
self.assertAllEqual(42, self.evaluate(get_next()))
self.assertAllEqual(42, self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
# Test that `GeneratorWrapper` object is destroyed when the
# iterator terminates (and the generator iterator is deleted).
self.assertTrue(event.is_set())
@test_util.run_deprecated_v1
def testFromGeneratorWithArgs(self):
def flat_map_fn(elem):
@ -410,20 +325,10 @@ class FromGeneratorTest(test_base.DatasetTestBase):
generator_with_arg, output_types=dtypes.int64, output_shapes=(),
args=(elem,))
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.range(5).flat_map(flat_map_fn))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
self.assertDatasetProduces(
dataset, expected_output=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4])
with self.cached_session() as sess:
sess.run(init_op)
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
for x in expected:
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@test_util.run_deprecated_v1
def testFromGeneratorWithTwoArgs(self):
def flat_map_fn(elem, message):
@ -436,26 +341,17 @@ class FromGeneratorTest(test_base.DatasetTestBase):
generator_with_arg, output_types=(dtypes.int64, dtypes.string),
output_shapes=((), ()), args=(elem, message))
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5),
dataset_ops.Dataset.from_tensors("Hi!").repeat(None)))
.flat_map(flat_map_fn))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5),
dataset_ops.Dataset.from_tensors("Hi!").repeat(None)
)).flat_map(flat_map_fn)
with self.cached_session() as sess:
sess.run(init_op)
expected = [(0, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
for x in expected:
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertDatasetProduces(
dataset,
expected_output=[(0, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (0, b"Hi!"),
(1, b"Hi!"), (2, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"),
(2, b"Hi!"), (3, b"Hi!")])
@test_util.run_deprecated_v1
def testGeneratorDatasetFinalizeFunctionCalled(self):
# NOTE(mrry): This test tests the internal `_GeneratorDataset`,
# which affords more control over what the finalize function can do than
@ -472,19 +368,15 @@ class FromGeneratorTest(test_base.DatasetTestBase):
stateful=True)
dummy = constant_op.constant(37)
iterator = dataset_ops.make_initializable_iterator(
dataset_ops._GeneratorDataset(
dummy, lambda x: x, lambda x: x, finalize_fn).take(2))
init_op = iterator.initializer
get_next = iterator.get_next()
dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x,
finalize_fn).take(2)
get_next = self.getNext(dataset)
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(37, sess.run(get_next))
self.assertAllEqual(37, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertTrue(event.is_set())
self.assertAllEqual(37, self.evaluate(get_next()))
self.assertAllEqual(37, self.evaluate(get_next()))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
self.assertTrue(event.is_set())
if __name__ == "__main__":

File diff suppressed because it is too large Load Diff