Raise an error if several scopes are not properly nested, e.g. due to
using __enter__ and __exit__ on the context manager directly instead of using "with" blocks. In preparation for creating a tf.distribute.set_strategy() API which will manually call strategy.scope().__enter__ and .__exit__. PiperOrigin-RevId: 239124241
This commit is contained in:
parent
f1e0fa966d
commit
6bdbc4689e
@ -465,6 +465,9 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
|
|||||||
dtypes.string, shape=(len(example_weights),))
|
dtypes.string, shape=(len(example_weights),))
|
||||||
examples['example_ids'] = example_ids
|
examples['example_ids'] = example_ids
|
||||||
variables = make_variable_dict(1, 1)
|
variables = make_variable_dict(1, 1)
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
for num_shards in _SHARD_NUMBERS:
|
for num_shards in _SHARD_NUMBERS:
|
||||||
for num_loss_partitions in _NUM_LOSS_PARTITIONS:
|
for num_loss_partitions in _NUM_LOSS_PARTITIONS:
|
||||||
with self._single_threaded_test_session():
|
with self._single_threaded_test_session():
|
||||||
|
@ -19,9 +19,10 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import enum
|
||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
import enum
|
import six
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
@ -189,10 +190,31 @@ class _CurrentDistributionContext(object):
|
|||||||
|
|
||||||
def __exit__(self, exception_type, exception_value, traceback):
|
def __exit__(self, exception_type, exception_value, traceback):
|
||||||
if self._device_scope:
|
if self._device_scope:
|
||||||
self._device_scope.__exit__(exception_type, exception_value, traceback)
|
try:
|
||||||
self._var_creator_scope.__exit__(exception_type, exception_value, traceback)
|
self._device_scope.__exit__(exception_type, exception_value, traceback)
|
||||||
|
except RuntimeError as e:
|
||||||
|
six.raise_from(
|
||||||
|
RuntimeError("Device scope nesting error: move call to "
|
||||||
|
"tf.distribute.set_strategy() out of `with` scope."),
|
||||||
|
e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._var_creator_scope.__exit__(
|
||||||
|
exception_type, exception_value, traceback)
|
||||||
|
except RuntimeError as e:
|
||||||
|
six.raise_from(
|
||||||
|
RuntimeError("Variable creator scope nesting error: move call to "
|
||||||
|
"tf.distribute.set_strategy() out of `with` scope."),
|
||||||
|
e)
|
||||||
|
|
||||||
if self._var_scope:
|
if self._var_scope:
|
||||||
self._var_scope.__exit__(exception_type, exception_value, traceback)
|
try:
|
||||||
|
self._var_scope.__exit__(exception_type, exception_value, traceback)
|
||||||
|
except RuntimeError as e:
|
||||||
|
six.raise_from(
|
||||||
|
RuntimeError("Variable scope nesting error: move call to "
|
||||||
|
"tf.distribute.set_strategy() out of `with` scope."),
|
||||||
|
e)
|
||||||
_pop_per_thread_mode()
|
_pop_per_thread_mode()
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribute_lib
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -109,6 +110,53 @@ class TestStrategyTest(test.TestCase):
|
|||||||
variable_scope.variable(1.0, name="baz"))
|
variable_scope.variable(1.0, name="baz"))
|
||||||
_assert_in_default_state(self)
|
_assert_in_default_state(self)
|
||||||
|
|
||||||
|
def testScopeDeviceNestingError(self):
|
||||||
|
_assert_in_default_state(self)
|
||||||
|
dist = _TestStrategy()
|
||||||
|
# Open a device scope with dist.scope().
|
||||||
|
dist.extended._default_device = "/device:GPU:0"
|
||||||
|
scope = dist.scope()
|
||||||
|
scope.__enter__()
|
||||||
|
self.assertIs(dist, ds_context.get_strategy())
|
||||||
|
with ops.device("/device:CPU:0"):
|
||||||
|
with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
_assert_in_default_state(self)
|
||||||
|
|
||||||
|
def testScopeVarCreatorNestingError(self):
|
||||||
|
|
||||||
|
def creator(next_creator, **kwargs):
|
||||||
|
return next_creator(**kwargs)
|
||||||
|
|
||||||
|
_assert_in_default_state(self)
|
||||||
|
dist = _TestStrategy()
|
||||||
|
scope = dist.scope()
|
||||||
|
scope.__enter__()
|
||||||
|
self.assertIs(dist, ds_context.get_strategy())
|
||||||
|
with variable_scope.variable_creator_scope(creator):
|
||||||
|
with self.assertRaisesRegexp(RuntimeError,
|
||||||
|
"Variable creator scope nesting error"):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
_assert_in_default_state(self)
|
||||||
|
|
||||||
|
def testScopeVarScopeNestingError(self):
|
||||||
|
# We create a new graph here to simplify clean-up, since the error
|
||||||
|
# we are triggering happens in the middle of scope.__exit__() and
|
||||||
|
# leaves us in a weird state.
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
_assert_in_default_state(self)
|
||||||
|
dist = _TestStrategy()
|
||||||
|
scope = dist.scope()
|
||||||
|
scope.__enter__()
|
||||||
|
self.assertIs(dist, ds_context.get_strategy())
|
||||||
|
with variable_scope.variable_scope("AA"):
|
||||||
|
with self.assertRaisesRegexp(RuntimeError,
|
||||||
|
"Variable scope nesting error"):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
_assert_in_default_state(self)
|
||||||
|
|
||||||
def testSettingSynchronizationAndAggregation(self):
|
def testSettingSynchronizationAndAggregation(self):
|
||||||
_assert_in_default_state(self)
|
_assert_in_default_state(self)
|
||||||
dist = _TestStrategy()
|
dist = _TestStrategy()
|
||||||
|
@ -560,6 +560,7 @@ class Context(object):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If name is not a string or is an invalid device name.
|
ValueError: If name is not a string or is an invalid device name.
|
||||||
|
RuntimeError: If device scopes are not properly nested.
|
||||||
"""
|
"""
|
||||||
eager_context = self._thread_local_data
|
eager_context = self._thread_local_data
|
||||||
old_device_name = eager_context.device_name
|
old_device_name = eager_context.device_name
|
||||||
@ -595,6 +596,8 @@ class Context(object):
|
|||||||
eager_context.device_spec = new_device_spec
|
eager_context.device_spec = new_device_spec
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
if eager_context.device_spec is not new_device_spec:
|
||||||
|
raise RuntimeError("Exiting device scope without proper scope nesting")
|
||||||
eager_context.device_name = old_device_name
|
eager_context.device_name = old_device_name
|
||||||
eager_context.device_spec = old_device_spec
|
eager_context.device_spec = old_device_spec
|
||||||
|
|
||||||
|
@ -3088,18 +3088,25 @@ class Graph(object):
|
|||||||
Yields:
|
Yields:
|
||||||
`_variable_creator_scope` is a context manager with a side effect, but
|
`_variable_creator_scope` is a context manager with a side effect, but
|
||||||
doesn't return a value.
|
doesn't return a value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If variable creator scopes are not properly nested.
|
||||||
"""
|
"""
|
||||||
# This step makes a copy of the existing stack, and it also initializes
|
# This step keeps a reference to the existing stack, and it also initializes
|
||||||
# self._thread_local._variable_creator_stack if it doesn't exist yet.
|
# self._thread_local._variable_creator_stack if it doesn't exist yet.
|
||||||
old = list(self._variable_creator_stack)
|
old = self._variable_creator_stack
|
||||||
stack = self._thread_local._variable_creator_stack # pylint: disable=protected-access
|
new = list(old)
|
||||||
stack.append((priority, creator))
|
new.append((priority, creator))
|
||||||
# Sorting is stable, so we'll put higher-priority creators later in the list
|
# Sorting is stable, so we'll put higher-priority creators later in the list
|
||||||
# but otherwise maintain registration order.
|
# but otherwise maintain registration order.
|
||||||
stack.sort(key=lambda item: item[0])
|
new.sort(key=lambda item: item[0])
|
||||||
|
self._thread_local._variable_creator_stack = new # pylint: disable=protected-access
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access
|
||||||
|
raise RuntimeError(
|
||||||
|
"Exiting variable_creator_scope without proper nesting.")
|
||||||
self._thread_local._variable_creator_stack = old # pylint: disable=protected-access
|
self._thread_local._variable_creator_stack = old # pylint: disable=protected-access
|
||||||
|
|
||||||
# Note: this method is private because the API of tf.Graph() is public and
|
# Note: this method is private because the API of tf.Graph() is public and
|
||||||
@ -3108,7 +3115,25 @@ class Graph(object):
|
|||||||
def _variable_creator_stack(self):
|
def _variable_creator_stack(self):
|
||||||
if not hasattr(self._thread_local, "_variable_creator_stack"):
|
if not hasattr(self._thread_local, "_variable_creator_stack"):
|
||||||
self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access
|
self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access
|
||||||
return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access
|
|
||||||
|
# This previously returned a copy of the stack instead of the stack itself,
|
||||||
|
# to guard against accidental mutation. Consider, however, code that wants
|
||||||
|
# to save and restore the variable creator stack:
|
||||||
|
# def f():
|
||||||
|
# original_stack = graph._variable_creator_stack
|
||||||
|
# graph._variable_creator_stack = new_stack
|
||||||
|
# ... # Some code
|
||||||
|
# graph._variable_creator_stack = original_stack
|
||||||
|
#
|
||||||
|
# And lets say you have some code that calls this function with some
|
||||||
|
# variable_creator:
|
||||||
|
# def g():
|
||||||
|
# with variable_scope.variable_creator_scope(creator):
|
||||||
|
# f()
|
||||||
|
# When exiting the variable creator scope, it would see a different stack
|
||||||
|
# object than it expected leading to a "Exiting variable_creator_scope
|
||||||
|
# without proper nesting" error.
|
||||||
|
return self._thread_local._variable_creator_stack # pylint: disable=protected-access
|
||||||
|
|
||||||
@_variable_creator_stack.setter
|
@_variable_creator_stack.setter
|
||||||
def _variable_creator_stack(self, variable_creator_stack):
|
def _variable_creator_stack(self, variable_creator_stack):
|
||||||
@ -4405,11 +4430,18 @@ class Graph(object):
|
|||||||
Yields:
|
Yields:
|
||||||
A context manager that specifies the default device to use for newly
|
A context manager that specifies the default device to use for newly
|
||||||
created ops.
|
created ops.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If device scopes are not properly nested.
|
||||||
"""
|
"""
|
||||||
self._add_device_to_stack(device_name_or_function, offset=2)
|
self._add_device_to_stack(device_name_or_function, offset=2)
|
||||||
|
old_top_of_stack = self._device_function_stack.peek_objs()[0]
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
new_top_of_stack = self._device_function_stack.peek_objs()[0]
|
||||||
|
if old_top_of_stack is not new_top_of_stack:
|
||||||
|
raise RuntimeError("Exiting device scope without proper scope nesting.")
|
||||||
self._device_function_stack.pop_obj()
|
self._device_function_stack.pop_obj()
|
||||||
|
|
||||||
def _apply_device_functions(self, op):
|
def _apply_device_functions(self, op):
|
||||||
@ -5119,9 +5151,7 @@ class Graph(object):
|
|||||||
def device(device_name_or_function):
|
def device(device_name_or_function):
|
||||||
"""Wrapper for `Graph.device()` using the default graph.
|
"""Wrapper for `Graph.device()` using the default graph.
|
||||||
|
|
||||||
See
|
See `tf.Graph.device` for more details.
|
||||||
`tf.Graph.device`
|
|
||||||
for more details.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_name_or_function: The device name or function to use in
|
device_name_or_function: The device name or function to use in
|
||||||
|
@ -1295,6 +1295,22 @@ class DeviceTest(test_util.TensorFlowTestCase):
|
|||||||
device: "/device:CPU:5" }
|
device: "/device:CPU:5" }
|
||||||
""", gd)
|
""", gd)
|
||||||
|
|
||||||
|
def testNestingErrorGraph(self):
|
||||||
|
g = ops.Graph()
|
||||||
|
scope = g.device("/device:GPU:8")
|
||||||
|
scope.__enter__()
|
||||||
|
with g.device("/device:GPU:9"):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
|
||||||
|
def testNestingErrorEager(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
scope = ops.device("/device:CPU:0")
|
||||||
|
scope.__enter__()
|
||||||
|
with ops.device(None):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
|
||||||
def testNoneClearsDefault(self):
|
def testNoneClearsDefault(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.device("/job:worker/replica:2/device:CPU:1"):
|
with g.device("/job:worker/replica:2/device:CPU:1"):
|
||||||
|
@ -437,6 +437,9 @@ class ConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testParallelApplyGrad(self):
|
def testParallelApplyGrad(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.ConditionalAccumulator(
|
q = data_flow_ops.ConditionalAccumulator(
|
||||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
||||||
@ -463,6 +466,9 @@ class ConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testParallelTakeGrad(self):
|
def testParallelTakeGrad(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.ConditionalAccumulator(
|
q = data_flow_ops.ConditionalAccumulator(
|
||||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
||||||
@ -496,6 +502,9 @@ class ConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testAccumulatorApplyAndBlockingTake(self):
|
def testAccumulatorApplyAndBlockingTake(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.ConditionalAccumulator(
|
q = data_flow_ops.ConditionalAccumulator(
|
||||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
@ -33,6 +34,9 @@ class AssignOpTest(test.TestCase):
|
|||||||
# contain benign and deliberate data races when multiple threads update
|
# contain benign and deliberate data races when multiple threads update
|
||||||
# the same parameters without a lock.
|
# the same parameters without a lock.
|
||||||
def testParallelUpdateWithoutLocking(self):
|
def testParallelUpdateWithoutLocking(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
ones_t = array_ops.fill([1024, 1024], 1.0)
|
ones_t = array_ops.fill([1024, 1024], 1.0)
|
||||||
p = variables.Variable(array_ops.zeros([1024, 1024]))
|
p = variables.Variable(array_ops.zeros([1024, 1024]))
|
||||||
@ -60,6 +64,9 @@ class AssignOpTest(test.TestCase):
|
|||||||
self.assertTrue((vals <= ones * 20).all())
|
self.assertTrue((vals <= ones * 20).all())
|
||||||
|
|
||||||
def testParallelAssignWithoutLocking(self):
|
def testParallelAssignWithoutLocking(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
ones_t = array_ops.fill([1024, 1024], float(1))
|
ones_t = array_ops.fill([1024, 1024], float(1))
|
||||||
p = variables.Variable(array_ops.zeros([1024, 1024]))
|
p = variables.Variable(array_ops.zeros([1024, 1024]))
|
||||||
@ -92,6 +99,9 @@ class AssignOpTest(test.TestCase):
|
|||||||
# returning the output tensors. This issue will be resolved with the new
|
# returning the output tensors. This issue will be resolved with the new
|
||||||
# resource variables.
|
# resource variables.
|
||||||
def testParallelUpdateWithLocking(self):
|
def testParallelUpdateWithLocking(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
zeros_t = array_ops.fill([1024, 1024], 0.0)
|
zeros_t = array_ops.fill([1024, 1024], 0.0)
|
||||||
ones_t = array_ops.fill([1024, 1024], 1.0)
|
ones_t = array_ops.fill([1024, 1024], 1.0)
|
||||||
@ -119,6 +129,9 @@ class AssignOpTest(test.TestCase):
|
|||||||
self.assertAllEqual(vals, ones * 20)
|
self.assertAllEqual(vals, ones * 20)
|
||||||
|
|
||||||
def testParallelAssignWithLocking(self):
|
def testParallelAssignWithLocking(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
zeros_t = array_ops.fill([1024, 1024], 0.0)
|
zeros_t = array_ops.fill([1024, 1024], 0.0)
|
||||||
ones_t = array_ops.fill([1024, 1024], 1.0)
|
ones_t = array_ops.fill([1024, 1024], 1.0)
|
||||||
|
@ -42,6 +42,11 @@ from tensorflow.python.util import compat
|
|||||||
@test_util.run_v1_only("FIFOQueue removed from v2")
|
@test_util.run_v1_only("FIFOQueue removed from v2")
|
||||||
class FIFOQueueTest(test.TestCase):
|
class FIFOQueueTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
|
|
||||||
def testConstructor(self):
|
def testConstructor(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q")
|
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, name="Q")
|
||||||
|
@ -120,6 +120,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertEqual(4, q.size().eval())
|
self.assertEqual(4, q.size().eval())
|
||||||
|
|
||||||
def testParallelEnqueue(self):
|
def testParallelEnqueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||||
@ -146,6 +149,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertItemsEqual(elems, results)
|
self.assertItemsEqual(elems, results)
|
||||||
|
|
||||||
def testParallelDequeue(self):
|
def testParallelDequeue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||||
@ -184,6 +190,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertEqual([elems[i]], vals)
|
self.assertEqual([elems[i]], vals)
|
||||||
|
|
||||||
def testEnqueueAndBlockingDequeue(self):
|
def testEnqueueAndBlockingDequeue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0]
|
elems = [10.0, 20.0, 30.0]
|
||||||
@ -627,6 +636,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.evaluate(dequeued_t)
|
self.evaluate(dequeued_t)
|
||||||
|
|
||||||
def testParallelEnqueueMany(self):
|
def testParallelEnqueueMany(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
|
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
|
||||||
elems = [10.0 * x for x in range(100)]
|
elems = [10.0 * x for x in range(100)]
|
||||||
@ -646,6 +658,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
|
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
|
||||||
|
|
||||||
def testParallelDequeueMany(self):
|
def testParallelDequeueMany(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
|
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
|
||||||
elems = [10.0 * x for x in range(1000)]
|
elems = [10.0 * x for x in range(1000)]
|
||||||
@ -668,6 +683,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertItemsEqual(elems, dequeued_elems)
|
self.assertItemsEqual(elems, dequeued_elems)
|
||||||
|
|
||||||
def testParallelDequeueUpTo(self):
|
def testParallelDequeueUpTo(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
|
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
|
||||||
elems = [10.0 * x for x in range(1000)]
|
elems = [10.0 * x for x in range(1000)]
|
||||||
@ -692,6 +710,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertItemsEqual(elems, dequeued_elems)
|
self.assertItemsEqual(elems, dequeued_elems)
|
||||||
|
|
||||||
def testParallelEnqueueAndDequeue(self):
|
def testParallelEnqueueAndDequeue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),))
|
q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),))
|
||||||
initial_elements = [10.0] * 49
|
initial_elements = [10.0] * 49
|
||||||
@ -795,6 +816,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertEqual(0, q.size().eval())
|
self.assertEqual(0, q.size().eval())
|
||||||
|
|
||||||
def testBlockingDequeueMany(self):
|
def testBlockingDequeueMany(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -822,6 +846,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertAllEqual(elems, dequeued_elems)
|
self.assertAllEqual(elems, dequeued_elems)
|
||||||
|
|
||||||
def testBlockingDequeueUpTo(self):
|
def testBlockingDequeueUpTo(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -892,6 +919,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.evaluate(dequeued_t)
|
self.evaluate(dequeued_t)
|
||||||
|
|
||||||
def testBlockingDequeueFromClosedQueue(self):
|
def testBlockingDequeueFromClosedQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -940,6 +970,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
dequeue_thread.join()
|
dequeue_thread.join()
|
||||||
|
|
||||||
def testBlockingDequeueFromClosedEmptyQueue(self):
|
def testBlockingDequeueFromClosedEmptyQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
close_op = q.close()
|
close_op = q.close()
|
||||||
@ -960,6 +993,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
dequeue_thread.join()
|
dequeue_thread.join()
|
||||||
|
|
||||||
def testBlockingDequeueManyFromClosedQueue(self):
|
def testBlockingDequeueManyFromClosedQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -985,6 +1021,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
dequeue_thread.join()
|
dequeue_thread.join()
|
||||||
|
|
||||||
def testBlockingDequeueManyButNotAllFromClosedQueue(self):
|
def testBlockingDequeueManyButNotAllFromClosedQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -1010,6 +1049,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
dequeue_thread.join()
|
dequeue_thread.join()
|
||||||
|
|
||||||
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
|
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -1080,6 +1122,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertEqual(0, q.size().eval())
|
self.assertEqual(0, q.size().eval())
|
||||||
|
|
||||||
def testBlockingDequeueManyFromClosedEmptyQueue(self):
|
def testBlockingDequeueManyFromClosedEmptyQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
close_op = q.close()
|
close_op = q.close()
|
||||||
@ -1100,6 +1145,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
dequeue_thread.join()
|
dequeue_thread.join()
|
||||||
|
|
||||||
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
|
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
|
||||||
close_op = q.close()
|
close_op = q.close()
|
||||||
@ -1147,6 +1195,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
enqueue_op.run()
|
enqueue_op.run()
|
||||||
|
|
||||||
def testBlockingEnqueueToFullQueue(self):
|
def testBlockingEnqueueToFullQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -1170,6 +1221,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
def testBlockingEnqueueManyToFullQueue(self):
|
def testBlockingEnqueueManyToFullQueue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -1197,6 +1251,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
def testBlockingEnqueueBeforeClose(self):
|
def testBlockingEnqueueBeforeClose(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0, 40.0]
|
elems = [10.0, 20.0, 30.0, 40.0]
|
||||||
@ -1234,6 +1291,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
self.assertEqual(0, q.size().eval())
|
self.assertEqual(0, q.size().eval())
|
||||||
|
|
||||||
def testBlockingEnqueueManyBeforeClose(self):
|
def testBlockingEnqueueManyBeforeClose(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
|
||||||
elems = [10.0, 20.0, 30.0]
|
elems = [10.0, 20.0, 30.0]
|
||||||
@ -1397,6 +1457,9 @@ class PaddingFIFOQueueTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testResetOfBlockingOperation(self):
|
def testResetOfBlockingOperation(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),))
|
q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),))
|
||||||
dequeue_op = q_empty.dequeue()
|
dequeue_op = q_empty.dequeue()
|
||||||
|
@ -27,6 +27,7 @@ import numpy as np
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
@ -69,6 +70,9 @@ class PriorityQueueTest(test.TestCase):
|
|||||||
self.assertEqual(missed, set())
|
self.assertEqual(missed, set())
|
||||||
|
|
||||||
def testRoundTripInsertMultiThreadedReadOnceSorts(self):
|
def testRoundTripInsertMultiThreadedReadOnceSorts(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
|
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
|
||||||
(), ()))
|
(), ()))
|
||||||
@ -115,6 +119,9 @@ class PriorityQueueTest(test.TestCase):
|
|||||||
self.assertEqual(missed, set())
|
self.assertEqual(missed, set())
|
||||||
|
|
||||||
def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
|
def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
|
q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
|
||||||
|
|
||||||
@ -165,6 +172,9 @@ class PriorityQueueTest(test.TestCase):
|
|||||||
self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
|
self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
|
||||||
|
|
||||||
def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
|
def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
|
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
|
||||||
|
|
||||||
@ -221,6 +231,9 @@ class PriorityQueueTest(test.TestCase):
|
|||||||
self.assertAllEqual(set(dequeued), set(all_enqueued_values))
|
self.assertAllEqual(set(dequeued), set(all_enqueued_values))
|
||||||
|
|
||||||
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
|
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
|
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
|
||||||
(), ()))
|
(), ()))
|
||||||
|
@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -42,6 +43,9 @@ class RandomShuffleQueueTest(test.TestCase):
|
|||||||
# Useful for debugging when a test times out.
|
# Useful for debugging when a test times out.
|
||||||
super(RandomShuffleQueueTest, self).setUp()
|
super(RandomShuffleQueueTest, self).setUp()
|
||||||
tf_logging.error("Starting: %s", self._testMethodName)
|
tf_logging.error("Starting: %s", self._testMethodName)
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super(RandomShuffleQueueTest, self).tearDown()
|
super(RandomShuffleQueueTest, self).tearDown()
|
||||||
@ -1237,6 +1241,9 @@ class RandomShuffleQueueTest(test.TestCase):
|
|||||||
self.evaluate(enqueue_many_op)
|
self.evaluate(enqueue_many_op)
|
||||||
|
|
||||||
def testResetOfBlockingOperation(self):
|
def testResetOfBlockingOperation(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, (
|
q_empty = data_flow_ops.RandomShuffleQueue(5, 0, dtypes_lib.float32, (
|
||||||
(),))
|
(),))
|
||||||
|
@ -269,6 +269,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testParallelApplyGradMean(self):
|
def testParallelApplyGradMean(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.SparseConditionalAccumulator(
|
q = data_flow_ops.SparseConditionalAccumulator(
|
||||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
||||||
@ -301,6 +304,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testParallelApplyGradSum(self):
|
def testParallelApplyGradSum(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.SparseConditionalAccumulator(
|
q = data_flow_ops.SparseConditionalAccumulator(
|
||||||
dtypes_lib.float32,
|
dtypes_lib.float32,
|
||||||
@ -336,6 +342,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testParallelTakeGrad(self):
|
def testParallelTakeGrad(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.SparseConditionalAccumulator(
|
q = data_flow_ops.SparseConditionalAccumulator(
|
||||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
||||||
@ -376,6 +385,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testAccumulatorApplyAndBlockingTake(self):
|
def testAccumulatorApplyAndBlockingTake(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.SparseConditionalAccumulator(
|
q = data_flow_ops.SparseConditionalAccumulator(
|
||||||
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
|
||||||
@ -412,6 +424,9 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testAccumulatorCancel(self):
|
def testAccumulatorCancel(self):
|
||||||
|
# We need each thread to keep its own device stack or the device scopes
|
||||||
|
# won't be properly nested.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
q = data_flow_ops.SparseConditionalAccumulator(
|
q = data_flow_ops.SparseConditionalAccumulator(
|
||||||
dtypes_lib.float32,
|
dtypes_lib.float32,
|
||||||
|
@ -148,6 +148,17 @@ class VariableScopeTest(test.TestCase):
|
|||||||
w = variable_scope.get_variable("w", [])
|
w = variable_scope.get_variable("w", [])
|
||||||
self.assertEqual(w.constraint, constraint)
|
self.assertEqual(w.constraint, constraint)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@run_inside_wrap_function_in_eager_mode
|
||||||
|
def testVarScopeNestingError(self):
|
||||||
|
with variable_scope.variable_scope("aa"):
|
||||||
|
scope = variable_scope.variable_scope("bb")
|
||||||
|
scope.__enter__()
|
||||||
|
with variable_scope.variable_scope("cc"):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
|
||||||
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
# TODO(mihaimaruseac): Not converted to use wrap_function because of
|
||||||
# TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string>
|
# TypeError: Fetch argument <tf.Variable 'string:0' shape=() dtype=string>
|
||||||
# has invalid type <class '...ResourceVariable'>, must be a string or Tensor.
|
# has invalid type <class '...ResourceVariable'>, must be a string or Tensor.
|
||||||
@ -1704,6 +1715,26 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
|
|||||||
aggregation=variable_scope.VariableAggregation.MEAN)
|
aggregation=variable_scope.VariableAggregation.MEAN)
|
||||||
self.assertTrue(called[0])
|
self.assertTrue(called[0])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@run_inside_wrap_function_in_eager_mode
|
||||||
|
def testVariableCreatorNestingError(self):
|
||||||
|
|
||||||
|
def creator(next_creator, **kwargs):
|
||||||
|
return next_creator(**kwargs)
|
||||||
|
|
||||||
|
# Save the state so we can clean up at the end.
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
old_creator_stack = graph._variable_creator_stack
|
||||||
|
|
||||||
|
try:
|
||||||
|
scope = variable_scope.variable_creator_scope(creator)
|
||||||
|
scope.__enter__()
|
||||||
|
with variable_scope.variable_creator_scope(creator):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
scope.__exit__(None, None, None)
|
||||||
|
finally:
|
||||||
|
graph._variable_creator_stack = old_creator_stack
|
||||||
|
|
||||||
|
|
||||||
class PartitionInfoTest(test.TestCase):
|
class PartitionInfoTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -1835,6 +1835,7 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name
|
|||||||
self._constraint = constraint
|
self._constraint = constraint
|
||||||
self._var_store = _get_default_variable_store()
|
self._var_store = _get_default_variable_store()
|
||||||
self._var_scope_store = get_variable_scope_store()
|
self._var_scope_store = get_variable_scope_store()
|
||||||
|
self._last_variable_scope_object = None
|
||||||
if isinstance(self._name_or_scope, VariableScope):
|
if isinstance(self._name_or_scope, VariableScope):
|
||||||
self._new_name = self._name_or_scope.name
|
self._new_name = self._name_or_scope.name
|
||||||
name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access
|
name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access
|
||||||
@ -1931,9 +1932,13 @@ class _pure_variable_scope(object): # pylint: disable=invalid-name
|
|||||||
variable_scope_object.set_use_resource(self._use_resource)
|
variable_scope_object.set_use_resource(self._use_resource)
|
||||||
self._var_scope_store.open_variable_scope(self._new_name)
|
self._var_scope_store.open_variable_scope(self._new_name)
|
||||||
self._var_scope_store.current_scope = variable_scope_object
|
self._var_scope_store.current_scope = variable_scope_object
|
||||||
|
self._last_variable_scope_object = variable_scope_object
|
||||||
return variable_scope_object
|
return variable_scope_object
|
||||||
|
|
||||||
def __exit__(self, type_arg, value_arg, traceback_arg):
|
def __exit__(self, type_arg, value_arg, traceback_arg):
|
||||||
|
if (self._var_scope_store.current_scope is not
|
||||||
|
self._last_variable_scope_object):
|
||||||
|
raise RuntimeError("Improper nesting of variable_scope.")
|
||||||
# If jumping out from a non-prolonged scope, restore counts.
|
# If jumping out from a non-prolonged scope, restore counts.
|
||||||
if isinstance(self._name_or_scope, VariableScope):
|
if isinstance(self._name_or_scope, VariableScope):
|
||||||
self._var_scope_store.variable_scopes_count = self._old_subscopes
|
self._var_scope_store.variable_scopes_count = self._old_subscopes
|
||||||
@ -2225,11 +2230,10 @@ class variable_scope(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return self._enter_scope_uncached()
|
return self._enter_scope_uncached()
|
||||||
except Exception:
|
finally:
|
||||||
if self._in_graph_mode and not self._building_function:
|
if (self._in_graph_mode and not self._building_function and
|
||||||
if self._graph_context_manager is not None:
|
self._graph_context_manager is not None):
|
||||||
self._graph_context_manager.__exit__(*sys.exc_info())
|
self._graph_context_manager.__exit__(*sys.exc_info())
|
||||||
raise
|
|
||||||
|
|
||||||
def _enter_scope_uncached(self):
|
def _enter_scope_uncached(self):
|
||||||
"""Enters the context manager when there is no cached scope yet.
|
"""Enters the context manager when there is no cached scope yet.
|
||||||
|
Loading…
Reference in New Issue
Block a user