STT-tensorflow/tensorflow/python/training/queue_runner_test.py
2018-12-17 11:09:45 -08:00

356 lines
15 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for QueueRunner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import time
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner_impl
_MockOp = collections.namedtuple("MockOp", ["name"])
@test_util.run_v1_only("QueueRunner removed from v2")
class QueueRunnerTest(test.TestCase):
def testBasic(self):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
self.evaluate(variables.global_variables_initializer())
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
threads = qr.create_threads(sess)
self.assertEqual(sorted(t.name for t in threads),
["QueueRunnerThread-fifo_queue-CountUpTo:0"])
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(0, len(qr.exceptions_raised))
# The variable should be 3.
self.assertEqual(3, self.evaluate(var))
def testTwoOps(self):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var0 = variables.VariableV1(zero64)
count_up_to_3 = var0.count_up_to(3)
var1 = variables.VariableV1(zero64)
count_up_to_30 = var1.count_up_to(30)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
threads = qr.create_threads(sess)
self.assertEqual(sorted(t.name for t in threads),
["QueueRunnerThread-fifo_queue-CountUpTo:0",
"QueueRunnerThread-fifo_queue-CountUpTo_1:0"])
self.evaluate(variables.global_variables_initializer())
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(0, len(qr.exceptions_raised))
self.assertEqual(3, self.evaluate(var0))
self.assertEqual(30, self.evaluate(var1))
def testExceptionsCaptured(self):
with self.cached_session() as sess:
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"),
_MockOp("so fail")])
threads = qr.create_threads(sess)
self.evaluate(variables.global_variables_initializer())
for t in threads:
t.start()
for t in threads:
t.join()
exceptions = qr.exceptions_raised
self.assertEqual(2, len(exceptions))
self.assertTrue("Operation not in the graph" in str(exceptions[0]))
self.assertTrue("Operation not in the graph" in str(exceptions[1]))
def testRealDequeueEnqueue(self):
with self.cached_session() as sess:
q0 = data_flow_ops.FIFOQueue(3, dtypes.float32)
enqueue0 = q0.enqueue((10.0,))
close0 = q0.close()
q1 = data_flow_ops.FIFOQueue(30, dtypes.float32)
enqueue1 = q1.enqueue((q0.dequeue(),))
dequeue1 = q1.dequeue()
qr = queue_runner_impl.QueueRunner(q1, [enqueue1])
threads = qr.create_threads(sess)
for t in threads:
t.start()
# Enqueue 2 values, then close queue0.
enqueue0.run()
enqueue0.run()
close0.run()
# Wait for the queue runner to terminate.
for t in threads:
t.join()
# It should have terminated cleanly.
self.assertEqual(0, len(qr.exceptions_raised))
# The 2 values should be in queue1.
self.assertEqual(10.0, self.evaluate(dequeue1))
self.assertEqual(10.0, self.evaluate(dequeue1))
# And queue1 should now be closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError, "is closed"):
self.evaluate(dequeue1)
def testRespectCoordShouldStop(self):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
self.evaluate(variables.global_variables_initializer())
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
# As the coordinator to stop. The queue runner should
# finish immediately.
coord = coordinator.Coordinator()
coord.request_stop()
threads = qr.create_threads(sess, coord)
self.assertEqual(sorted(t.name for t in threads),
["QueueRunnerThread-fifo_queue-CountUpTo:0",
"QueueRunnerThread-fifo_queue-close_on_stop"])
for t in threads:
t.start()
coord.join()
self.assertEqual(0, len(qr.exceptions_raised))
# The variable should be 0.
self.assertEqual(0, self.evaluate(var))
def testRequestStopOnException(self):
with self.cached_session() as sess:
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")])
coord = coordinator.Coordinator()
threads = qr.create_threads(sess, coord)
for t in threads:
t.start()
# The exception should be re-raised when joining.
with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
coord.join()
def testGracePeriod(self):
with self.cached_session() as sess:
# The enqueue will quickly block.
queue = data_flow_ops.FIFOQueue(2, dtypes.float32)
enqueue = queue.enqueue((10.0,))
dequeue = queue.dequeue()
qr = queue_runner_impl.QueueRunner(queue, [enqueue])
coord = coordinator.Coordinator()
qr.create_threads(sess, coord, start=True)
# Dequeue one element and then request stop.
dequeue.op.run()
time.sleep(0.02)
coord.request_stop()
# We should be able to join because the RequestStop() will cause
# the queue to be closed and the enqueue to terminate.
coord.join(stop_grace_period_secs=1.0)
def testMultipleSessions(self):
with self.cached_session() as sess:
with session.Session() as other_sess:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
self.evaluate(variables.global_variables_initializer())
coord = coordinator.Coordinator()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
# NOTE that this test does not actually start the threads.
threads = qr.create_threads(sess, coord=coord)
other_threads = qr.create_threads(other_sess, coord=coord)
self.assertEqual(len(threads), len(other_threads))
def testIgnoreMultiStarts(self):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
self.evaluate(variables.global_variables_initializer())
coord = coordinator.Coordinator()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
threads = []
# NOTE that this test does not actually start the threads.
threads.extend(qr.create_threads(sess, coord=coord))
new_threads = qr.create_threads(sess, coord=coord)
self.assertEqual([], new_threads)
def testThreads(self):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
self.evaluate(variables.global_variables_initializer())
qr = queue_runner_impl.QueueRunner(queue, [count_up_to,
_MockOp("bad_op")])
threads = qr.create_threads(sess, start=True)
self.assertEqual(sorted(t.name for t in threads),
["QueueRunnerThread-fifo_queue-CountUpTo:0",
"QueueRunnerThread-fifo_queue-bad_op"])
for t in threads:
t.join()
exceptions = qr.exceptions_raised
self.assertEqual(1, len(exceptions))
self.assertTrue("Operation not in the graph" in str(exceptions[0]))
threads = qr.create_threads(sess, start=True)
for t in threads:
t.join()
exceptions = qr.exceptions_raised
self.assertEqual(1, len(exceptions))
self.assertTrue("Operation not in the graph" in str(exceptions[0]))
def testName(self):
with ops.name_scope("scope"):
queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue")
qr = queue_runner_impl.QueueRunner(queue, [control_flow_ops.no_op()])
self.assertEqual("scope/queue", qr.name)
queue_runner_impl.add_queue_runner(qr)
self.assertEqual(
1, len(ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS, "scope")))
def testStartQueueRunners(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
with self.cached_session() as sess:
init_op.run()
threads = queue_runner_impl.start_queue_runners(sess)
for t in threads:
t.join()
self.assertEqual(0, len(qr.exceptions_raised))
# The variable should be 3.
self.assertEqual(3, self.evaluate(var))
def testStartQueueRunnersRaisesIfNotASession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
with self.cached_session():
init_op.run()
with self.assertRaisesRegexp(TypeError, "tf.Session"):
queue_runner_impl.start_queue_runners("NotASession")
def testStartQueueRunnersIgnoresMonitoredSession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
with self.cached_session():
init_op.run()
threads = queue_runner_impl.start_queue_runners(
monitored_session.MonitoredSession())
self.assertFalse(threads)
def testStartQueueRunnersNonDefaultGraph(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
graph = ops.Graph()
with graph.as_default():
zero64 = constant_op.constant(0, dtype=dtypes.int64)
var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
queue_runner_impl.add_queue_runner(qr)
with self.session(graph=graph) as sess:
init_op.run()
threads = queue_runner_impl.start_queue_runners(sess)
for t in threads:
t.join()
self.assertEqual(0, len(qr.exceptions_raised))
# The variable should be 3.
self.assertEqual(3, self.evaluate(var))
def testQueueRunnerSerializationRoundTrip(self):
graph = ops.Graph()
with graph.as_default():
queue = data_flow_ops.FIFOQueue(10, dtypes.float32, name="queue")
enqueue_op = control_flow_ops.no_op(name="enqueue")
close_op = control_flow_ops.no_op(name="close")
cancel_op = control_flow_ops.no_op(name="cancel")
qr0 = queue_runner_impl.QueueRunner(
queue, [enqueue_op],
close_op,
cancel_op,
queue_closed_exception_types=(errors_impl.OutOfRangeError,
errors_impl.CancelledError))
qr0_proto = queue_runner_impl.QueueRunner.to_proto(qr0)
qr0_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto)
self.assertEqual("queue", qr0_recon.queue.name)
self.assertEqual(1, len(qr0_recon.enqueue_ops))
self.assertEqual(enqueue_op, qr0_recon.enqueue_ops[0])
self.assertEqual(close_op, qr0_recon.close_op)
self.assertEqual(cancel_op, qr0_recon.cancel_op)
self.assertEqual(
(errors_impl.OutOfRangeError, errors_impl.CancelledError),
qr0_recon.queue_closed_exception_types)
# Assert we reconstruct an OutOfRangeError for QueueRunners
# created before QueueRunnerDef had a queue_closed_exception_types field.
del qr0_proto.queue_closed_exception_types[:]
qr0_legacy_recon = queue_runner_impl.QueueRunner.from_proto(qr0_proto)
self.assertEqual("queue", qr0_legacy_recon.queue.name)
self.assertEqual(1, len(qr0_legacy_recon.enqueue_ops))
self.assertEqual(enqueue_op, qr0_legacy_recon.enqueue_ops[0])
self.assertEqual(close_op, qr0_legacy_recon.close_op)
self.assertEqual(cancel_op, qr0_legacy_recon.cancel_op)
self.assertEqual((errors_impl.OutOfRangeError,),
qr0_legacy_recon.queue_closed_exception_types)
if __name__ == "__main__":
test.main()