Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212338134
This commit is contained in:
parent
6d3af1df20
commit
a5752eb9cb
@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
# Initialize with an input tensor of incompatible rank.
|
||||
sess.run(init_op, feed_dict={input_tensor: [[1]]})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
|
||||
@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
|
||||
@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = sess.run(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
|
||||
@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
|
||||
sess.run(op))
|
||||
@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = data.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
# Mismatch in the 0th dimension.
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for test_batch_size in [1, 3, 7, 10]:
|
||||
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
|
||||
num_batches = 7 // test_batch_size
|
||||
@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for test_batch_size in [1, 3, 7, 10]:
|
||||
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
|
||||
num_batches = 7 // test_batch_size
|
||||
@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
# Batch of a finite input, where the batch_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
if not drop_remainder:
|
||||
@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
.make_one_shot_iterator())
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
|
||||
sess.run(init_op, feed_dict={batch_size: 14})
|
||||
|
||||
@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
batch_size=10)).make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
if threshold % 10 != 0:
|
||||
@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
|
||||
@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase):
|
||||
iterator = result.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(5):
|
||||
sess.run(get_next)
|
||||
@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase):
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase):
|
||||
def checkResults(self, dataset, shapes, values):
|
||||
self.assertEqual(shapes, dataset.output_shapes)
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for expected in values:
|
||||
got = sess.run(get_next)
|
||||
self.assertEqual(got, expected)
|
||||
@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase):
|
||||
self.assertIs(None, dataset.output_shapes[1].ndims)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual([0] * (2**i), x)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||
@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase):
|
||||
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
|
||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||
self.assertEqual(y, 45)
|
||||
@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -376,7 +376,7 @@ class BucketTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
@ -411,7 +411,7 @@ class BucketTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
@ -482,7 +482,7 @@ class BucketTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
@ -515,7 +515,7 @@ class BucketTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
@ -556,7 +556,7 @@ class BucketBySequenceLength(test.TestCase):
|
||||
element_len, boundaries, batch_sizes))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(4):
|
||||
batches.append(sess.run(batch))
|
||||
@ -600,7 +600,7 @@ class BucketBySequenceLength(test.TestCase):
|
||||
pad_to_bucket_boundary=True))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(3):
|
||||
batches.append(sess.run(batch))
|
||||
@ -637,7 +637,7 @@ class BucketBySequenceLength(test.TestCase):
|
||||
pad_to_bucket_boundary=True))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(5):
|
||||
batches.append(sess.run(batch))
|
||||
|
@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for _ in range(100):
|
||||
for i in range(10):
|
||||
@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
freqs = np.zeros([num_datasets])
|
||||
for _ in range(num_samples):
|
||||
freqs[sess.run(next_element)] += 1
|
||||
@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in choice_array:
|
||||
self.assertEqual(words[i], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
|
||||
lambda x: (x * x, make_sparse(x))).take(take_t)
|
||||
element = get_single_element.get_single_element(dataset)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
if error is None:
|
||||
dense_val, sparse_val = sess.run(
|
||||
element, feed_dict={
|
||||
@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.range(stop_t)
|
||||
element = get_single_element.reduce_dataset(dataset, sum_reducer)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
value = sess.run(element, feed_dict={stop_t: stop})
|
||||
self.assertEqual(stop * (stop - 1) / 2, value)
|
||||
|
||||
|
@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase):
|
||||
get_op = gen_dataset_ops.indexed_dataset_get(
|
||||
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(materialize)
|
||||
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
|
||||
|
||||
def testIdentityIndexedDataset(self):
|
||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||
materialized = ds.materialize()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(materialized.initializer)
|
||||
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
|
||||
for i in range(16):
|
||||
@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase):
|
||||
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
|
||||
itr = ds.make_initializable_iterator()
|
||||
n = itr.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(itr.initializer)
|
||||
for i in range(16):
|
||||
output = sess.run(n)
|
||||
|
@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
|
||||
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
|
||||
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def testSingleThreadedRagged(self):
|
||||
# Tests a sequence with wildly different elements per iterator.
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def _testTwoThreadsNoContention(self, sloppy=False):
|
||||
# num_threads > 1.
|
||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
Args:
|
||||
sloppy: Whether to be sloppy or not.
|
||||
"""
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
|
||||
# num_threads > 1.
|
||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
Args:
|
||||
sloppy: Whether to be sloppy or not.
|
||||
"""
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
|
||||
|
||||
def _testEmptyInput(self, sloppy=False):
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
# Empty input.
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
|
||||
# Non-empty input leading to empty output.
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
|
||||
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
|
||||
# Mixture of non-empty and empty interleaved datasets.
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
def testDelayedOutputSloppy(self):
|
||||
# Explicitly control the sequence of events to ensure we correctly avoid
|
||||
# head-of-line blocking.
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testBlockLengthWithContentionSloppy(self):
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def _testEarlyExit(self, sloppy=False):
|
||||
# Exiting without consuming all input should not block
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
output_values = []
|
||||
for _ in range(30):
|
||||
output_values.append(sess.run(iterator.get_next()))
|
||||
@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
for j in range(2):
|
||||
@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
def testErrorsInOutputFn(self):
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
sess.run(
|
||||
self.init_op,
|
||||
@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.init_op = self.iterator.initializer
|
||||
self.next_element = self.iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
self.init_op = self.iterator.initializer
|
||||
self.next_element = self.iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
next_element = iterator.get_next()
|
||||
|
||||
results = []
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2):
|
||||
elements = []
|
||||
sess.run(iterator.initializer)
|
||||
|
@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(num_repeats): # Dataset is repeated.
|
||||
for i in range(10): # 10 records.
|
||||
|
@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
# All of the files are present.
|
||||
sess.run(init_op)
|
||||
for filename in filenames:
|
||||
|
@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase):
|
||||
expected_values=None,
|
||||
expected_err=None):
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
if expected_err:
|
||||
with self.assertRaisesWithPredicateMatch(expected_err[0],
|
||||
expected_err[1]):
|
||||
|
@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
|
||||
destroy_op = resource_variable_ops.destroy_resource_op(
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
||||
@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase):
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = back_to_cpu_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase):
|
||||
elem_has_value_t = next_elem.has_value()
|
||||
elem_value_t = next_elem.get_value()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
# Before initializing the iterator, evaluating the optional fails with
|
||||
# a FailedPreconditionError.
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
|
@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase):
|
||||
self.assertEqual([tensor_shape.TensorShape([])] * 3,
|
||||
[t.shape for t in get_next[1]])
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
|
||||
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
|
||||
@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase):
|
||||
.make_one_shot_iterator())
|
||||
negative_get_next = negative_iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(3 + 4, sess.run(get_next))
|
||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
||||
|
@ -116,7 +116,7 @@ class ReadBatchFeaturesTest(
|
||||
init_op = iterator.initializer
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||
range(self._num_files), 2, 10):
|
||||
|
@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
|
||||
class_func=lambda c, _: c,
|
||||
seed=27)).make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
returned = []
|
||||
while len(returned) < 4000:
|
||||
returned.append(sess.run(get_next))
|
||||
@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
returned = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
returned = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
|
@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase):
|
||||
start, make_scan_fn(step)).take(take).make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
|
||||
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
|
||||
(10, 2, 10), (10, -1, 10),
|
||||
@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase):
|
||||
make_scan_fn(step)).take(take).make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
|
||||
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
|
||||
(10, 2, 10), (10, -1, 10),
|
||||
@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
|
||||
self.assertAllEqual([0] * (2**i), longer_vector_val)
|
||||
|
@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase):
|
||||
def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
|
||||
get_next = ds_fn().make_one_shot_iterator().get_next()
|
||||
outputs = []
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(num_outputs):
|
||||
outputs.append(sess.run(get_next))
|
||||
if verify_exhausted:
|
||||
|
@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
window_stride=window_stride_t)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(
|
||||
init_op,
|
||||
@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
num_batches = (10 - 5) // 3 + 1
|
||||
for i in range(num_batches):
|
||||
@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
# Slide: 1st batch.
|
||||
actual = sess.run(get_next)
|
||||
@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
.make_initializable_iterator())
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
|
@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSet(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string), 2)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(2): # Run twice to verify statelessness of db operations.
|
||||
sess.run(
|
||||
init_op,
|
||||
@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetJoinQuery(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetNullTerminator(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetReuseSqlDataset(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadEmptyResultSet(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetWithInvalidDriverName(self):
|
||||
init_op = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))[0]
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(
|
||||
init_op,
|
||||
@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetWithInvalidColumnName(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetOfQueryWithSyntaxError(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetOfInsertQuery(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.string))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# place it in an `int8` tensor.
|
||||
def testReadResultSetInt8(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetInt8NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
|
||||
dtypes.int8))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# a SQLite database table and place it in an `int8` tensor.
|
||||
def testReadResultSetInt8MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# place it in an `int16` tensor.
|
||||
def testReadResultSetInt16(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetInt16NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
|
||||
dtypes.int16))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# a SQLite database table and place it in an `int16` tensor.
|
||||
def testReadResultSetInt16MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# place it in an `int32` tensor.
|
||||
def testReadResultSetInt32(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# SQLite database table and place it in an `int32` tensor.
|
||||
def testReadResultSetInt32NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# a SQLite database table and place it in an `int32` tensor.
|
||||
def testReadResultSetInt32MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# table and place it in an `int32` tensor.
|
||||
def testReadResultSetInt32VarCharColumnAsInt(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# and place it in an `int64` tensor.
|
||||
def testReadResultSetInt64(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# SQLite database table and place it in an `int64` tensor.
|
||||
def testReadResultSetInt64NegativeAndZero(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# a SQLite database table and place it in an `int64` tensor.
|
||||
def testReadResultSetInt64MaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# place it in a `uint8` tensor.
|
||||
def testReadResultSetUInt8(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# SQLite database table and place them in `uint8` tensors.
|
||||
def testReadResultSetUInt8MinAndMaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# and place it in a `uint16` tensor.
|
||||
def testReadResultSetUInt16(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# SQLite database table and place them in `uint16` tensors.
|
||||
def testReadResultSetUInt16MinAndMaxValues(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# in `bool` tensors.
|
||||
def testReadResultSetBool(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
# from a SQLite database table and place it as `True` in a `bool` tensor.
|
||||
def testReadResultSetBoolNotZeroOrOne(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetFloat64(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.float64))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetFloat64OverlyPrecise(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.float64))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
|
||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||
dtypes.float64))
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
init_op,
|
||||
feed_dict={
|
||||
|
@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase):
|
||||
# TODO(rachelim): support sparse tensor outputs
|
||||
next1 = dataset1.make_one_shot_iterator().get_next()
|
||||
next2 = dataset2.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
while True:
|
||||
try:
|
||||
op1 = sess.run(next1)
|
||||
@ -54,7 +54,7 @@ class DatasetTestBase(test.TestCase):
|
||||
replacements=None):
|
||||
next1 = dataset1.make_one_shot_iterator().get_next()
|
||||
next2 = dataset2.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
try:
|
||||
sess.run(next1)
|
||||
raise ValueError(
|
||||
|
@ -69,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
thread_ids = []
|
||||
try:
|
||||
|
@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
for test_case, expected in test_cases:
|
||||
current_test_case = test_case
|
||||
sess.run(iterator.initializer)
|
||||
|
@ -92,7 +92,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
dataset = self._structuredDataset(structure, shape, dtype).apply(
|
||||
grouping.window_dataset(5)).flat_map(fn)
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
expected = sess.run(self._structuredElement(structure, shape, dtype))
|
||||
actual = sess.run(get_next)
|
||||
self._assertEqual(expected, actual)
|
||||
@ -128,7 +128,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
|
||||
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
expected = sess.run(
|
||||
self._structuredElement(structure, np.concatenate(
|
||||
([5], shape), axis=0), dtype))
|
||||
@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, {shape_t: shape})
|
||||
expected = sess.run(
|
||||
self._structuredElement(None, np.concatenate(([5], shape), axis=0),
|
||||
@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
structure, shape, dtype).repeat(5).apply(
|
||||
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
expected = sess.run(
|
||||
self._structuredSparseElement(structure,
|
||||
np.concatenate(([5], shape), axis=0),
|
||||
@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, {shape_t: shape})
|
||||
expected = sess.run(
|
||||
self._structuredSparseElement(None,
|
||||
@ -321,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
grouping.window_dataset(len(shapes))).apply(
|
||||
grouping._map_x_dataset(fn))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
|
||||
expected = sess.run(
|
||||
self._structuredElement(
|
||||
@ -352,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, {shapes_t: shapes})
|
||||
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
|
||||
expected = sess.run(
|
||||
@ -380,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
grouping._map_x_dataset(
|
||||
lambda x: batching.padded_batch_window(x, padded_shape)))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -458,7 +458,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
structure, shapes, dtype).apply(grouping.window_dataset(
|
||||
len(shapes))).apply(grouping._map_x_dataset(fn))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
expected = sess.run(
|
||||
self._structuredRaggedSparseElement(structure, shapes, dtype,
|
||||
padded_shape))
|
||||
@ -489,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op, {shapes_t: shapes})
|
||||
expected = sess.run(
|
||||
self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
|
||||
@ -516,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
|
||||
grouping._map_x_dataset(
|
||||
lambda x: batching.padded_batch_window(x, padded_shape)))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase):
|
||||
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
|
||||
|
||||
def testWrite(self):
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
self.writer, feed_dict={
|
||||
self.filename: self._createFile(),
|
||||
@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase):
|
||||
|
||||
def testWriteZLIB(self):
|
||||
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
self.writer,
|
||||
feed_dict={
|
||||
@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase):
|
||||
|
||||
def testWriteGZIP(self):
|
||||
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
|
||||
with self.test_session() as sess:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(
|
||||
self.writer,
|
||||
feed_dict={
|
||||
|
Loading…
x
Reference in New Issue
Block a user