[TF CriticalSection] Make deadlock detection friendly for non-distributed eager execution.
PiperOrigin-RevId: 266177528
This commit is contained in:
parent
b2c7258633
commit
7c94976ae9
tensorflow/python
@ -170,7 +170,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
||||
[signature.op for signature in
|
||||
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
|
||||
|
||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
||||
def testRecursiveCriticalSectionAccessIsIllegal(self):
|
||||
# This does not work properly in eager mode. Eager users will
|
||||
# just hit a deadlock if they do this. But at least it'll be easier
|
||||
@ -181,9 +180,7 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
||||
return cs.execute(lambda: add(x))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"attempts to directly access the CriticalSection in which it "
|
||||
r"would be running"):
|
||||
ValueError, r"Attempting to lock a CriticalSection in which we are"):
|
||||
cs.execute(lambda: fn(1.0))
|
||||
|
||||
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
|
||||
@ -313,7 +310,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
||||
"body_args_capture'\n"
|
||||
"==============\n")
|
||||
|
||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
||||
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
|
||||
# This does not work properly in eager mode. Eager users will
|
||||
# just hit a deadlock if they do this. But at least it'll be easier
|
||||
@ -324,12 +320,11 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
||||
def fn(x):
|
||||
return cs_same.execute(lambda: add(x))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r"attempts to directly access the CriticalSection in which it "
|
||||
r"would be running"):
|
||||
ValueError, r"Attempting to lock a CriticalSection in which we are"):
|
||||
cs.execute(lambda: fn(1.0))
|
||||
|
||||
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
|
||||
@test_util.run_v1_only(
|
||||
"b/123955885 Can't identify consumed resources in eager mode")
|
||||
def testMultipleCSExecutionsRequestSameResource(self):
|
||||
cs0 = critical_section_ops.CriticalSection()
|
||||
cs1 = critical_section_ops.CriticalSection()
|
||||
@ -403,32 +398,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual(1, get_first())
|
||||
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
#
|
||||
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):
|
||||
# cs = critical_section_ops.CriticalSection()
|
||||
# r = cs.execute(lambda x: x + 1, 1.0)
|
||||
# graph = ops.get_default_graph()
|
||||
# meta_graph = saver_lib.export_meta_graph(
|
||||
# graph=graph, collection_list=graph.get_all_collection_keys())
|
||||
# graph_copy = ops.Graph()
|
||||
# with graph_copy.as_default():
|
||||
# _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported")
|
||||
# restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)
|
||||
# restored_exec = ops.get_collection(
|
||||
# critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
|
||||
# self.assertEqual(1, len(restored_cs))
|
||||
# self.assertEqual(1, len(restored_exec))
|
||||
# self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name)
|
||||
# self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
|
||||
|
||||
# def testToProto(self):
|
||||
# cs = critical_section_ops.CriticalSection(shared_name="cs")
|
||||
# proto = cs.to_proto()
|
||||
# self.assertEqual(proto.critical_section_name, cs._handle.name)
|
||||
# cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
|
||||
# self.assertEqual(cs_copy._handle, cs._handle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -19,9 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
# from tensorflow.core.protobuf import critical_section_pb2
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -63,14 +62,66 @@ def _identity(x):
|
||||
return array_ops.identity(x)
|
||||
|
||||
|
||||
def _get_device_or_colocation(op):
|
||||
return op.device or _get_colocation(op)
|
||||
|
||||
|
||||
def _get_colocation(op):
|
||||
"""Get colocation symbol from op, if any."""
|
||||
try:
|
||||
return op.get_attr("_class")
|
||||
except ValueError:
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
_CRITICAL_SECTION_STACK = threading.local()
|
||||
|
||||
|
||||
def _get_critical_section_stack():
|
||||
try:
|
||||
return _CRITICAL_SECTION_STACK.value
|
||||
except AttributeError:
|
||||
_CRITICAL_SECTION_STACK.value = []
|
||||
return _CRITICAL_SECTION_STACK.value
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _push_critical_section_stack(signature):
|
||||
"""Push a CriticalSection._signature to the thread-local stack.
|
||||
|
||||
If the signature is already on the stack, raise an error because it means
|
||||
we're trying to execute inside the same locked CriticalSection, which
|
||||
will create a deadlock.
|
||||
|
||||
Args:
|
||||
signature: Tuple of the type `CriticalSection._signature`. Uniquely
|
||||
identifies a CriticalSection by its `shared_name`, `container`,
|
||||
and device.
|
||||
|
||||
Yields:
|
||||
An empty value. The context is guaranteed to run without deadlock.
|
||||
|
||||
Raises:
|
||||
ValueError: If the signature is already on the stack.
|
||||
RuntimeError: If another thread or function modifies the current stack
|
||||
entry during the yield.
|
||||
"""
|
||||
stack = _get_critical_section_stack()
|
||||
if signature in stack:
|
||||
raise ValueError(
|
||||
"Attempting to lock a CriticalSection in which we are "
|
||||
"already running. This is illegal and may cause deadlocks.")
|
||||
stack.append(signature)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
received_signature = stack.pop()
|
||||
if received_signature != signature:
|
||||
raise RuntimeError(
|
||||
"CriticalSection stack inconsistency: expected signature "
|
||||
"{} but saw {}".format(signature, received_signature))
|
||||
|
||||
|
||||
@tf_export("CriticalSection")
|
||||
class CriticalSection(object):
|
||||
"""Critical section.
|
||||
@ -149,22 +200,10 @@ class CriticalSection(object):
|
||||
raise ValueError("critical_section_def and shared_name are "
|
||||
"mutually exclusive.")
|
||||
if critical_section_def:
|
||||
self._init_from_proto(critical_section_def, import_scope=import_scope)
|
||||
raise ValueError("critical_section_def is not supported.")
|
||||
else:
|
||||
self._init_from_args(name, shared_name)
|
||||
|
||||
def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name
|
||||
raise NotImplementedError("Not yet implemented")
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
# assert isinstance(
|
||||
# critical_section_def, critical_section_pb2.CriticalSectionDef)
|
||||
# # Create from critical_section_def.
|
||||
# g = ops.get_default_graph()
|
||||
# self._handle = g.as_graph_element(
|
||||
# ops.prepend_name_scope(
|
||||
# critical_section_def.critical_section_name,
|
||||
# import_scope=import_scope))
|
||||
|
||||
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
|
||||
"""Initialize the CriticalSection from constructor arguments."""
|
||||
with ops.name_scope(name, "CriticalSection", []) as name:
|
||||
@ -178,6 +217,12 @@ class CriticalSection(object):
|
||||
container = ""
|
||||
self._handle = gen_resource_variable_ops.mutex_v2(
|
||||
shared_name=shared_name, container=container, name=name)
|
||||
# Get a uniquely identifying signature for the handle.
|
||||
self._signature = (
|
||||
container,
|
||||
# If shared_name is empty, a unique CriticalSection is created.
|
||||
shared_name or id(self._handle),
|
||||
_get_device_or_colocation(self._handle))
|
||||
|
||||
if not context.executing_eagerly():
|
||||
ops.add_to_collections(CRITICAL_SECTIONS, self)
|
||||
@ -217,26 +262,26 @@ class CriticalSection(object):
|
||||
without `exclusive_resource_access=True`, a `ValueError` will be raised.
|
||||
"""
|
||||
with ops.name_scope(name, "critical_section_execute", []):
|
||||
|
||||
# Ensure that mutex locking only happens *after* all args and
|
||||
# kwargs have been executed. This avoids certain types of deadlocks.
|
||||
lock = gen_resource_variable_ops.mutex_lock(self._handle)
|
||||
with _push_critical_section_stack(self._signature):
|
||||
lock = gen_resource_variable_ops.mutex_lock(self._handle)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
|
||||
# Operations created by other threads.
|
||||
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
||||
existing_ops = ops.get_default_graph().get_operations()
|
||||
if not context.executing_eagerly():
|
||||
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
|
||||
# Operations created by other threads.
|
||||
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
||||
existing_ops = ops.get_default_graph().get_operations()
|
||||
with ops.control_dependencies([lock]):
|
||||
r = fn()
|
||||
# TODO(ebrevdo): If creating critical sections in a python loop,
|
||||
# this makes graph creation time quadratic. Revisit if this
|
||||
# becomes a problem.
|
||||
created_ops = (set(ops.get_default_graph().get_operations())
|
||||
.difference(existing_ops))
|
||||
else:
|
||||
with ops.control_dependencies([lock]):
|
||||
r = fn()
|
||||
# TODO(ebrevdo): If creating critical sections in a python loop, this
|
||||
# makes graph creation time quadratic. Revisit if this
|
||||
# becomes a problem.
|
||||
created_ops = (set(ops.get_default_graph().get_operations())
|
||||
.difference(existing_ops))
|
||||
else:
|
||||
with ops.control_dependencies([lock]):
|
||||
r = fn()
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self._add_control_dependencies_to_lock(created_ops, lock.op)
|
||||
@ -255,9 +300,9 @@ class CriticalSection(object):
|
||||
# the execute(), themselves attempt to access the
|
||||
# CriticalSection. This will cause a deadlock.
|
||||
if any(self._is_self_handle(x) for x in captured_resources):
|
||||
raise ValueError("The function fn attempts to directly access the "
|
||||
"CriticalSection in which it would be running. "
|
||||
"This is illegal and would cause deadlocks.")
|
||||
raise ValueError(
|
||||
"Attempting to lock a CriticalSection in which we are "
|
||||
"already running. This is illegal and may cause deadlocks.")
|
||||
|
||||
self._check_multiple_access_to_resources(
|
||||
captured_resources, exclusive_resource_access)
|
||||
@ -374,82 +419,3 @@ class CriticalSection(object):
|
||||
"of this resource. Did you mean to call execute with keyword "
|
||||
"argument exclusive_resource_access=False?" %
|
||||
(list(resource_intersection), self._handle, sg, sg.handle))
|
||||
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
|
||||
# def to_proto(self, export_scope=None):
|
||||
# """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.
|
||||
|
||||
# Args:
|
||||
# export_scope: Optional `string`. Name scope to remove.
|
||||
|
||||
# Returns:
|
||||
# A `CriticalSectionDef` protocol buffer, or `None` if the
|
||||
# `CriticalSection` is not in the specified name scope.
|
||||
# """
|
||||
# if export_scope is None or self.handle.name.startswith(export_scope):
|
||||
# cs_def = critical_section_pb2.CriticalSectionDef()
|
||||
# cs_def.critical_section_name = ops.strip_name_scope(
|
||||
# self._handle.name, export_scope)
|
||||
# return cs_def
|
||||
# else:
|
||||
# return None
|
||||
|
||||
# @staticmethod
|
||||
# def from_proto(critical_section_def, import_scope=None):
|
||||
# return CriticalSection(
|
||||
# critical_section_def=critical_section_def, import_scope=import_scope)
|
||||
|
||||
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
|
||||
# def _execution_to_proto_fn(execution_signature, export_scope=None):
|
||||
# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
|
||||
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
|
||||
|
||||
# Args:
|
||||
# execution_signature: Instance of `_ExecutionSignature`.
|
||||
# export_scope: The export scope, if any.
|
||||
|
||||
# Returns:
|
||||
# An instance of `CriticalSectionExecutionDef`.
|
||||
# """
|
||||
# if (export_scope is None
|
||||
# or execution_signature.op.name.startswith(export_scope)):
|
||||
# op_def = critical_section_pb2.CriticalSectionExecutionDef()
|
||||
# op_def.execute_in_critical_section_name = ops.strip_name_scope(
|
||||
# execution_signature.op.name, export_scope)
|
||||
# op_def.exclusive_resource_access = (
|
||||
# execution_signature.exclusive_resource_access)
|
||||
# return op_def
|
||||
# else:
|
||||
# return None
|
||||
|
||||
|
||||
# def _execution_from_proto_fn(op_def, import_scope=None):
|
||||
# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
|
||||
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
|
||||
# assert isinstance(
|
||||
# op_def, critical_section_pb2.CriticalSectionExecutionDef)
|
||||
|
||||
# # Create from op_def.
|
||||
# g = ops.get_default_graph()
|
||||
# execution_op = g.as_graph_element(
|
||||
# ops.prepend_name_scope(
|
||||
# op_def.execute_in_critical_section_name,
|
||||
# import_scope=import_scope))
|
||||
# return _ExecutionSignature(
|
||||
# op=execution_op,
|
||||
# exclusive_resource_access=op_def.exclusive_resource_access)
|
||||
|
||||
# ops.register_proto_function(
|
||||
# CRITICAL_SECTIONS,
|
||||
# proto_type=critical_section_pb2.CriticalSectionDef,
|
||||
# to_proto=CriticalSection.to_proto,
|
||||
# from_proto=CriticalSection.from_proto)
|
||||
|
||||
# ops.register_proto_function(
|
||||
# CRITICAL_SECTION_EXECUTIONS,
|
||||
# proto_type=critical_section_pb2.CriticalSectionExecutionDef,
|
||||
# to_proto=_execution_to_proto_fn,
|
||||
# from_proto=_execution_from_proto_fn)
|
||||
|
Loading…
Reference in New Issue
Block a user