Expose CriticalSection in core as tf.CriticalSection.

PiperOrigin-RevId: 233146763
This commit is contained in:
Eugene Brevdo 2019-02-08 16:24:59 -08:00 committed by TensorFlower Gardener
parent ce5a8d8ff7
commit c7621b86bc
14 changed files with 164 additions and 65 deletions

View File

@ -32,7 +32,6 @@ tf_custom_op_py_library(
"python/ops/arg_scope.py",
"python/ops/audio_ops.py",
"python/ops/checkpoint_ops.py",
"python/ops/critical_section_ops.py",
"python/ops/ops.py",
"python/ops/prettyprint_ops.py",
"python/ops/script_ops.py",
@ -172,27 +171,6 @@ py_test(
],
)
cuda_py_test(
name = "critical_section_test",
size = "medium",
srcs = ["python/ops/critical_section_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
":framework_py",
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:platform_test",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
],
)
py_test(
name = "ops_test",
size = "small",

View File

@ -94,8 +94,6 @@
@@smart_constant_value
@@smart_case
@@CriticalSection
@@BoundedTensorSpec
@@TensorSpec

View File

@ -22,7 +22,6 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.framework.python.ops.arg_scope import *
from tensorflow.contrib.framework.python.ops.checkpoint_ops import *
from tensorflow.contrib.framework.python.ops.critical_section_ops import *
from tensorflow.contrib.framework.python.ops.ops import *
from tensorflow.contrib.framework.python.ops.prettyprint_ops import *
from tensorflow.contrib.framework.python.ops.script_ops import *

View File

@ -2737,6 +2737,22 @@ py_library(
],
)
py_library(
name = "critical_section_ops",
srcs = ["ops/critical_section_ops.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
":control_flow_ops",
":dtypes",
":framework_ops",
":resource_variable_ops_gen",
":tensor_array_ops",
":util",
"//tensorflow/python/eager:context",
],
)
py_library(
name = "list_ops",
srcs = ["ops/list_ops.py"],
@ -3148,6 +3164,7 @@ py_library(
":clip_ops",
":confusion_matrix",
":control_flow_ops",
":critical_section_ops",
":cudnn_rnn_grad",
":data_flow_grad",
":data_flow_ops",

View File

@ -3592,3 +3592,24 @@ cuda_py_test(
grpc_enabled = True,
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "critical_section_test",
size = "medium",
srcs = ["critical_section_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:platform_test",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:critical_section_ops",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
],
)

View File

@ -18,14 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import critical_section_ops
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import critical_section_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@ -48,7 +49,7 @@ class CriticalSectionTest(test.TestCase):
return array_ops.identity(c)
num_concurrent = 100
r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)]
r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
self.evaluate(v.initializer)
r_value = self.evaluate(r)
self.assertAllClose([2.0 * i for i in range(num_concurrent)],
@ -74,7 +75,7 @@ class CriticalSectionTest(test.TestCase):
array_ops.identity(inner_cond), true_fn, lambda: c)
def execute():
return cs.execute(fn, 1.0, 2.0)
return cs.execute(lambda: fn(1.0, 2.0))
r = [
control_flow_ops.cond(array_ops.identity(outer_cond),
@ -92,6 +93,7 @@ class CriticalSectionTest(test.TestCase):
else:
self.assertAllClose([0] * num_concurrent, r_value)
@test_util.run_v1_only("b/123990562 Sees CancelledError on some calls")
def testCriticalSectionInParallelDoesntDeadlockOnError(self):
# No eager mode execution of this test because eager does not
# run fn() in parallel, which is where the deadlock could
@ -103,12 +105,23 @@ class CriticalSectionTest(test.TestCase):
error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
with ops.control_dependencies([error]):
return v.read_value()
num_concurrent = 2
r = [cs.execute(fn, i) for i in range(num_concurrent)]
@def_function.function(autograph=False)
def run_concurrently():
return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)]
if not context.executing_eagerly():
run_concurrently = run_concurrently()
self.evaluate(v.initializer)
for _ in range(100):
with self.assertRaisesOpError("Error"):
self.evaluate(r)
if context.executing_eagerly():
run_concurrently()
else:
self.evaluate(run_concurrently)
@test_util.run_in_graph_and_eager_modes
def testCreateCriticalSectionFnReturnsOp(self):
@ -123,17 +136,20 @@ class CriticalSectionTest(test.TestCase):
return control_flow_ops.no_op()
num_concurrent = 100
r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)]
r = [cs.execute(lambda: fn_return_op(1.0, 2.0))
for _ in range(num_concurrent)]
self.evaluate(v.initializer)
self.evaluate(r)
final_v = self.evaluate(v)
self.assertAllClose(2.0 * num_concurrent, final_v)
@test_util.run_v1_only("Collections don't exist in TF2")
def testCollection(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
self.assertIn(
cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute")
add = lambda x: x + 1
execute = cs.execute(lambda: add(1.0), name="my_execute")
execute_op = [
x for x in execute.graph.get_operations()
if "my_execute" in x.name and "MutexLock" in x.type
@ -143,18 +159,21 @@ class CriticalSectionTest(test.TestCase):
[signature.op for signature in
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testRecursiveCriticalSectionAccessIsIllegal(self):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
# to debug.
cs = critical_section_ops.CriticalSection()
add = lambda y: y + 1
def fn(x):
return cs.execute(lambda y: y + 1, x)
return cs.execute(lambda: add(x))
with self.assertRaisesRegexp(
ValueError,
r"attempts to directly access the CriticalSection in which it "
r"would be running"):
cs.execute(fn, 1.0)
cs.execute(lambda: fn(1.0))
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
# This one is subtle; and we're being overly cautious here. The
@ -174,24 +193,24 @@ class CriticalSectionTest(test.TestCase):
# operations are finished before anything runs within the critical section.
cs = critical_section_ops.CriticalSection(shared_name="cs")
fn = array_ops.identity
to_capture = cs.execute(fn, 1.0)
to_capture = cs.execute(lambda: fn(1.0))
fn_captures = lambda x: x + to_capture
to_capture_too = array_ops.identity(to_capture)
ex_0 = cs.execute(fn_captures, 1.0)
ex_0 = cs.execute(lambda: fn_captures(1.0))
with ops.control_dependencies([to_capture]):
# This is OK because to_capture will execute before this next call
ex_1 = cs.execute(fn_captures, 1.0)
ex_1 = cs.execute(lambda: fn_captures(1.0))
dependency = array_ops.identity(to_capture)
fn_captures_dependency = lambda x: x + dependency
ex_2 = cs.execute(fn_captures_dependency, 1.0)
ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))
with ops.control_dependencies([to_capture_too]):
ex_3 = cs.execute(fn_captures_dependency, 1.0)
ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))
# Ensure there's no actual deadlock on to_execute.
self.assertEquals(2.0, self.evaluate(ex_0))
@ -217,6 +236,8 @@ class CriticalSectionTest(test.TestCase):
body_implicit_capture,
[0, 0],
parallel_iterations=25)
# For consistency between eager and graph mode.
i_n = array_ops.identity(i_n)
logging.warn(
"\n==============\nRunning "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
@ -242,6 +263,8 @@ class CriticalSectionTest(test.TestCase):
body_implicit_capture_protected,
[0, 0],
parallel_iterations=25)
# For consistency between eager and graph mode.
i_n = array_ops.identity(i_n)
logging.warn(
"\n==============\nRunning "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
@ -258,13 +281,15 @@ class CriticalSectionTest(test.TestCase):
# This version is ok because j is an argument to fn and we can
# ensure there's a control dependency on j.
fn = lambda x: x + 1
return (i + 1, cs.execute(fn, j))
return (i + 1, cs.execute(lambda: fn(j)))
(i_n, j_n) = control_flow_ops.while_loop(
lambda i, _: i < 1000,
body_args_capture,
[0, 0],
parallel_iterations=25)
# For consistency between eager and graph mode.
i_n = array_ops.identity(i_n)
logging.warn(
"\n==============\nRunning "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
@ -277,20 +302,23 @@ class CriticalSectionTest(test.TestCase):
"body_args_capture'\n"
"==============\n")
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
# to debug.
cs = critical_section_ops.CriticalSection(shared_name="cs")
cs_same = critical_section_ops.CriticalSection(shared_name="cs")
add = lambda x: x + 1
def fn(x):
return cs_same.execute(lambda x: x+1, x)
return cs_same.execute(lambda: add(x))
with self.assertRaisesRegexp(
ValueError,
r"attempts to directly access the CriticalSection in which it "
r"would be running"):
cs.execute(fn, 1.0)
cs.execute(lambda: fn(1.0))
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testMultipleCSExecutionsRequestSameResource(self):
cs0 = critical_section_ops.CriticalSection()
cs1 = critical_section_ops.CriticalSection()
@ -328,8 +356,11 @@ class CriticalSectionTest(test.TestCase):
# Note, here v must be a resource variable (or something similar),
# otherwise it gets hoisted into the while_loop by the time we add
# control dependencies to the lock_op.
def body(i):
add_j = lambda j: v + j + 1
return cs.execute(lambda: add_j(i))
out = control_flow_ops.while_loop(
lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0])
lambda i: i < 10, body, [0])
self.evaluate(v.initializer)
self.assertEqual(10, self.evaluate(out))

View File

@ -31,6 +31,10 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
__all__ = ["CriticalSection"]
# Graph Keys
@ -66,6 +70,7 @@ def _get_colocation(op):
return None
@tf_export("CriticalSection")
class CriticalSection(object):
"""Critical section.
@ -179,37 +184,36 @@ class CriticalSection(object):
def name(self):
return self._handle.op.name
def execute(self, fn, *args, **kwargs):
"""Execute function `fn(*args, **kwargs)` inside the CriticalSection.
def execute(self, fn, exclusive_resource_access=True, name=None):
"""Execute function `fn()` inside the critical section.
`fn` should not accept any arguments. To add extra arguments to when
calling `fn` in the critical section, create a lambda:
```python
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
```
Args:
fn: The function to execute. Must return at least one tensor.
*args: Additional positional arguments to `fn`.
**kwargs: Additional keyword arguments to `fn`.
Several keywords are reserved for `execute`. These are:
- name; The name to use when creating the execute operation.
- exclusive_resource_access; Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
exclusive_resource_access: Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
name: The name to use when creating the execute operation.
Returns:
The tensors returned from `fn(*args, **kwargs)`.
The tensors returned from `fn()`.
Raises:
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access` is not provided (is `True`) and
ValueError: If `exclusive_resource_access == True` and
another `CriticalSection` has an execution requesting the same
resources as in `*args`, `**kwargs`, and any additionally captured
inputs in `fn`. Note, even if `exclusive_resource_access` is `True`,
if another execution in another `CriticalSection` was created without
`exclusive_resource_access=True`, a `ValueError` will be raised.
resources as `fn``. Note, even if `exclusive_resource_access` is
`True`, if another execution in another `CriticalSection` was created
without `exclusive_resource_access=True`, a `ValueError` will be raised.
"""
name = kwargs.pop("name", None)
exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
@ -222,7 +226,7 @@ class CriticalSection(object):
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn(*args, **kwargs)
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop, this
# makes graph creation time quadratic. Revisit if this
# becomes a problem.
@ -230,7 +234,7 @@ class CriticalSection(object):
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn(*args, **kwargs)
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)

View File

@ -54,6 +54,7 @@ from tensorflow.python.ops.control_flow_ops import tuple # pylint: disable=rede
# pylint: enable=redefined-builtin
from tensorflow.python.eager import wrap_function
from tensorflow.python.ops.control_flow_ops import while_loop
from tensorflow.python.ops.critical_section_ops import *
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.functional_ops import *
from tensorflow.python.ops.gradients import *

View File

@ -0,0 +1,17 @@
path: "tensorflow.CriticalSection"
tf_class {
is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>"
is_instance: "<type \'object\'>"
member {
name: "name"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "execute"
argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
}

View File

@ -32,6 +32,10 @@ tf_module {
name: "ConfigProto"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
name: "CriticalSection"
mtype: "<type \'type\'>"
}
member {
name: "DType"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,17 @@
path: "tensorflow.CriticalSection"
tf_class {
is_instance: "<class \'tensorflow.python.ops.critical_section_ops.CriticalSection\'>"
is_instance: "<type \'object\'>"
member {
name: "name"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'name\', \'shared_name\', \'critical_section_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "execute"
argspec: "args=[\'self\', \'fn\', \'exclusive_resource_access\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "AggregationMethod"
mtype: "<type \'type\'>"
}
member {
name: "CriticalSection"
mtype: "<type \'type\'>"
}
member {
name: "DType"
mtype: "<type \'type\'>"

View File

@ -540,6 +540,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.data.experimental.unbatch",
"tf.contrib.data.unique":
"tf.data.experimental.unique",
"tf.contrib.framework.CriticalSection":
"tf.CriticalSection",
"tf.contrib.framework.is_tensor":
"tf.is_tensor",
"tf.contrib.framework.nest.assert_same_structure":

View File

@ -1157,6 +1157,12 @@ def _log_prob(self, x):
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
def test_CriticalSection_upgrade(self):
text = "tf.contrib.framework.CriticalSection(shared_name='blah')"
expected = "tf.CriticalSection(shared_name='blah')"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
def test_sample_distorted_bounding_box(self):
# pylint: disable=line-too-long
text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"