Guard ops modification and Session.run with a group lock. This lock allows multiple ops modifications to happen at the same time, but no Session.run can happen until the modifications are done. And vice-versa.
PiperOrigin-RevId: 202028326
This commit is contained in:
parent
79d11c035c
commit
55b3cac99d
tensorflow/python
@ -3925,7 +3925,7 @@ tf_cuda_library(
|
||||
|
||||
tf_py_test(
|
||||
name = "session_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["client/session_test.py"],
|
||||
additional_deps = [
|
||||
":array_ops",
|
||||
|
@ -1291,7 +1291,7 @@ class BaseSession(SessionInterface):
|
||||
raise type(e)(node_def, op, message)
|
||||
|
||||
def _extend_graph(self):
|
||||
with self._graph._lock: # pylint: disable=protected-access
|
||||
with self._graph._session_run_lock(): # pylint: disable=protected-access
|
||||
tf_session.ExtendSession(self._session)
|
||||
|
||||
# The threshold to run garbage collection to delete dead tensors.
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
def testParallelRunAndBuild(self):
|
||||
@staticmethod
|
||||
def _build_graph():
|
||||
time.sleep(random.random() * 0.1)
|
||||
# Do some graph construction. Try to exercise non-trivial paths.
|
||||
graph = ops.get_default_graph()
|
||||
gdef = None
|
||||
for _ in range(10):
|
||||
x = array_ops.placeholder(dtype=dtypes.float32)
|
||||
with ops.colocate_with(x):
|
||||
y = array_ops.placeholder(dtype=dtypes.float32)
|
||||
with ops.device('/cpu:0'):
|
||||
z = control_flow_ops.while_loop(
|
||||
lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
|
||||
with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
|
||||
gradients_impl.gradients(z, [x, y])
|
||||
if gdef is None:
|
||||
gdef = graph.as_graph_def()
|
||||
else:
|
||||
importer.import_graph_def(gdef, name='import')
|
||||
|
||||
def testParallelRunAndSingleBuild(self):
|
||||
with session.Session() as sess:
|
||||
c = constant_op.constant(5.0)
|
||||
stop = threading.Event()
|
||||
|
||||
def run_loop():
|
||||
while not stop.is_set():
|
||||
time.sleep(random.random() * 0.1)
|
||||
self.assertEqual(sess.run(c), 5.0)
|
||||
|
||||
threads = [self.checkedThread(target=run_loop) for _ in range(100)]
|
||||
threads = [self.checkedThread(target=run_loop) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# Do some graph construction. Try to exercise non-trivial paths.
|
||||
graph = ops.get_default_graph()
|
||||
gdef = None
|
||||
for _ in range(10):
|
||||
x = array_ops.placeholder(dtype=dtypes.float32)
|
||||
with ops.colocate_with(x):
|
||||
y = array_ops.placeholder(dtype=dtypes.float32)
|
||||
with ops.device('/cpu:0'):
|
||||
z = control_flow_ops.while_loop(
|
||||
lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
|
||||
with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
|
||||
gradients_impl.gradients(z, [x, y])
|
||||
if gdef is None:
|
||||
gdef = graph.as_graph_def()
|
||||
else:
|
||||
importer.import_graph_def(gdef, name='import')
|
||||
SessionTest._build_graph()
|
||||
|
||||
stop.set()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
def testParallelRunAndParallelBuild(self):
|
||||
with session.Session() as sess:
|
||||
c = constant_op.constant(5.0)
|
||||
stop = threading.Event()
|
||||
|
||||
def run_loop():
|
||||
while not stop.is_set():
|
||||
time.sleep(random.random() * 0.1)
|
||||
self.assertEqual(sess.run(c), 5.0)
|
||||
|
||||
run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
|
||||
for t in run_threads:
|
||||
t.start()
|
||||
|
||||
build_threads = [self.checkedThread(target=SessionTest._build_graph)
|
||||
for _ in range(10)]
|
||||
for t in build_threads:
|
||||
t.start()
|
||||
for t in build_threads:
|
||||
t.join()
|
||||
|
||||
# Let the run_threads run until the build threads are finished.
|
||||
stop.set()
|
||||
for t in run_threads:
|
||||
t.join()
|
||||
|
||||
def testRunFeedDict(self):
|
||||
with session.Session() as s:
|
||||
x = array_ops.zeros([2])
|
||||
|
@ -407,11 +407,11 @@ def import_graph_def(graph_def,
|
||||
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
|
||||
return_elements)
|
||||
|
||||
# _ProcessNewOps mutates the new operations. _lock ensures a Session.run
|
||||
# call cannot occur between creating the TF_Operations in the
|
||||
# _ProcessNewOps mutates the new operations. _mutation_lock ensures a
|
||||
# Session.run call cannot occur between creating the TF_Operations in the
|
||||
# TF_GraphImportGraphDefWithResults call and mutating the them in
|
||||
# _ProcessNewOps.
|
||||
with graph._lock: # pylint: disable=protected-access
|
||||
with graph._mutation_lock(): # pylint: disable=protected-access
|
||||
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
|
||||
try:
|
||||
results = c_api.TF_GraphImportGraphDefWithResults(
|
||||
|
@ -55,6 +55,7 @@ from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import decorator_utils
|
||||
from tensorflow.python.util import lock_util
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -2599,6 +2600,10 @@ def _name_from_scope_name(name):
|
||||
return name[:-1] if (name and name[-1] == "/") else name
|
||||
|
||||
|
||||
_MUTATION_LOCK_GROUP = 0
|
||||
_SESSION_RUN_LOCK_GROUP = 1
|
||||
|
||||
|
||||
@tf_export("Graph")
|
||||
class Graph(object):
|
||||
"""A TensorFlow computation, represented as a dataflow graph.
|
||||
@ -2648,20 +2653,21 @@ class Graph(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Creates a new, empty Graph."""
|
||||
# Protects core state that can be returned via public accessors, as well as
|
||||
# synchronizes Session.run calls with methods that create and mutate ops
|
||||
# (e.g. Graph.create_op()). This synchronization is necessary because it's
|
||||
# illegal to modify an operation after it's been run. Thread-safety is
|
||||
# provided on a best-effort basis to support buggy programs, and is not
|
||||
# guaranteed by the public `tf.Graph` API.
|
||||
#
|
||||
# The lock must be reentrant because create_op can be called recursively due
|
||||
# to control flow. Without a reentrant lock, many methods would also need a
|
||||
# "locked" version or parameter (including generated code).
|
||||
# Protects core state that can be returned via public accessors.
|
||||
# Thread-safety is provided on a best-effort basis to support buggy
|
||||
# programs, and is not guaranteed by the public `tf.Graph` API.
|
||||
#
|
||||
# NOTE(mrry): This does not protect the various stacks. A warning will
|
||||
# be reported if these are used from multiple threads
|
||||
self._lock = threading.RLock()
|
||||
# The group lock synchronizes Session.run calls with methods that create
|
||||
# and mutate ops (e.g. Graph.create_op()). This synchronization is
|
||||
# necessary because it's illegal to modify an operation after it's been run.
|
||||
# The group lock allows any number of threads to mutate ops at the same time
|
||||
# but if any modification is going on, all Session.run calls have to wait.
|
||||
# Similarly, if one or more Session.run calls are going on, all mutate ops
|
||||
# have to wait until all Session.run calls have finished.
|
||||
self._group_lock = lock_util.GroupLock(num_groups=2)
|
||||
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
|
||||
self._next_id_counter = 0 # GUARDED_BY(self._lock)
|
||||
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
|
||||
@ -3192,9 +3198,9 @@ class Graph(object):
|
||||
|
||||
input_ops = set([t.op for t in inputs])
|
||||
control_inputs = self._control_dependencies_for_inputs(input_ops)
|
||||
# _create_op_helper mutates the new Operation. _lock ensures a Session.run
|
||||
# call cannot occur between creating and mutating the op.
|
||||
with self._lock:
|
||||
# _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
|
||||
# Session.run call cannot occur between creating and mutating the op.
|
||||
with self._mutation_lock():
|
||||
ret = Operation(
|
||||
node_def,
|
||||
self,
|
||||
@ -4719,6 +4725,20 @@ class Graph(object):
|
||||
else:
|
||||
self._graph_control_dependencies_stack = control_dependencies
|
||||
|
||||
def _mutation_lock(self):
|
||||
"""Returns a lock to guard code that creates & mutates ops.
|
||||
|
||||
See the comment for self._group_lock for more info.
|
||||
"""
|
||||
return self._group_lock.group(_MUTATION_LOCK_GROUP)
|
||||
|
||||
def _session_run_lock(self):
|
||||
"""Returns a lock to guard code for Session.run.
|
||||
|
||||
See the comment for self._group_lock for more info.
|
||||
"""
|
||||
return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
|
||||
|
||||
|
||||
# TODO(agarwal): currently device directives in an outer eager scope will not
|
||||
# apply to inner graph mode code. Fix that.
|
||||
|
@ -2943,9 +2943,10 @@ class WhileContext(ControlFlowContext):
|
||||
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
|
||||
try:
|
||||
self.Enter()
|
||||
# _BuildLoop calls _update_input in several places. _lock ensures a
|
||||
# Session.run call cannot occur between creating and mutating new ops.
|
||||
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
||||
# _BuildLoop calls _update_input in several places. _mutation_lock()
|
||||
# ensures a Session.run call cannot occur between creating and mutating
|
||||
# new ops.
|
||||
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
|
||||
original_body_result, exit_vars = self._BuildLoop(
|
||||
pred, body, original_loop_vars, loop_vars, shape_invariants)
|
||||
finally:
|
||||
|
@ -534,10 +534,10 @@ def gradients(ys,
|
||||
RuntimeError: if called in Eager mode.
|
||||
|
||||
"""
|
||||
# Creating the gradient graph for control flow mutates Operations. _lock
|
||||
# ensures a Session.run call cannot occur between creating and mutating new
|
||||
# ops.
|
||||
with ops.get_default_graph()._lock: # pylint: disable=protected-access
|
||||
# Creating the gradient graph for control flow mutates Operations.
|
||||
# _mutation_lock ensures a Session.run call cannot occur between creating and
|
||||
# mutating new ops.
|
||||
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
|
||||
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
gate_gradients, aggregation_method, stop_gradients)
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -48,7 +47,7 @@ class GroupLockTest(test.TestCase, parameterized.TestCase):
|
||||
finished.add(thread_id)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=thread_fn, args=(i,))
|
||||
self.checkedThread(target=thread_fn, args=(i,))
|
||||
for i in range(num_threads)
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user