Guard ops modification and Session.run with a group lock. This lock allows multiple ops modifications to happen at the same time, but no Session.run can happen until the modifications are done. And vice-versa.

PiperOrigin-RevId: 202028326
This commit is contained in:
Priya Gupta 2018-06-25 15:17:56 -07:00 committed by TensorFlower Gardener
parent 79d11c035c
commit 55b3cac99d
8 changed files with 98 additions and 45 deletions

View File

@ -3925,7 +3925,7 @@ tf_cuda_library(
tf_py_test(
name = "session_test",
size = "small",
size = "medium",
srcs = ["client/session_test.py"],
additional_deps = [
":array_ops",

View File

@ -1291,7 +1291,7 @@ class BaseSession(SessionInterface):
raise type(e)(node_def, op, message)
def _extend_graph(self):
with self._graph._lock: # pylint: disable=protected-access
with self._graph._session_run_lock(): # pylint: disable=protected-access
tf_session.ExtendSession(self._session)
# The threshold to run garbage collection to delete dead tensors.

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
import random
import os
import sys
import threading
@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase):
for t in threads:
t.join()
def testParallelRunAndBuild(self):
@staticmethod
def _build_graph():
time.sleep(random.random() * 0.1)
# Do some graph construction. Try to exercise non-trivial paths.
graph = ops.get_default_graph()
gdef = None
for _ in range(10):
x = array_ops.placeholder(dtype=dtypes.float32)
with ops.colocate_with(x):
y = array_ops.placeholder(dtype=dtypes.float32)
with ops.device('/cpu:0'):
z = control_flow_ops.while_loop(
lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
gradients_impl.gradients(z, [x, y])
if gdef is None:
gdef = graph.as_graph_def()
else:
importer.import_graph_def(gdef, name='import')
def testParallelRunAndSingleBuild(self):
with session.Session() as sess:
c = constant_op.constant(5.0)
stop = threading.Event()
def run_loop():
while not stop.is_set():
time.sleep(random.random() * 0.1)
self.assertEqual(sess.run(c), 5.0)
threads = [self.checkedThread(target=run_loop) for _ in range(100)]
threads = [self.checkedThread(target=run_loop) for _ in range(10)]
for t in threads:
t.start()
# Do some graph construction. Try to exercise non-trivial paths.
graph = ops.get_default_graph()
gdef = None
for _ in range(10):
x = array_ops.placeholder(dtype=dtypes.float32)
with ops.colocate_with(x):
y = array_ops.placeholder(dtype=dtypes.float32)
with ops.device('/cpu:0'):
z = control_flow_ops.while_loop(
lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
gradients_impl.gradients(z, [x, y])
if gdef is None:
gdef = graph.as_graph_def()
else:
importer.import_graph_def(gdef, name='import')
SessionTest._build_graph()
stop.set()
for t in threads:
t.join()
def testParallelRunAndParallelBuild(self):
with session.Session() as sess:
c = constant_op.constant(5.0)
stop = threading.Event()
def run_loop():
while not stop.is_set():
time.sleep(random.random() * 0.1)
self.assertEqual(sess.run(c), 5.0)
run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
for t in run_threads:
t.start()
build_threads = [self.checkedThread(target=SessionTest._build_graph)
for _ in range(10)]
for t in build_threads:
t.start()
for t in build_threads:
t.join()
# Let the run_threads run until the build threads are finished.
stop.set()
for t in run_threads:
t.join()
def testRunFeedDict(self):
with session.Session() as s:
x = array_ops.zeros([2])

View File

@ -407,11 +407,11 @@ def import_graph_def(graph_def,
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
# _ProcessNewOps mutates the new operations. _lock ensures a Session.run
# call cannot occur between creating the TF_Operations in the
# _ProcessNewOps mutates the new operations. _mutation_lock ensures a
# Session.run call cannot occur between creating the TF_Operations in the
# TF_GraphImportGraphDefWithResults call and mutating the them in
# _ProcessNewOps.
with graph._lock: # pylint: disable=protected-access
with graph._mutation_lock(): # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
results = c_api.TF_GraphImportGraphDefWithResults(

View File

@ -55,6 +55,7 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
@ -2599,6 +2600,10 @@ def _name_from_scope_name(name):
return name[:-1] if (name and name[-1] == "/") else name
_MUTATION_LOCK_GROUP = 0
_SESSION_RUN_LOCK_GROUP = 1
@tf_export("Graph")
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@ -2648,20 +2653,21 @@ class Graph(object):
def __init__(self):
"""Creates a new, empty Graph."""
# Protects core state that can be returned via public accessors, as well as
# synchronizes Session.run calls with methods that create and mutate ops
# (e.g. Graph.create_op()). This synchronization is necessary because it's
# illegal to modify an operation after it's been run. Thread-safety is
# provided on a best-effort basis to support buggy programs, and is not
# guaranteed by the public `tf.Graph` API.
#
# The lock must be reentrant because create_op can be called recursively due
# to control flow. Without a reentrant lock, many methods would also need a
# "locked" version or parameter (including generated code).
# Protects core state that can be returned via public accessors.
# Thread-safety is provided on a best-effort basis to support buggy
# programs, and is not guaranteed by the public `tf.Graph` API.
#
# NOTE(mrry): This does not protect the various stacks. A warning will
# be reported if these are used from multiple threads
self._lock = threading.RLock()
# The group lock synchronizes Session.run calls with methods that create
# and mutate ops (e.g. Graph.create_op()). This synchronization is
# necessary because it's illegal to modify an operation after it's been run.
# The group lock allows any number of threads to mutate ops at the same time
# but if any modification is going on, all Session.run calls have to wait.
# Similarly, if one or more Session.run calls are going on, all mutate ops
# have to wait until all Session.run calls have finished.
self._group_lock = lock_util.GroupLock(num_groups=2)
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
@ -3192,9 +3198,9 @@ class Graph(object):
input_ops = set([t.op for t in inputs])
control_inputs = self._control_dependencies_for_inputs(input_ops)
# _create_op_helper mutates the new Operation. _lock ensures a Session.run
# call cannot occur between creating and mutating the op.
with self._lock:
# _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
# Session.run call cannot occur between creating and mutating the op.
with self._mutation_lock():
ret = Operation(
node_def,
self,
@ -4719,6 +4725,20 @@ class Graph(object):
else:
self._graph_control_dependencies_stack = control_dependencies
def _mutation_lock(self):
"""Returns a lock to guard code that creates & mutates ops.
See the comment for self._group_lock for more info.
"""
return self._group_lock.group(_MUTATION_LOCK_GROUP)
def _session_run_lock(self):
"""Returns a lock to guard code for Session.run.
See the comment for self._group_lock for more info.
"""
return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
# TODO(agarwal): currently device directives in an outer eager scope will not
# apply to inner graph mode code. Fix that.

View File

@ -2943,9 +2943,10 @@ class WhileContext(ControlFlowContext):
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
try:
self.Enter()
# _BuildLoop calls _update_input in several places. _lock ensures a
# Session.run call cannot occur between creating and mutating new ops.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
# _BuildLoop calls _update_input in several places. _mutation_lock()
# ensures a Session.run call cannot occur between creating and mutating
# new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
original_body_result, exit_vars = self._BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants)
finally:

View File

@ -534,10 +534,10 @@ def gradients(ys,
RuntimeError: if called in Eager mode.
"""
# Creating the gradient graph for control flow mutates Operations. _lock
# ensures a Session.run call cannot occur between creating and mutating new
# ops.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
# Creating the gradient graph for control flow mutates Operations.
# _mutation_lock ensures a Session.run call cannot occur between creating and
# mutating new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import random
import threading
import time
from absl.testing import parameterized
@ -48,7 +47,7 @@ class GroupLockTest(test.TestCase, parameterized.TestCase):
finished.add(thread_id)
threads = [
threading.Thread(target=thread_fn, args=(i,))
self.checkedThread(target=thread_fn, args=(i,))
for i in range(num_threads)
]