Raise an error if several scopes are not properly nested, e.g. due to

using __enter__ and __exit__ on the context manager directly instead
of using "with" blocks.  In preparation for creating a
tf.distribute.set_strategy() API which will manually call
strategy.scope().__enter__ and .__exit__.

PiperOrigin-RevId: 239124241
This commit is contained in:
A. Unique TensorFlower 2019-03-18 21:49:22 -07:00 committed by TensorFlower Gardener
parent f1e0fa966d
commit 6bdbc4689e
15 changed files with 300 additions and 18 deletions

View File

@ -465,6 +465,9 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
dtypes.string, shape=(len(example_weights),)) dtypes.string, shape=(len(example_weights),))
examples['example_ids'] = example_ids examples['example_ids'] = example_ids
variables = make_variable_dict(1, 1) variables = make_variable_dict(1, 1)
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
for num_shards in _SHARD_NUMBERS: for num_shards in _SHARD_NUMBERS:
for num_loss_partitions in _NUM_LOSS_PARTITIONS: for num_loss_partitions in _NUM_LOSS_PARTITIONS:
with self._single_threaded_test_session(): with self._single_threaded_test_session():

View File

@ -19,9 +19,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import enum
import threading import threading
import weakref import weakref
import enum import six
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util from tensorflow.python.distribute import device_util
@ -189,10 +190,31 @@ class _CurrentDistributionContext(object):
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, exception_type, exception_value, traceback):
if self._device_scope: if self._device_scope:
self._device_scope.__exit__(exception_type, exception_value, traceback) try:
self._var_creator_scope.__exit__(exception_type, exception_value, traceback) self._device_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Device scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
try:
self._var_creator_scope.__exit__(
exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable creator scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
if self._var_scope: if self._var_scope:
self._var_scope.__exit__(exception_type, exception_value, traceback) try:
self._var_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
_pop_per_thread_mode() _pop_per_thread_mode()

View File

@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -109,6 +110,53 @@ class TestStrategyTest(test.TestCase):
variable_scope.variable(1.0, name="baz")) variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self) _assert_in_default_state(self)
def testScopeDeviceNestingError(self):
_assert_in_default_state(self)
dist = _TestStrategy()
# Open a device scope with dist.scope().
dist.extended._default_device = "/device:GPU:0"
scope = dist.scope()
scope.__enter__()
self.assertIs(dist, ds_context.get_strategy())
with ops.device("/device:CPU:0"):
with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
scope.__exit__(None, None, None)
scope.__exit__(None, None, None)
_assert_in_default_state(self)
def testScopeVarCreatorNestingError(self):
def creator(next_creator, **kwargs):
return next_creator(**kwargs)
_assert_in_default_state(self)
dist = _TestStrategy()
scope = dist.scope()
scope.__enter__()
self.assertIs(dist, ds_context.get_strategy())
with variable_scope.variable_creator_scope(creator):
with self.assertRaisesRegexp(RuntimeError,
"Variable creator scope nesting error"):
scope.__exit__(None, None, None)
scope.__exit__(None, None, None)
_assert_in_default_state(self)
def testScopeVarScopeNestingError(self):
# We create a new graph here to simplify clean-up, since the error
# we are triggering happens in the middle of scope.__exit__() and
# leaves us in a weird state.
with ops.Graph().as_default():
_assert_in_default_state(self)
dist = _TestStrategy()
scope = dist.scope()
scope.__enter__()
self.assertIs(dist, ds_context.get_strategy())
with variable_scope.variable_scope("AA"):
with self.assertRaisesRegexp(RuntimeError,
"Variable scope nesting error"):
scope.__exit__(None, None, None)
_assert_in_default_state(self)
def testSettingSynchronizationAndAggregation(self): def testSettingSynchronizationAndAggregation(self):
_assert_in_default_state(self) _assert_in_default_state(self)
dist = _TestStrategy() dist = _TestStrategy()

View File

@ -560,6 +560,7 @@ class Context(object):
Raises: Raises:
ValueError: If name is not a string or is an invalid device name. ValueError: If name is not a string or is an invalid device name.
RuntimeError: If device scopes are not properly nested.
""" """
eager_context = self._thread_local_data eager_context = self._thread_local_data
old_device_name = eager_context.device_name old_device_name = eager_context.device_name
@ -595,6 +596,8 @@ class Context(object):
eager_context.device_spec = new_device_spec eager_context.device_spec = new_device_spec
yield yield
finally: finally:
if eager_context.device_spec is not new_device_spec:
raise RuntimeError("Exiting device scope without proper scope nesting")
eager_context.device_name = old_device_name eager_context.device_name = old_device_name
eager_context.device_spec = old_device_spec eager_context.device_spec = old_device_spec

View File

@ -3088,18 +3088,25 @@ class Graph(object):
Yields: Yields:
`_variable_creator_scope` is a context manager with a side effect, but `_variable_creator_scope` is a context manager with a side effect, but
doesn't return a value. doesn't return a value.
Raises:
RuntimeError: If variable creator scopes are not properly nested.
""" """
# This step makes a copy of the existing stack, and it also initializes # This step keeps a reference to the existing stack, and it also initializes
# self._thread_local._variable_creator_stack if it doesn't exist yet. # self._thread_local._variable_creator_stack if it doesn't exist yet.
old = list(self._variable_creator_stack) old = self._variable_creator_stack
stack = self._thread_local._variable_creator_stack # pylint: disable=protected-access new = list(old)
stack.append((priority, creator)) new.append((priority, creator))
# Sorting is stable, so we'll put higher-priority creators later in the list # Sorting is stable, so we'll put higher-priority creators later in the list
# but otherwise maintain registration order. # but otherwise maintain registration order.
stack.sort(key=lambda item: item[0]) new.sort(key=lambda item: item[0])
self._thread_local._variable_creator_stack = new # pylint: disable=protected-access
try: try:
yield yield
finally: finally:
if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access
raise RuntimeError(
"Exiting variable_creator_scope without proper nesting.")
self._thread_local._variable_creator_stack = old # pylint: disable=protected-access self._thread_local._variable_creator_stack = old # pylint: disable=protected-access
# Note: this method is private because the API of tf.Graph() is public and # Note: this method is private because the API of tf.Graph() is public and
@ -3108,7 +3115,25 @@ class Graph(object):
def _variable_creator_stack(self): def _variable_creator_stack(self):
if not hasattr(self._thread_local, "_variable_creator_stack"): if not hasattr(self._thread_local, "_variable_creator_stack"):
self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access
return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access
# This previously returned a copy of the stack instead of the stack itself,
# to guard against accidental mutation. Consider, however, code that wants
# to save and restore the variable creator stack:
# def f():
# original_stack = graph._variable_creator_stack
# graph._variable_creator_stack = new_stack
# ... # Some code
# graph._variable_creator_stack = original_stack
#
# And lets say you have some code that calls this function with some
# variable_creator:
# def g():
# with variable_scope.variable_creator_scope(creator):
# f()
# When exiting the variable creator scope, it would see a different stack
# object than it expected leading to a "Exiting variable_creator_scope
# without proper nesting" error.
return self._thread_local._variable_creator_stack # pylint: disable=protected-access
@_variable_creator_stack.setter @_variable_creator_stack.setter
def _variable_creator_stack(self, variable_creator_stack): def _variable_creator_stack(self, variable_creator_stack):
@ -4405,11 +4430,18 @@ class Graph(object):
Yields: Yields:
A context manager that specifies the default device to use for newly A context manager that specifies the default device to use for newly
created ops. created ops.
Raises:
RuntimeError: If device scopes are not properly nested.
""" """
self._add_device_to_stack(device_name_or_function, offset=2) self._add_device_to_stack(device_name_or_function, offset=2)
old_top_of_stack = self._device_function_stack.peek_objs()[0]
try: try:
yield yield
finally: finally:
new_top_of_stack = self._device_function_stack.peek_objs()[0]
if old_top_of_stack is not new_top_of_stack:
raise RuntimeError("Exiting device scope without proper scope nesting.")
self._device_function_stack.pop_obj() self._device_function_stack.pop_obj()
def _apply_device_functions(self, op): def _apply_device_functions(self, op):
@ -5119,9 +5151,7 @@ class Graph(object):
def device(device_name_or_function): def device(device_name_or_function):
"""Wrapper for `Graph.device()` using the default graph. """Wrapper for `Graph.device()` using the default graph.
See See `tf.Graph.device` for more details.
`tf.Graph.device`
for more details.
Args: Args:
device_name_or_function: The device name or function to use in device_name_or_function: The device name or function to use in

View File

@ -1295,6 +1295,22 @@ class DeviceTest(test_util.TensorFlowTestCase):
device: "/device:CPU:5" } device: "/device:CPU:5" }
""", gd) """, gd)
def testNestingErrorGraph(self):
g = ops.Graph()
scope = g.device("/device:GPU:8")
scope.__enter__()
with g.device("/device:GPU:9"):
with self.assertRaises(RuntimeError):
scope.__exit__(None, None, None)
def testNestingErrorEager(self):
with context.eager_mode():
scope = ops.device("/device:CPU:0")
scope.__enter__()
with ops.device(None):
with self.assertRaises(RuntimeError):
scope.__exit__(None, None, None)
def testNoneClearsDefault(self): def testNoneClearsDefault(self):
g = ops.Graph() g = ops.Graph()
with g.device("/job:worker/replica:2/device:CPU:1"): with g.device("/job:worker/replica:2/device:CPU:1"):

View File

@ -437,6 +437,9 @@ class ConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testParallelApplyGrad(self): def testParallelApplyGrad(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@ -463,6 +466,9 @@ class ConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testParallelTakeGrad(self): def testParallelTakeGrad(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@ -496,6 +502,9 @@ class ConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testAccumulatorApplyAndBlockingTake(self): def testAccumulatorApplyAndBlockingTake(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator( q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
@ -33,6 +34,9 @@ class AssignOpTest(test.TestCase):
# contain benign and deliberate data races when multiple threads update # contain benign and deliberate data races when multiple threads update
# the same parameters without a lock. # the same parameters without a lock.
def testParallelUpdateWithoutLocking(self): def testParallelUpdateWithoutLocking(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], 1.0) ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(array_ops.zeros([1024, 1024])) p = variables.Variable(array_ops.zeros([1024, 1024]))
@ -60,6 +64,9 @@ class AssignOpTest(test.TestCase):
self.assertTrue((vals <= ones * 20).all()) self.assertTrue((vals <= ones * 20).all())
def testParallelAssignWithoutLocking(self): def testParallelAssignWithoutLocking(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], float(1)) ones_t = array_ops.fill([1024, 1024], float(1))
p = variables.Variable(array_ops.zeros([1024, 1024])) p = variables.Variable(array_ops.zeros([1024, 1024]))
@ -92,6 +99,9 @@ class AssignOpTest(test.TestCase):
# returning the output tensors. This issue will be resolved with the new # returning the output tensors. This issue will be resolved with the new
# resource variables. # resource variables.
def testParallelUpdateWithLocking(self): def testParallelUpdateWithLocking(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
zeros_t = array_ops.fill([1024, 1024], 0.0) zeros_t = array_ops.fill([1024, 1024], 0.0)
ones_t = array_ops.fill([1024, 1024], 1.0) ones_t = array_ops.fill([1024, 1024], 1.0)
@ -119,6 +129,9 @@ class AssignOpTest(test.TestCase):
self.assertAllEqual(vals, ones * 20) self.assertAllEqual(vals, ones * 20)
def testParallelAssignWithLocking(self): def testParallelAssignWithLocking(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
zeros_t = array_ops.fill([1024, 1024], 0.0) zeros_t = array_ops.fill([1024, 1024], 0.0)
ones_t = array_ops.fill([1024, 1024], 1.0) ones_t = array_ops.fill([1024, 1024], 1.0)

View File

@ -42,6 +42,11 @@ from tensorflow.python.util import compat
@test_util.run_v1_only("FIFOQueue removed from v2") @test_util.run_v1_only("FIFOQueue removed from v2")
class FIFOQueueTest(test.TestCase): class FIFOQueueTest(test.TestCase):
def setUp(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
def testConstructor(self): def testConstructor(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q") q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q")

View File

@ -120,6 +120,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(4, q.size().eval()) self.assertEqual(4, q.size().eval())
def testParallelEnqueue(self): def testParallelEnqueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@ -146,6 +149,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results) self.assertItemsEqual(elems, results)
def testParallelDequeue(self): def testParallelDequeue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@ -184,6 +190,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals) self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self): def testEnqueueAndBlockingDequeue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0] elems = [10.0, 20.0, 30.0]
@ -627,6 +636,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.evaluate(dequeued_t) self.evaluate(dequeued_t)
def testParallelEnqueueMany(self): def testParallelEnqueueMany(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),)) q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(100)] elems = [10.0 * x for x in range(100)]
@ -646,6 +658,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval(), elems * 10) self.assertItemsEqual(dequeued_t.eval(), elems * 10)
def testParallelDequeueMany(self): def testParallelDequeueMany(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),)) q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(1000)] elems = [10.0 * x for x in range(1000)]
@ -668,6 +683,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems) self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpTo(self): def testParallelDequeueUpTo(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),)) q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(1000)] elems = [10.0 * x for x in range(1000)]
@ -692,6 +710,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems) self.assertItemsEqual(elems, dequeued_elems)
def testParallelEnqueueAndDequeue(self): def testParallelEnqueueAndDequeue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),)) q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),))
initial_elements = [10.0] * 49 initial_elements = [10.0] * 49
@ -795,6 +816,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval()) self.assertEqual(0, q.size().eval())
def testBlockingDequeueMany(self): def testBlockingDequeueMany(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -822,6 +846,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems) self.assertAllEqual(elems, dequeued_elems)
def testBlockingDequeueUpTo(self): def testBlockingDequeueUpTo(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -892,6 +919,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.evaluate(dequeued_t) self.evaluate(dequeued_t)
def testBlockingDequeueFromClosedQueue(self): def testBlockingDequeueFromClosedQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -940,6 +970,9 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join() dequeue_thread.join()
def testBlockingDequeueFromClosedEmptyQueue(self): def testBlockingDequeueFromClosedEmptyQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close() close_op = q.close()
@ -960,6 +993,9 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join() dequeue_thread.join()
def testBlockingDequeueManyFromClosedQueue(self): def testBlockingDequeueManyFromClosedQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -985,6 +1021,9 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join() dequeue_thread.join()
def testBlockingDequeueManyButNotAllFromClosedQueue(self): def testBlockingDequeueManyButNotAllFromClosedQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -1010,6 +1049,9 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join() dequeue_thread.join()
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self): def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -1080,6 +1122,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval()) self.assertEqual(0, q.size().eval())
def testBlockingDequeueManyFromClosedEmptyQueue(self): def testBlockingDequeueManyFromClosedEmptyQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close() close_op = q.close()
@ -1100,6 +1145,9 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join() dequeue_thread.join()
def testBlockingDequeueUpToFromClosedEmptyQueue(self): def testBlockingDequeueUpToFromClosedEmptyQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close() close_op = q.close()
@ -1147,6 +1195,9 @@ class PaddingFIFOQueueTest(test.TestCase):
enqueue_op.run() enqueue_op.run()
def testBlockingEnqueueToFullQueue(self): def testBlockingEnqueueToFullQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -1170,6 +1221,9 @@ class PaddingFIFOQueueTest(test.TestCase):
thread.join() thread.join()
def testBlockingEnqueueManyToFullQueue(self): def testBlockingEnqueueManyToFullQueue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -1197,6 +1251,9 @@ class PaddingFIFOQueueTest(test.TestCase):
thread.join() thread.join()
def testBlockingEnqueueBeforeClose(self): def testBlockingEnqueueBeforeClose(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0] elems = [10.0, 20.0, 30.0, 40.0]
@ -1234,6 +1291,9 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval()) self.assertEqual(0, q.size().eval())
def testBlockingEnqueueManyBeforeClose(self): def testBlockingEnqueueManyBeforeClose(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),)) q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0] elems = [10.0, 20.0, 30.0]
@ -1397,6 +1457,9 @@ class PaddingFIFOQueueTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testResetOfBlockingOperation(self): def testResetOfBlockingOperation(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),)) q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),))
dequeue_op = q_empty.dequeue() dequeue_op = q_empty.dequeue()

View File

@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
@ -69,6 +70,9 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set()) self.assertEqual(missed, set())
def testRoundTripInsertMultiThreadedReadOnceSorts(self): def testRoundTripInsertMultiThreadedReadOnceSorts(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ())) (), ()))
@ -115,6 +119,9 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set()) self.assertEqual(missed, set())
def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self): def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (())) q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
@ -165,6 +172,9 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values)) self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self): def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
@ -221,6 +231,9 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(set(dequeued), set(all_enqueued_values)) self.assertAllEqual(set(dequeued), set(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self): def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ())) (), ()))

View File

@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -42,6 +43,9 @@ class RandomShuffleQueueTest(test.TestCase):
# Useful for debugging when a test times out. # Useful for debugging when a test times out.
super(RandomShuffleQueueTest, self).setUp() super(RandomShuffleQueueTest, self).setUp()
tf_logging.error("Starting: %s", self._testMethodName) tf_logging.error("Starting: %s", self._testMethodName)
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
def tearDown(self): def tearDown(self):
super(RandomShuffleQueueTest, self).tearDown() super(RandomShuffleQueueTest, self).tearDown()
@ -1237,6 +1241,9 @@ class RandomShuffleQueueTest(test.TestCase):
self.evaluate(enqueue_many_op) self.evaluate(enqueue_many_op)
def testResetOfBlockingOperation(self): def testResetOfBlockingOperation(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, ( q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, (
(),)) (),))

View File

@ -269,6 +269,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testParallelApplyGradMean(self): def testParallelApplyGradMean(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@ -301,6 +304,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testParallelApplyGradSum(self): def testParallelApplyGradSum(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, dtypes_lib.float32,
@ -336,6 +342,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testParallelTakeGrad(self): def testParallelTakeGrad(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@ -376,6 +385,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testAccumulatorApplyAndBlockingTake(self): def testAccumulatorApplyAndBlockingTake(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@ -412,6 +424,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testAccumulatorCancel(self): def testAccumulatorCancel(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess: with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator( q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, dtypes_lib.float32,

View File

@ -148,6 +148,17 @@ class VariableScopeTest(test.TestCase):
w = variable_scope.get_variable("w", []) w = variable_scope.get_variable("w", [])
self.assertEqual(w.constraint, constraint) self.assertEqual(w.constraint, constraint)
@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testVarScopeNestingError(self):
with variable_scope.variable_scope("aa"):
scope = variable_scope.variable_scope("bb")
scope.__enter__()
with variable_scope.variable_scope("cc"):
with self.assertRaises(RuntimeError):
scope.__exit__(None, None, None)
scope.__exit__(None, None, None)
# TODO(mihaimaruseac): Not converted to use wrap_function because of # TODO(mihaimaruseac): Not converted to use wrap_function because of
# TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string> # TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string>
# has invalid type <class '...ResourceVariable'>, must be a string or Tensor. # has invalid type <class '...ResourceVariable'>, must be a string or Tensor.
@ -1704,6 +1715,26 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
aggregation=variable_scope.VariableAggregation.MEAN) aggregation=variable_scope.VariableAggregation.MEAN)
self.assertTrue(called[0]) self.assertTrue(called[0])
@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testVariableCreatorNestingError(self):
def creator(next_creator, **kwargs):
return next_creator(**kwargs)
# Save the state so we can clean up at the end.
graph = ops.get_default_graph()
old_creator_stack = graph._variable_creator_stack
try:
scope = variable_scope.variable_creator_scope(creator)
scope.__enter__()
with variable_scope.variable_creator_scope(creator):
with self.assertRaises(RuntimeError):
scope.__exit__(None, None, None)
finally:
graph._variable_creator_stack = old_creator_stack
class PartitionInfoTest(test.TestCase): class PartitionInfoTest(test.TestCase):

View File

@ -1835,6 +1835,7 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name
self._constraint = constraint self._constraint = constraint
self._var_store = _get_default_variable_store() self._var_store = _get_default_variable_store()
self._var_scope_store = get_variable_scope_store() self._var_scope_store = get_variable_scope_store()
self._last_variable_scope_object = None
if isinstance(self._name_or_scope, VariableScope): if isinstance(self._name_or_scope, VariableScope):
self._new_name = self._name_or_scope.name self._new_name = self._name_or_scope.name
name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access
@ -1931,9 +1932,13 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name
variable_scope_object.set_use_resource(self._use_resource) variable_scope_object.set_use_resource(self._use_resource)
self._var_scope_store.open_variable_scope(self._new_name) self._var_scope_store.open_variable_scope(self._new_name)
self._var_scope_store.current_scope = variable_scope_object self._var_scope_store.current_scope = variable_scope_object
self._last_variable_scope_object = variable_scope_object
return variable_scope_object return variable_scope_object
def __exit__(self, type_arg, value_arg, traceback_arg): def __exit__(self, type_arg, value_arg, traceback_arg):
if (self._var_scope_store.current_scope is not
self._last_variable_scope_object):
raise RuntimeError("Improper nesting of variable_scope.")
# If jumping out from a non-prolonged scope, restore counts. # If jumping out from a non-prolonged scope, restore counts.
if isinstance(self._name_or_scope, VariableScope): if isinstance(self._name_or_scope, VariableScope):
self._var_scope_store.variable_scopes_count = self._old_subscopes self._var_scope_store.variable_scopes_count = self._old_subscopes
@ -2225,11 +2230,10 @@ class variable_scope(object):
try: try:
return self._enter_scope_uncached() return self._enter_scope_uncached()
except Exception: finally:
if self._in_graph_mode and not self._building_function: if (self._in_graph_mode and not self._building_function and
if self._graph_context_manager is not None: self._graph_context_manager is not None):
self._graph_context_manager.__exit__(*sys.exc_info()) self._graph_context_manager.__exit__(*sys.exc_info())
raise
def _enter_scope_uncached(self): def _enter_scope_uncached(self):
"""Enters the context manager when there is no cached scope yet. """Enters the context manager when there is no cached scope yet.