[TF CriticalSection] Make deadlock detection friendly for non-distributed eager execution.

PiperOrigin-RevId: 266177528
This commit is contained in:
Eugene Brevdo 2019-08-29 10:51:14 -07:00 committed by TensorFlower Gardener
parent b2c7258633
commit 7c94976ae9
2 changed files with 84 additions and 149 deletions
tensorflow/python

View File

@ -170,7 +170,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
[signature.op for signature in
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testRecursiveCriticalSectionAccessIsIllegal(self):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
@ -181,9 +180,7 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
return cs.execute(lambda: add(x))
with self.assertRaisesRegexp(
ValueError,
r"attempts to directly access the CriticalSection in which it "
r"would be running"):
ValueError, r"Attempting to lock a CriticalSection in which we are"):
cs.execute(lambda: fn(1.0))
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
@ -313,7 +310,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
"body_args_capture'\n"
"==============\n")
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
@ -324,12 +320,11 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
def fn(x):
return cs_same.execute(lambda: add(x))
with self.assertRaisesRegexp(
ValueError,
r"attempts to directly access the CriticalSection in which it "
r"would be running"):
ValueError, r"Attempting to lock a CriticalSection in which we are"):
cs.execute(lambda: fn(1.0))
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
@test_util.run_v1_only(
"b/123955885 Can't identify consumed resources in eager mode")
def testMultipleCSExecutionsRequestSameResource(self):
cs0 = critical_section_ops.CriticalSection()
cs1 = critical_section_ops.CriticalSection()
@ -403,32 +398,6 @@ class CriticalSectionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(1, get_first())
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
#
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):
# cs = critical_section_ops.CriticalSection()
# r = cs.execute(lambda x: x + 1, 1.0)
# graph = ops.get_default_graph()
# meta_graph = saver_lib.export_meta_graph(
# graph=graph, collection_list=graph.get_all_collection_keys())
# graph_copy = ops.Graph()
# with graph_copy.as_default():
# _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported")
# restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)
# restored_exec = ops.get_collection(
# critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
# self.assertEqual(1, len(restored_cs))
# self.assertEqual(1, len(restored_exec))
# self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name)
# self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
# def testToProto(self):
# cs = critical_section_ops.CriticalSection(shared_name="cs")
# proto = cs.to_proto()
# self.assertEqual(proto.critical_section_name, cs._handle.name)
# cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
# self.assertEqual(cs_copy._handle, cs._handle)
if __name__ == "__main__":
test.main()

View File

@ -19,9 +19,8 @@ from __future__ import division
from __future__ import print_function
import collections
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.core.protobuf import critical_section_pb2
import contextlib
import threading
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
@ -63,14 +62,66 @@ def _identity(x):
return array_ops.identity(x)
def _get_device_or_colocation(op):
return op.device or _get_colocation(op)
def _get_colocation(op):
"""Get colocation symbol from op, if any."""
try:
return op.get_attr("_class")
except ValueError:
except (ValueError, AttributeError):
return None
_CRITICAL_SECTION_STACK = threading.local()
def _get_critical_section_stack():
try:
return _CRITICAL_SECTION_STACK.value
except AttributeError:
_CRITICAL_SECTION_STACK.value = []
return _CRITICAL_SECTION_STACK.value
@contextlib.contextmanager
def _push_critical_section_stack(signature):
"""Push a CriticalSection._signature to the thread-local stack.
If the signature is already on the stack, raise an error because it means
we're trying to execute inside the same locked CriticalSection, which
will create a deadlock.
Args:
signature: Tuple of the type `CriticalSection._signature`. Uniquely
identifies a CriticalSection by its `shared_name`, `container`,
and device.
Yields:
An empty value. The context is guaranteed to run without deadlock.
Raises:
ValueError: If the signature is already on the stack.
RuntimeError: If another thread or function modifies the current stack
entry during the yield.
"""
stack = _get_critical_section_stack()
if signature in stack:
raise ValueError(
"Attempting to lock a CriticalSection in which we are "
"already running. This is illegal and may cause deadlocks.")
stack.append(signature)
try:
yield
finally:
received_signature = stack.pop()
if received_signature != signature:
raise RuntimeError(
"CriticalSection stack inconsistency: expected signature "
"{} but saw {}".format(signature, received_signature))
@tf_export("CriticalSection")
class CriticalSection(object):
"""Critical section.
@ -149,22 +200,10 @@ class CriticalSection(object):
raise ValueError("critical_section_def and shared_name are "
"mutually exclusive.")
if critical_section_def:
self._init_from_proto(critical_section_def, import_scope=import_scope)
raise ValueError("critical_section_def is not supported.")
else:
self._init_from_args(name, shared_name)
def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name
raise NotImplementedError("Not yet implemented")
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# assert isinstance(
# critical_section_def, critical_section_pb2.CriticalSectionDef)
# # Create from critical_section_def.
# g = ops.get_default_graph()
# self._handle = g.as_graph_element(
# ops.prepend_name_scope(
# critical_section_def.critical_section_name,
# import_scope=import_scope))
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
@ -178,6 +217,12 @@ class CriticalSection(object):
container = ""
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
# Get a uniquely identifying signature for the handle.
self._signature = (
container,
# If shared_name is empty, a unique CriticalSection is created.
shared_name or id(self._handle),
_get_device_or_colocation(self._handle))
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
@ -217,26 +262,26 @@ class CriticalSection(object):
without `exclusive_resource_access=True`, a `ValueError` will be raised.
"""
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
# kwargs have been executed. This avoids certain types of deadlocks.
lock = gen_resource_variable_ops.mutex_lock(self._handle)
with _push_critical_section_stack(self._signature):
lock = gen_resource_variable_ops.mutex_lock(self._handle)
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop,
# this makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop, this
# makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)
@ -255,9 +300,9 @@ class CriticalSection(object):
# the execute(), themselves attempt to access the
# CriticalSection. This will cause a deadlock.
if any(self._is_self_handle(x) for x in captured_resources):
raise ValueError("The function fn attempts to directly access the "
"CriticalSection in which it would be running. "
"This is illegal and would cause deadlocks.")
raise ValueError(
"Attempting to lock a CriticalSection in which we are "
"already running. This is illegal and may cause deadlocks.")
self._check_multiple_access_to_resources(
captured_resources, exclusive_resource_access)
@ -374,82 +419,3 @@ class CriticalSection(object):
"of this resource. Did you mean to call execute with keyword "
"argument exclusive_resource_access=False?" %
(list(resource_intersection), self._handle, sg, sg.handle))
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def to_proto(self, export_scope=None):
# """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.
# Args:
# export_scope: Optional `string`. Name scope to remove.
# Returns:
# A `CriticalSectionDef` protocol buffer, or `None` if the
# `CriticalSection` is not in the specified name scope.
# """
# if export_scope is None or self.handle.name.startswith(export_scope):
# cs_def = critical_section_pb2.CriticalSectionDef()
# cs_def.critical_section_name = ops.strip_name_scope(
# self._handle.name, export_scope)
# return cs_def
# else:
# return None
# @staticmethod
# def from_proto(critical_section_def, import_scope=None):
# return CriticalSection(
# critical_section_def=critical_section_def, import_scope=import_scope)
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def _execution_to_proto_fn(execution_signature, export_scope=None):
# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# Args:
# execution_signature: Instance of `_ExecutionSignature`.
# export_scope: The export scope, if any.
# Returns:
# An instance of `CriticalSectionExecutionDef`.
# """
# if (export_scope is None
# or execution_signature.op.name.startswith(export_scope)):
# op_def = critical_section_pb2.CriticalSectionExecutionDef()
# op_def.execute_in_critical_section_name = ops.strip_name_scope(
# execution_signature.op.name, export_scope)
# op_def.exclusive_resource_access = (
# execution_signature.exclusive_resource_access)
# return op_def
# else:
# return None
# def _execution_from_proto_fn(op_def, import_scope=None):
# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# assert isinstance(
# op_def, critical_section_pb2.CriticalSectionExecutionDef)
# # Create from op_def.
# g = ops.get_default_graph()
# execution_op = g.as_graph_element(
# ops.prepend_name_scope(
# op_def.execute_in_critical_section_name,
# import_scope=import_scope))
# return _ExecutionSignature(
# op=execution_op,
# exclusive_resource_access=op_def.exclusive_resource_access)
# ops.register_proto_function(
# CRITICAL_SECTIONS,
# proto_type=critical_section_pb2.CriticalSectionDef,
# to_proto=CriticalSection.to_proto,
# from_proto=CriticalSection.from_proto)
# ops.register_proto_function(
# CRITICAL_SECTION_EXECUTIONS,
# proto_type=critical_section_pb2.CriticalSectionExecutionDef,
# to_proto=_execution_to_proto_fn,
# from_proto=_execution_from_proto_fn)