[TF CriticalSection] Make deadlock detection friendly for non-distributed eager execution.
PiperOrigin-RevId: 266177528
This commit is contained in:
parent
b2c7258633
commit
7c94976ae9
tensorflow/python
@ -170,7 +170,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
|||||||
[signature.op for signature in
|
[signature.op for signature in
|
||||||
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
|
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
|
||||||
|
|
||||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
|
||||||
def testRecursiveCriticalSectionAccessIsIllegal(self):
|
def testRecursiveCriticalSectionAccessIsIllegal(self):
|
||||||
# This does not work properly in eager mode. Eager users will
|
# This does not work properly in eager mode. Eager users will
|
||||||
# just hit a deadlock if they do this. But at least it'll be easier
|
# just hit a deadlock if they do this. But at least it'll be easier
|
||||||
@ -181,9 +180,7 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
|||||||
return cs.execute(lambda: add(x))
|
return cs.execute(lambda: add(x))
|
||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError, r"Attempting to lock a CriticalSection in which we are"):
|
||||||
r"attempts to directly access the CriticalSection in which it "
|
|
||||||
r"would be running"):
|
|
||||||
cs.execute(lambda: fn(1.0))
|
cs.execute(lambda: fn(1.0))
|
||||||
|
|
||||||
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
|
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
|
||||||
@ -313,7 +310,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
|||||||
"body_args_capture'\n"
|
"body_args_capture'\n"
|
||||||
"==============\n")
|
"==============\n")
|
||||||
|
|
||||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
|
||||||
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
|
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
|
||||||
# This does not work properly in eager mode. Eager users will
|
# This does not work properly in eager mode. Eager users will
|
||||||
# just hit a deadlock if they do this. But at least it'll be easier
|
# just hit a deadlock if they do this. But at least it'll be easier
|
||||||
@ -324,12 +320,11 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def fn(x):
|
def fn(x):
|
||||||
return cs_same.execute(lambda: add(x))
|
return cs_same.execute(lambda: add(x))
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError,
|
ValueError, r"Attempting to lock a CriticalSection in which we are"):
|
||||||
r"attempts to directly access the CriticalSection in which it "
|
|
||||||
r"would be running"):
|
|
||||||
cs.execute(lambda: fn(1.0))
|
cs.execute(lambda: fn(1.0))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
@test_util.run_v1_only(
|
||||||
|
"b/123955885 Can't identify consumed resources in eager mode")
|
||||||
def testMultipleCSExecutionsRequestSameResource(self):
|
def testMultipleCSExecutionsRequestSameResource(self):
|
||||||
cs0 = critical_section_ops.CriticalSection()
|
cs0 = critical_section_ops.CriticalSection()
|
||||||
cs1 = critical_section_ops.CriticalSection()
|
cs1 = critical_section_ops.CriticalSection()
|
||||||
@ -403,32 +398,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(1, get_first())
|
self.assertEqual(1, get_first())
|
||||||
|
|
||||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
|
||||||
#
|
|
||||||
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):
|
|
||||||
# cs = critical_section_ops.CriticalSection()
|
|
||||||
# r = cs.execute(lambda x: x + 1, 1.0)
|
|
||||||
# graph = ops.get_default_graph()
|
|
||||||
# meta_graph = saver_lib.export_meta_graph(
|
|
||||||
# graph=graph, collection_list=graph.get_all_collection_keys())
|
|
||||||
# graph_copy = ops.Graph()
|
|
||||||
# with graph_copy.as_default():
|
|
||||||
# _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported")
|
|
||||||
# restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)
|
|
||||||
# restored_exec = ops.get_collection(
|
|
||||||
# critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
|
|
||||||
# self.assertEqual(1, len(restored_cs))
|
|
||||||
# self.assertEqual(1, len(restored_exec))
|
|
||||||
# self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name)
|
|
||||||
# self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
|
|
||||||
|
|
||||||
# def testToProto(self):
|
|
||||||
# cs = critical_section_ops.CriticalSection(shared_name="cs")
|
|
||||||
# proto = cs.to_proto()
|
|
||||||
# self.assertEqual(proto.critical_section_name, cs._handle.name)
|
|
||||||
# cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
|
|
||||||
# self.assertEqual(cs_copy._handle, cs._handle)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -19,9 +19,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
import threading
|
||||||
# from tensorflow.core.protobuf import critical_section_pb2
|
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -63,14 +62,66 @@ def _identity(x):
|
|||||||
return array_ops.identity(x)
|
return array_ops.identity(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_device_or_colocation(op):
|
||||||
|
return op.device or _get_colocation(op)
|
||||||
|
|
||||||
|
|
||||||
def _get_colocation(op):
|
def _get_colocation(op):
|
||||||
"""Get colocation symbol from op, if any."""
|
"""Get colocation symbol from op, if any."""
|
||||||
try:
|
try:
|
||||||
return op.get_attr("_class")
|
return op.get_attr("_class")
|
||||||
except ValueError:
|
except (ValueError, AttributeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_CRITICAL_SECTION_STACK = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_critical_section_stack():
|
||||||
|
try:
|
||||||
|
return _CRITICAL_SECTION_STACK.value
|
||||||
|
except AttributeError:
|
||||||
|
_CRITICAL_SECTION_STACK.value = []
|
||||||
|
return _CRITICAL_SECTION_STACK.value
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _push_critical_section_stack(signature):
|
||||||
|
"""Push a CriticalSection._signature to the thread-local stack.
|
||||||
|
|
||||||
|
If the signature is already on the stack, raise an error because it means
|
||||||
|
we're trying to execute inside the same locked CriticalSection, which
|
||||||
|
will create a deadlock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature: Tuple of the type `CriticalSection._signature`. Uniquely
|
||||||
|
identifies a CriticalSection by its `shared_name`, `container`,
|
||||||
|
and device.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An empty value. The context is guaranteed to run without deadlock.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the signature is already on the stack.
|
||||||
|
RuntimeError: If another thread or function modifies the current stack
|
||||||
|
entry during the yield.
|
||||||
|
"""
|
||||||
|
stack = _get_critical_section_stack()
|
||||||
|
if signature in stack:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempting to lock a CriticalSection in which we are "
|
||||||
|
"already running. This is illegal and may cause deadlocks.")
|
||||||
|
stack.append(signature)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
received_signature = stack.pop()
|
||||||
|
if received_signature != signature:
|
||||||
|
raise RuntimeError(
|
||||||
|
"CriticalSection stack inconsistency: expected signature "
|
||||||
|
"{} but saw {}".format(signature, received_signature))
|
||||||
|
|
||||||
|
|
||||||
@tf_export("CriticalSection")
|
@tf_export("CriticalSection")
|
||||||
class CriticalSection(object):
|
class CriticalSection(object):
|
||||||
"""Critical section.
|
"""Critical section.
|
||||||
@ -149,22 +200,10 @@ class CriticalSection(object):
|
|||||||
raise ValueError("critical_section_def and shared_name are "
|
raise ValueError("critical_section_def and shared_name are "
|
||||||
"mutually exclusive.")
|
"mutually exclusive.")
|
||||||
if critical_section_def:
|
if critical_section_def:
|
||||||
self._init_from_proto(critical_section_def, import_scope=import_scope)
|
raise ValueError("critical_section_def is not supported.")
|
||||||
else:
|
else:
|
||||||
self._init_from_args(name, shared_name)
|
self._init_from_args(name, shared_name)
|
||||||
|
|
||||||
def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name
|
|
||||||
raise NotImplementedError("Not yet implemented")
|
|
||||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
|
||||||
# assert isinstance(
|
|
||||||
# critical_section_def, critical_section_pb2.CriticalSectionDef)
|
|
||||||
# # Create from critical_section_def.
|
|
||||||
# g = ops.get_default_graph()
|
|
||||||
# self._handle = g.as_graph_element(
|
|
||||||
# ops.prepend_name_scope(
|
|
||||||
# critical_section_def.critical_section_name,
|
|
||||||
# import_scope=import_scope))
|
|
||||||
|
|
||||||
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
|
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
|
||||||
"""Initialize the CriticalSection from constructor arguments."""
|
"""Initialize the CriticalSection from constructor arguments."""
|
||||||
with ops.name_scope(name, "CriticalSection", []) as name:
|
with ops.name_scope(name, "CriticalSection", []) as name:
|
||||||
@ -178,6 +217,12 @@ class CriticalSection(object):
|
|||||||
container = ""
|
container = ""
|
||||||
self._handle = gen_resource_variable_ops.mutex_v2(
|
self._handle = gen_resource_variable_ops.mutex_v2(
|
||||||
shared_name=shared_name, container=container, name=name)
|
shared_name=shared_name, container=container, name=name)
|
||||||
|
# Get a uniquely identifying signature for the handle.
|
||||||
|
self._signature = (
|
||||||
|
container,
|
||||||
|
# If shared_name is empty, a unique CriticalSection is created.
|
||||||
|
shared_name or id(self._handle),
|
||||||
|
_get_device_or_colocation(self._handle))
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
ops.add_to_collections(CRITICAL_SECTIONS, self)
|
ops.add_to_collections(CRITICAL_SECTIONS, self)
|
||||||
@ -217,26 +262,26 @@ class CriticalSection(object):
|
|||||||
without `exclusive_resource_access=True`, a `ValueError` will be raised.
|
without `exclusive_resource_access=True`, a `ValueError` will be raised.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "critical_section_execute", []):
|
with ops.name_scope(name, "critical_section_execute", []):
|
||||||
|
|
||||||
# Ensure that mutex locking only happens *after* all args and
|
# Ensure that mutex locking only happens *after* all args and
|
||||||
# kwargs have been executed. This avoids certain types of deadlocks.
|
# kwargs have been executed. This avoids certain types of deadlocks.
|
||||||
lock = gen_resource_variable_ops.mutex_lock(self._handle)
|
with _push_critical_section_stack(self._signature):
|
||||||
|
lock = gen_resource_variable_ops.mutex_lock(self._handle)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
|
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
|
||||||
# Operations created by other threads.
|
# Operations created by other threads.
|
||||||
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
||||||
existing_ops = ops.get_default_graph().get_operations()
|
existing_ops = ops.get_default_graph().get_operations()
|
||||||
|
with ops.control_dependencies([lock]):
|
||||||
|
r = fn()
|
||||||
|
# TODO(ebrevdo): If creating critical sections in a python loop,
|
||||||
|
# this makes graph creation time quadratic. Revisit if this
|
||||||
|
# becomes a problem.
|
||||||
|
created_ops = (set(ops.get_default_graph().get_operations())
|
||||||
|
.difference(existing_ops))
|
||||||
|
else:
|
||||||
with ops.control_dependencies([lock]):
|
with ops.control_dependencies([lock]):
|
||||||
r = fn()
|
r = fn()
|
||||||
# TODO(ebrevdo): If creating critical sections in a python loop, this
|
|
||||||
# makes graph creation time quadratic. Revisit if this
|
|
||||||
# becomes a problem.
|
|
||||||
created_ops = (set(ops.get_default_graph().get_operations())
|
|
||||||
.difference(existing_ops))
|
|
||||||
else:
|
|
||||||
with ops.control_dependencies([lock]):
|
|
||||||
r = fn()
|
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
self._add_control_dependencies_to_lock(created_ops, lock.op)
|
self._add_control_dependencies_to_lock(created_ops, lock.op)
|
||||||
@ -255,9 +300,9 @@ class CriticalSection(object):
|
|||||||
# the execute(), themselves attempt to access the
|
# the execute(), themselves attempt to access the
|
||||||
# CriticalSection. This will cause a deadlock.
|
# CriticalSection. This will cause a deadlock.
|
||||||
if any(self._is_self_handle(x) for x in captured_resources):
|
if any(self._is_self_handle(x) for x in captured_resources):
|
||||||
raise ValueError("The function fn attempts to directly access the "
|
raise ValueError(
|
||||||
"CriticalSection in which it would be running. "
|
"Attempting to lock a CriticalSection in which we are "
|
||||||
"This is illegal and would cause deadlocks.")
|
"already running. This is illegal and may cause deadlocks.")
|
||||||
|
|
||||||
self._check_multiple_access_to_resources(
|
self._check_multiple_access_to_resources(
|
||||||
captured_resources, exclusive_resource_access)
|
captured_resources, exclusive_resource_access)
|
||||||
@ -374,82 +419,3 @@ class CriticalSection(object):
|
|||||||
"of this resource. Did you mean to call execute with keyword "
|
"of this resource. Did you mean to call execute with keyword "
|
||||||
"argument exclusive_resource_access=False?" %
|
"argument exclusive_resource_access=False?" %
|
||||||
(list(resource_intersection), self._handle, sg, sg.handle))
|
(list(resource_intersection), self._handle, sg, sg.handle))
|
||||||
|
|
||||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
|
||||||
|
|
||||||
# def to_proto(self, export_scope=None):
|
|
||||||
# """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# export_scope: Optional `string`. Name scope to remove.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# A `CriticalSectionDef` protocol buffer, or `None` if the
|
|
||||||
# `CriticalSection` is not in the specified name scope.
|
|
||||||
# """
|
|
||||||
# if export_scope is None or self.handle.name.startswith(export_scope):
|
|
||||||
# cs_def = critical_section_pb2.CriticalSectionDef()
|
|
||||||
# cs_def.critical_section_name = ops.strip_name_scope(
|
|
||||||
# self._handle.name, export_scope)
|
|
||||||
# return cs_def
|
|
||||||
# else:
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def from_proto(critical_section_def, import_scope=None):
|
|
||||||
# return CriticalSection(
|
|
||||||
# critical_section_def=critical_section_def, import_scope=import_scope)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
|
||||||
|
|
||||||
# def _execution_to_proto_fn(execution_signature, export_scope=None):
|
|
||||||
# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
|
|
||||||
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# execution_signature: Instance of `_ExecutionSignature`.
|
|
||||||
# export_scope: The export scope, if any.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# An instance of `CriticalSectionExecutionDef`.
|
|
||||||
# """
|
|
||||||
# if (export_scope is None
|
|
||||||
# or execution_signature.op.name.startswith(export_scope)):
|
|
||||||
# op_def = critical_section_pb2.CriticalSectionExecutionDef()
|
|
||||||
# op_def.execute_in_critical_section_name = ops.strip_name_scope(
|
|
||||||
# execution_signature.op.name, export_scope)
|
|
||||||
# op_def.exclusive_resource_access = (
|
|
||||||
# execution_signature.exclusive_resource_access)
|
|
||||||
# return op_def
|
|
||||||
# else:
|
|
||||||
# return None
|
|
||||||
|
|
||||||
|
|
||||||
# def _execution_from_proto_fn(op_def, import_scope=None):
|
|
||||||
# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
|
|
||||||
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
|
|
||||||
# assert isinstance(
|
|
||||||
# op_def, critical_section_pb2.CriticalSectionExecutionDef)
|
|
||||||
|
|
||||||
# # Create from op_def.
|
|
||||||
# g = ops.get_default_graph()
|
|
||||||
# execution_op = g.as_graph_element(
|
|
||||||
# ops.prepend_name_scope(
|
|
||||||
# op_def.execute_in_critical_section_name,
|
|
||||||
# import_scope=import_scope))
|
|
||||||
# return _ExecutionSignature(
|
|
||||||
# op=execution_op,
|
|
||||||
# exclusive_resource_access=op_def.exclusive_resource_access)
|
|
||||||
|
|
||||||
# ops.register_proto_function(
|
|
||||||
# CRITICAL_SECTIONS,
|
|
||||||
# proto_type=critical_section_pb2.CriticalSectionDef,
|
|
||||||
# to_proto=CriticalSection.to_proto,
|
|
||||||
# from_proto=CriticalSection.from_proto)
|
|
||||||
|
|
||||||
# ops.register_proto_function(
|
|
||||||
# CRITICAL_SECTION_EXECUTIONS,
|
|
||||||
# proto_type=critical_section_pb2.CriticalSectionExecutionDef,
|
|
||||||
# to_proto=_execution_to_proto_fn,
|
|
||||||
# from_proto=_execution_from_proto_fn)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user