From 6bdbc4689e58f8f13e3b2db59afea801c609c191 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Mar 2019 21:49:22 -0700 Subject: [PATCH] Raise an error if several scopes are not properly nested, e.g. due to using __enter__ and __exit__ on the context manager directly instead of using "with" blocks. In preparation for creating a tf.distribute.set_strategy() API which will manually call strategy.scope().__enter__ and .__exit__. PiperOrigin-RevId: 239124241 --- .../python/kernel_tests/sdca_ops_test.py | 3 + .../python/distribute/distribute_lib.py | 30 +++++++-- .../python/distribute/distribute_lib_test.py | 48 ++++++++++++++ tensorflow/python/eager/context.py | 3 + tensorflow/python/framework/ops.py | 48 +++++++++++--- tensorflow/python/framework/ops_test.py | 16 +++++ .../conditional_accumulator_test.py | 9 +++ .../dense_update_ops_no_tsan_test.py | 13 ++++ .../python/kernel_tests/fifo_queue_test.py | 5 ++ .../kernel_tests/padding_fifo_queue_test.py | 63 +++++++++++++++++++ .../kernel_tests/priority_queue_test.py | 13 ++++ .../random/random_shuffle_queue_test.py | 7 +++ .../sparse_conditional_accumulator_test.py | 15 +++++ .../kernel_tests/variable_scope_test.py | 31 +++++++++ tensorflow/python/ops/variable_scope.py | 14 +++-- 15 files changed, 300 insertions(+), 18 deletions(-) diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index d49834dc860..9dea5eff337 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -465,6 +465,9 @@ class SdcaWithLogisticLossTest(SdcaModelTest): dtypes.string, shape=(len(example_weights),)) examples['example_ids'] = example_ids variables = make_variable_dict(1, 1) + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 05b928b9a2b..e46a7992e4b 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -19,9 +19,10 @@ from __future__ import division from __future__ import print_function import copy +import enum import threading import weakref -import enum +import six from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util @@ -189,10 +190,31 @@ class _CurrentDistributionContext(object): def __exit__(self, exception_type, exception_value, traceback): if self._device_scope: - self._device_scope.__exit__(exception_type, exception_value, traceback) - self._var_creator_scope.__exit__(exception_type, exception_value, traceback) + try: + self._device_scope.__exit__(exception_type, exception_value, traceback) + except RuntimeError as e: + six.raise_from( + RuntimeError("Device scope nesting error: move call to " + "tf.distribute.set_strategy() out of `with` scope."), + e) + + try: + self._var_creator_scope.__exit__( + exception_type, exception_value, traceback) + except RuntimeError as e: + six.raise_from( + RuntimeError("Variable creator scope nesting error: move call to " + "tf.distribute.set_strategy() out of `with` scope."), + e) + if self._var_scope: - self._var_scope.__exit__(exception_type, exception_value, traceback) + try: + self._var_scope.__exit__(exception_type, exception_value, traceback) + except RuntimeError as e: + six.raise_from( + RuntimeError("Variable scope nesting error: move call to " + "tf.distribute.set_strategy() out of `with` scope."), + e) _pop_per_thread_mode() diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 391a70c562f..a3289c041ff 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -109,6 +110,53 @@ class TestStrategyTest(test.TestCase): variable_scope.variable(1.0, name="baz")) _assert_in_default_state(self) + def testScopeDeviceNestingError(self): + _assert_in_default_state(self) + dist = _TestStrategy() + # Open a device scope with dist.scope(). + dist.extended._default_device = "/device:GPU:0" + scope = dist.scope() + scope.__enter__() + self.assertIs(dist, ds_context.get_strategy()) + with ops.device("/device:CPU:0"): + with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"): + scope.__exit__(None, None, None) + scope.__exit__(None, None, None) + _assert_in_default_state(self) + + def testScopeVarCreatorNestingError(self): + + def creator(next_creator, **kwargs): + return next_creator(**kwargs) + + _assert_in_default_state(self) + dist = _TestStrategy() + scope = dist.scope() + scope.__enter__() + self.assertIs(dist, ds_context.get_strategy()) + with variable_scope.variable_creator_scope(creator): + with self.assertRaisesRegexp(RuntimeError, + "Variable creator scope nesting error"): + scope.__exit__(None, None, None) + scope.__exit__(None, None, None) + _assert_in_default_state(self) + + def testScopeVarScopeNestingError(self): + # We create a new graph here to simplify clean-up, since the error + # we are triggering happens in the middle of scope.__exit__() and + # leaves us in a weird state. + with ops.Graph().as_default(): + _assert_in_default_state(self) + dist = _TestStrategy() + scope = dist.scope() + scope.__enter__() + self.assertIs(dist, ds_context.get_strategy()) + with variable_scope.variable_scope("AA"): + with self.assertRaisesRegexp(RuntimeError, + "Variable scope nesting error"): + scope.__exit__(None, None, None) + _assert_in_default_state(self) + def testSettingSynchronizationAndAggregation(self): _assert_in_default_state(self) dist = _TestStrategy() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index a31d22f68e1..ad7fe539512 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -560,6 +560,7 @@ class Context(object): Raises: ValueError: If name is not a string or is an invalid device name. + RuntimeError: If device scopes are not properly nested. """ eager_context = self._thread_local_data old_device_name = eager_context.device_name @@ -595,6 +596,8 @@ class Context(object): eager_context.device_spec = new_device_spec yield finally: + if eager_context.device_spec is not new_device_spec: + raise RuntimeError("Exiting device scope without proper scope nesting") eager_context.device_name = old_device_name eager_context.device_spec = old_device_spec diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 8dfcf381626..687f3e674ee 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3088,18 +3088,25 @@ class Graph(object): Yields: `_variable_creator_scope` is a context manager with a side effect, but doesn't return a value. + + Raises: + RuntimeError: If variable creator scopes are not properly nested. """ - # This step makes a copy of the existing stack, and it also initializes + # This step keeps a reference to the existing stack, and it also initializes # self._thread_local._variable_creator_stack if it doesn't exist yet. - old = list(self._variable_creator_stack) - stack = self._thread_local._variable_creator_stack # pylint: disable=protected-access - stack.append((priority, creator)) + old = self._variable_creator_stack + new = list(old) + new.append((priority, creator)) # Sorting is stable, so we'll put higher-priority creators later in the list # but otherwise maintain registration order. - stack.sort(key=lambda item: item[0]) + new.sort(key=lambda item: item[0]) + self._thread_local._variable_creator_stack = new # pylint: disable=protected-access try: yield finally: + if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access + raise RuntimeError( + "Exiting variable_creator_scope without proper nesting.") self._thread_local._variable_creator_stack = old # pylint: disable=protected-access # Note: this method is private because the API of tf.Graph() is public and @@ -3108,7 +3115,25 @@ class Graph(object): def _variable_creator_stack(self): if not hasattr(self._thread_local, "_variable_creator_stack"): self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access - return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access + + # This previously returned a copy of the stack instead of the stack itself, + # to guard against accidental mutation. Consider, however, code that wants + # to save and restore the variable creator stack: + # def f(): + # original_stack = graph._variable_creator_stack + # graph._variable_creator_stack = new_stack + # ... # Some code + # graph._variable_creator_stack = original_stack + # + # And lets say you have some code that calls this function with some + # variable_creator: + # def g(): + # with variable_scope.variable_creator_scope(creator): + # f() + # When exiting the variable creator scope, it would see a different stack + # object than it expected leading to a "Exiting variable_creator_scope + # without proper nesting" error. + return self._thread_local._variable_creator_stack # pylint: disable=protected-access @_variable_creator_stack.setter def _variable_creator_stack(self, variable_creator_stack): @@ -4405,11 +4430,18 @@ class Graph(object): Yields: A context manager that specifies the default device to use for newly created ops. + + Raises: + RuntimeError: If device scopes are not properly nested. """ self._add_device_to_stack(device_name_or_function, offset=2) + old_top_of_stack = self._device_function_stack.peek_objs()[0] try: yield finally: + new_top_of_stack = self._device_function_stack.peek_objs()[0] + if old_top_of_stack is not new_top_of_stack: + raise RuntimeError("Exiting device scope without proper scope nesting.") self._device_function_stack.pop_obj() def _apply_device_functions(self, op): @@ -5119,9 +5151,7 @@ class Graph(object): def device(device_name_or_function): """Wrapper for `Graph.device()` using the default graph. - See - `tf.Graph.device` - for more details. + See `tf.Graph.device` for more details. Args: device_name_or_function: The device name or function to use in diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 7d9799a1a7e..77955eb2525 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1295,6 +1295,22 @@ class DeviceTest(test_util.TensorFlowTestCase): device: "/device:CPU:5" } """, gd) + def testNestingErrorGraph(self): + g = ops.Graph() + scope = g.device("/device:GPU:8") + scope.__enter__() + with g.device("/device:GPU:9"): + with self.assertRaises(RuntimeError): + scope.__exit__(None, None, None) + + def testNestingErrorEager(self): + with context.eager_mode(): + scope = ops.device("/device:CPU:0") + scope.__enter__() + with ops.device(None): + with self.assertRaises(RuntimeError): + scope.__exit__(None, None, None) + def testNoneClearsDefault(self): g = ops.Graph() with g.device("/job:worker/replica:2/device:CPU:1"): diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py index 32a20587508..37afb32e36b 100644 --- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py @@ -437,6 +437,9 @@ class ConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testParallelApplyGrad(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) @@ -463,6 +466,9 @@ class ConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testParallelTakeGrad(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) @@ -496,6 +502,9 @@ class ConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testAccumulatorApplyAndBlockingTake(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py index a778bf231bb..a60729ac7fd 100644 --- a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py +++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -33,6 +34,9 @@ class AssignOpTest(test.TestCase): # contain benign and deliberate data races when multiple threads update # the same parameters without a lock. def testParallelUpdateWithoutLocking(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: ones_t = array_ops.fill([1024, 1024], 1.0) p = variables.Variable(array_ops.zeros([1024, 1024])) @@ -60,6 +64,9 @@ class AssignOpTest(test.TestCase): self.assertTrue((vals <= ones * 20).all()) def testParallelAssignWithoutLocking(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: ones_t = array_ops.fill([1024, 1024], float(1)) p = variables.Variable(array_ops.zeros([1024, 1024])) @@ -92,6 +99,9 @@ class AssignOpTest(test.TestCase): # returning the output tensors. This issue will be resolved with the new # resource variables. def testParallelUpdateWithLocking(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: zeros_t = array_ops.fill([1024, 1024], 0.0) ones_t = array_ops.fill([1024, 1024], 1.0) @@ -119,6 +129,9 @@ class AssignOpTest(test.TestCase): self.assertAllEqual(vals, ones * 20) def testParallelAssignWithLocking(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: zeros_t = array_ops.fill([1024, 1024], 0.0) ones_t = array_ops.fill([1024, 1024], 1.0) diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index b88b43ff507..ebc372bf57c 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -42,6 +42,11 @@ from tensorflow.python.util import compat @test_util.run_v1_only("FIFOQueue removed from v2") class FIFOQueueTest(test.TestCase): + def setUp(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() + def testConstructor(self): with ops.Graph().as_default(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q") diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py index 214eaa0160e..38ac69c5c1e 100644 --- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py @@ -120,6 +120,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertEqual(4, q.size().eval()) def testParallelEnqueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] @@ -146,6 +149,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertItemsEqual(elems, results) def testParallelDequeue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] @@ -184,6 +190,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0] @@ -627,6 +636,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.evaluate(dequeued_t) def testParallelEnqueueMany(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),)) elems = [10.0 * x for x in range(100)] @@ -646,6 +658,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertItemsEqual(dequeued_t.eval(), elems * 10) def testParallelDequeueMany(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),)) elems = [10.0 * x for x in range(1000)] @@ -668,6 +683,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertItemsEqual(elems, dequeued_elems) def testParallelDequeueUpTo(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),)) elems = [10.0 * x for x in range(1000)] @@ -692,6 +710,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertItemsEqual(elems, dequeued_elems) def testParallelEnqueueAndDequeue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),)) initial_elements = [10.0] * 49 @@ -795,6 +816,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertEqual(0, q.size().eval()) def testBlockingDequeueMany(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -822,6 +846,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertAllEqual(elems, dequeued_elems) def testBlockingDequeueUpTo(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -892,6 +919,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.evaluate(dequeued_t) def testBlockingDequeueFromClosedQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -940,6 +970,9 @@ class PaddingFIFOQueueTest(test.TestCase): dequeue_thread.join() def testBlockingDequeueFromClosedEmptyQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) close_op = q.close() @@ -960,6 +993,9 @@ class PaddingFIFOQueueTest(test.TestCase): dequeue_thread.join() def testBlockingDequeueManyFromClosedQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -985,6 +1021,9 @@ class PaddingFIFOQueueTest(test.TestCase): dequeue_thread.join() def testBlockingDequeueManyButNotAllFromClosedQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -1010,6 +1049,9 @@ class PaddingFIFOQueueTest(test.TestCase): dequeue_thread.join() def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -1080,6 +1122,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertEqual(0, q.size().eval()) def testBlockingDequeueManyFromClosedEmptyQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) close_op = q.close() @@ -1100,6 +1145,9 @@ class PaddingFIFOQueueTest(test.TestCase): dequeue_thread.join() def testBlockingDequeueUpToFromClosedEmptyQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) close_op = q.close() @@ -1147,6 +1195,9 @@ class PaddingFIFOQueueTest(test.TestCase): enqueue_op.run() def testBlockingEnqueueToFullQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -1170,6 +1221,9 @@ class PaddingFIFOQueueTest(test.TestCase): thread.join() def testBlockingEnqueueManyToFullQueue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -1197,6 +1251,9 @@ class PaddingFIFOQueueTest(test.TestCase): thread.join() def testBlockingEnqueueBeforeClose(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0, 40.0] @@ -1234,6 +1291,9 @@ class PaddingFIFOQueueTest(test.TestCase): self.assertEqual(0, q.size().eval()) def testBlockingEnqueueManyBeforeClose(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) elems = [10.0, 20.0, 30.0] @@ -1397,6 +1457,9 @@ class PaddingFIFOQueueTest(test.TestCase): @test_util.run_deprecated_v1 def testResetOfBlockingOperation(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),)) dequeue_op = q_empty.dequeue() diff --git a/tensorflow/python/kernel_tests/priority_queue_test.py b/tensorflow/python/kernel_tests/priority_queue_test.py index 84f395dd343..c183fc0db48 100644 --- a/tensorflow/python/kernel_tests/priority_queue_test.py +++ b/tensorflow/python/kernel_tests/priority_queue_test.py @@ -27,6 +27,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -69,6 +70,9 @@ class PriorityQueueTest(test.TestCase): self.assertEqual(missed, set()) def testRoundTripInsertMultiThreadedReadOnceSorts(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( (), ())) @@ -115,6 +119,9 @@ class PriorityQueueTest(test.TestCase): self.assertEqual(missed, set()) def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (())) @@ -165,6 +172,9 @@ class PriorityQueueTest(test.TestCase): self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values)) def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) @@ -221,6 +231,9 @@ class PriorityQueueTest(test.TestCase): self.assertAllEqual(set(dequeued), set(all_enqueued_values)) def testRoundTripInsertManyMultiThreadedReadOnceSorts(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( (), ())) diff --git a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py index 4a8144fadb4..1fdf1aaf315 100644 --- a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py +++ b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py @@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -42,6 +43,9 @@ class RandomShuffleQueueTest(test.TestCase): # Useful for debugging when a test times out. super(RandomShuffleQueueTest, self).setUp() tf_logging.error("Starting: %s", self._testMethodName) + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() def tearDown(self): super(RandomShuffleQueueTest, self).tearDown() @@ -1237,6 +1241,9 @@ class RandomShuffleQueueTest(test.TestCase): self.evaluate(enqueue_many_op) def testResetOfBlockingOperation(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, ( (),)) diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py index 4a967b65628..67b42d02b88 100644 --- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py @@ -269,6 +269,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testParallelApplyGradMean(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) @@ -301,6 +304,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testParallelApplyGradSum(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, @@ -336,6 +342,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testParallelTakeGrad(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) @@ -376,6 +385,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testAccumulatorApplyAndBlockingTake(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) @@ -412,6 +424,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): @test_util.run_v1_only("b/120545219") def testAccumulatorCancel(self): + # We need each thread to keep its own device stack or the device scopes + # won't be properly nested. + ops.get_default_graph().switch_to_thread_local() with self.cached_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 451eb385306..af337c6d9b7 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -148,6 +148,17 @@ class VariableScopeTest(test.TestCase): w = variable_scope.get_variable("w", []) self.assertEqual(w.constraint, constraint) + @test_util.run_in_graph_and_eager_modes + @run_inside_wrap_function_in_eager_mode + def testVarScopeNestingError(self): + with variable_scope.variable_scope("aa"): + scope = variable_scope.variable_scope("bb") + scope.__enter__() + with variable_scope.variable_scope("cc"): + with self.assertRaises(RuntimeError): + scope.__exit__(None, None, None) + scope.__exit__(None, None, None) + # TODO(mihaimaruseac): Not converted to use wrap_function because of # TypeError: Fetch argument # has invalid type , must be a string or Tensor. @@ -1704,6 +1715,26 @@ class VariableScopeWithCustomGetterTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(called[0]) + @test_util.run_in_graph_and_eager_modes + @run_inside_wrap_function_in_eager_mode + def testVariableCreatorNestingError(self): + + def creator(next_creator, **kwargs): + return next_creator(**kwargs) + + # Save the state so we can clean up at the end. + graph = ops.get_default_graph() + old_creator_stack = graph._variable_creator_stack + + try: + scope = variable_scope.variable_creator_scope(creator) + scope.__enter__() + with variable_scope.variable_creator_scope(creator): + with self.assertRaises(RuntimeError): + scope.__exit__(None, None, None) + finally: + graph._variable_creator_stack = old_creator_stack + class PartitionInfoTest(test.TestCase): diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index bcef6e60e33..39d5d2816e1 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1835,6 +1835,7 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name self._constraint = constraint self._var_store = _get_default_variable_store() self._var_scope_store = get_variable_scope_store() + self._last_variable_scope_object = None if isinstance(self._name_or_scope, VariableScope): self._new_name = self._name_or_scope.name name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access @@ -1931,9 +1932,13 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name variable_scope_object.set_use_resource(self._use_resource) self._var_scope_store.open_variable_scope(self._new_name) self._var_scope_store.current_scope = variable_scope_object + self._last_variable_scope_object = variable_scope_object return variable_scope_object def __exit__(self, type_arg, value_arg, traceback_arg): + if (self._var_scope_store.current_scope is not + self._last_variable_scope_object): + raise RuntimeError("Improper nesting of variable_scope.") # If jumping out from a non-prolonged scope, restore counts. if isinstance(self._name_or_scope, VariableScope): self._var_scope_store.variable_scopes_count = self._old_subscopes @@ -2225,11 +2230,10 @@ class variable_scope(object): try: return self._enter_scope_uncached() - except Exception: - if self._in_graph_mode and not self._building_function: - if self._graph_context_manager is not None: - self._graph_context_manager.__exit__(*sys.exc_info()) - raise + finally: + if (self._in_graph_mode and not self._building_function and + self._graph_context_manager is not None): + self._graph_context_manager.__exit__(*sys.exc_info()) def _enter_scope_uncached(self): """Enters the context manager when there is no cached scope yet.