Remaining core kernel tests coverage.
PiperOrigin-RevId: 224865488
This commit is contained in:
parent
4bc66cd75a
commit
51a86aae7c
@ -21,7 +21,6 @@ import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -32,43 +31,27 @@ from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
|
||||
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
|
||||
output_types=None):
|
||||
if output_types is None:
|
||||
output_types = dtypes.int64
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(generator, output_types=output_types)
|
||||
.repeat(num_repeats)
|
||||
.prefetch(5))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2): # Run twice to test reinitialization.
|
||||
sess.run(init_op)
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=output_types).repeat(num_repeats).prefetch(5)
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
elem_sequence * num_repeats,
|
||||
requires_initialization=True,
|
||||
num_test_iterations=2)
|
||||
|
||||
def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
|
||||
iterator = dataset_ops.make_one_shot_iterator(
|
||||
dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
|
||||
.repeat(num_repeats)
|
||||
.prefetch(5))
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64).repeat(num_repeats).prefetch(5)
|
||||
self.assertDatasetProduces(
|
||||
dataset, elem_sequence * num_repeats, num_test_iterations=2)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(num_repeats):
|
||||
for elem in elem_sequence:
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorUsingFunction(self):
|
||||
def generator():
|
||||
for i in range(1, 100):
|
||||
@ -79,21 +62,18 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
self._testFromGeneratorOneShot(generator, elem_sequence, 1)
|
||||
self._testFromGeneratorOneShot(generator, elem_sequence, 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorUsingList(self):
|
||||
generator = lambda: [[i] * i for i in range(1, 100)]
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorUsingNdarray(self):
|
||||
generator = lambda: np.arange(100, dtype=np.int64)
|
||||
elem_sequence = list(generator())
|
||||
self._testFromGenerator(generator, elem_sequence, 1, output_types=np.int64)
|
||||
self._testFromGenerator(generator, elem_sequence, 5, output_types=np.int64)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorUsingGeneratorExpression(self):
|
||||
# NOTE(mrry): Generator *expressions* are not repeatable (or in
|
||||
# general reusable), because they eagerly evaluate the `for`
|
||||
@ -105,7 +85,6 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
self._testFromGenerator(generator, elem_sequence, 1)
|
||||
self._testFromGenerator(generator, elem_sequence, 5)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromMultipleConcurrentGenerators(self):
|
||||
num_inner_repeats = 5
|
||||
num_outer_repeats = 100
|
||||
@ -128,22 +107,16 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
output_shapes=([None], [3]))
|
||||
.repeat(num_inner_repeats).prefetch(5))
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.range(num_outer_repeats)
|
||||
.interleave(interleave_fn, cycle_length=10,
|
||||
block_length=len(input_list)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_inner_repeats * num_outer_repeats):
|
||||
for elem in input_list:
|
||||
val0, val1 = sess.run(get_next)
|
||||
self.assertAllEqual(elem[0], val0)
|
||||
self.assertAllEqual(elem[1], val1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
dataset = dataset_ops.Dataset.range(num_outer_repeats).interleave(
|
||||
interleave_fn, cycle_length=10, block_length=len(input_list))
|
||||
get_next = self.getNext(dataset)
|
||||
for _ in range(num_inner_repeats * num_outer_repeats):
|
||||
for elem in input_list:
|
||||
val0, val1 = self.evaluate(get_next())
|
||||
self.assertAllEqual(elem[0], val0)
|
||||
self.assertAllEqual(elem[1], val1)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# TODO(b/67868766): Reenable this when the source of flakiness is discovered.
|
||||
def _testFromGeneratorsRunningInParallel(self):
|
||||
@ -186,22 +159,16 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
return dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.range(num_parallel_iterators)
|
||||
.interleave(
|
||||
interleave_fn, cycle_length=num_parallel_iterators, block_length=1))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.range(num_parallel_iterators).interleave(
|
||||
interleave_fn, cycle_length=num_parallel_iterators, block_length=1)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for elem in [0, 1]:
|
||||
for _ in range(num_parallel_iterators):
|
||||
self.assertAllEqual(elem, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
for elem in [0, 1]:
|
||||
for _ in range(num_parallel_iterators):
|
||||
self.assertAllEqual(elem, self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorImplicitConversion(self):
|
||||
def generator():
|
||||
yield [1]
|
||||
@ -209,45 +176,28 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
yield [3]
|
||||
|
||||
for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]:
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtype, output_shapes=[1]))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtype, output_shapes=[1])
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
self.assertEqual(dtype, get_next.dtype)
|
||||
for expected in [[1], [2], [3]]:
|
||||
next_val = self.evaluate(get_next())
|
||||
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for expected in [[1], [2], [3]]:
|
||||
next_val = sess.run(get_next)
|
||||
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorString(self):
|
||||
def generator():
|
||||
yield "foo"
|
||||
yield b"bar"
|
||||
yield u"baz"
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.string, output_shapes=[]))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.string, output_shapes=[])
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[b"foo", b"bar", b"baz"])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for expected in [b"foo", b"bar", b"baz"]:
|
||||
next_val = sess.run(get_next)
|
||||
self.assertAllEqual(expected, next_val)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorTypeError(self):
|
||||
def generator():
|
||||
yield np.array([1, 2, 3], dtype=np.int64)
|
||||
@ -255,23 +205,19 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
yield "ERROR"
|
||||
yield np.array([7, 8, 9], dtype=np.int64)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[3]))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[3])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
self.evaluate(get_next())
|
||||
self.assertAllEqual([7, 8, 9], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorShapeError(self):
|
||||
def generator():
|
||||
yield np.array([1, 2, 3], dtype=np.int64)
|
||||
@ -279,23 +225,18 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
yield np.array([7, 8, 9, 10], dtype=np.int64)
|
||||
yield np.array([11, 12, 13], dtype=np.int64)
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[3]))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64, output_shapes=[3])
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([11, 12, 13], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
|
||||
self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
|
||||
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
|
||||
self.evaluate(get_next())
|
||||
self.assertAllEqual([11, 12, 13], self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorStructureError(self):
|
||||
def generator():
|
||||
yield 1, 2
|
||||
@ -304,46 +245,31 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
yield 6, 7, 8
|
||||
yield 9, 10
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=(dtypes.int64, dtypes.int64)))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=(dtypes.int64, dtypes.int64))
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((1, 2), sess.run(get_next))
|
||||
self.assertEqual((3, 4), sess.run(get_next))
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
self.assertEqual((9, 10), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual((1, 2), self.evaluate(get_next()))
|
||||
self.assertEqual((3, 4), self.evaluate(get_next()))
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
self.evaluate(get_next())
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
self.evaluate(get_next())
|
||||
self.assertEqual((9, 10), self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorHeterogeneous(self):
|
||||
def generator():
|
||||
yield 1
|
||||
yield [2, 3]
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64)
|
||||
self.assertDatasetProduces(dataset, expected_output=[1, [2, 3]])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(1, sess.run(get_next))
|
||||
self.assertAllEqual([2, 3], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorStopShort(self):
|
||||
|
||||
def generator():
|
||||
@ -351,18 +277,12 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
yield 1
|
||||
yield 2
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int64)
|
||||
get_next = self.getNext(dataset)
|
||||
self.assertAllEqual(0, self.evaluate(get_next()))
|
||||
self.assertAllEqual(1, self.evaluate(get_next()))
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(0, sess.run(get_next))
|
||||
self.assertAllEqual(1, sess.run(get_next))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorDestructorCalled(self):
|
||||
# Use an `Event` to signal that the generator has been deleted.
|
||||
event = threading.Event()
|
||||
@ -381,23 +301,18 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
def __del__(self):
|
||||
event.set()
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.from_generator(
|
||||
GeneratorWrapper, output_types=dtypes.int64).take(2))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
GeneratorWrapper, output_types=dtypes.int64).take(2)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(42, sess.run(get_next))
|
||||
self.assertAllEqual(42, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# Test that `GeneratorWrapper` object is destroyed when the
|
||||
# iterator terminates (and the generator iterator is deleted).
|
||||
self.assertTrue(event.is_set())
|
||||
self.assertAllEqual(42, self.evaluate(get_next()))
|
||||
self.assertAllEqual(42, self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
# Test that `GeneratorWrapper` object is destroyed when the
|
||||
# iterator terminates (and the generator iterator is deleted).
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorWithArgs(self):
|
||||
|
||||
def flat_map_fn(elem):
|
||||
@ -410,20 +325,10 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
generator_with_arg, output_types=dtypes.int64, output_shapes=(),
|
||||
args=(elem,))
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.range(5).flat_map(flat_map_fn))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
|
||||
for x in expected:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testFromGeneratorWithTwoArgs(self):
|
||||
|
||||
def flat_map_fn(elem, message):
|
||||
@ -436,26 +341,17 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
generator_with_arg, output_types=(dtypes.int64, dtypes.string),
|
||||
output_shapes=((), ()), args=(elem, message))
|
||||
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(5),
|
||||
dataset_ops.Dataset.from_tensors("Hi!").repeat(None)))
|
||||
.flat_map(flat_map_fn))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(5),
|
||||
dataset_ops.Dataset.from_tensors("Hi!").repeat(None)
|
||||
)).flat_map(flat_map_fn)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
expected = [(0, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
|
||||
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
|
||||
for x in expected:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
expected_output=[(0, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"), (0, b"Hi!"),
|
||||
(1, b"Hi!"), (2, b"Hi!"), (0, b"Hi!"), (1, b"Hi!"),
|
||||
(2, b"Hi!"), (3, b"Hi!")])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGeneratorDatasetFinalizeFunctionCalled(self):
|
||||
# NOTE(mrry): This test tests the internal `_GeneratorDataset`,
|
||||
# which affords more control over what the finalize function can do than
|
||||
@ -472,19 +368,15 @@ class FromGeneratorTest(test_base.DatasetTestBase):
|
||||
stateful=True)
|
||||
|
||||
dummy = constant_op.constant(37)
|
||||
iterator = dataset_ops.make_initializable_iterator(
|
||||
dataset_ops._GeneratorDataset(
|
||||
dummy, lambda x: x, lambda x: x, finalize_fn).take(2))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x, lambda x: x,
|
||||
finalize_fn).take(2)
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual(37, sess.run(get_next))
|
||||
self.assertAllEqual(37, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertTrue(event.is_set())
|
||||
self.assertAllEqual(37, self.evaluate(get_next()))
|
||||
self.assertAllEqual(37, self.evaluate(get_next()))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
self.assertTrue(event.is_set())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user