Make get_concrete_function threadsafe even when the function has arguments.

PiperOrigin-RevId: 274000736
This commit is contained in:
A. Unique TensorFlower 2019-10-10 11:23:59 -07:00 committed by TensorFlower Gardener
parent 08f41c6216
commit e258e9e31c
2 changed files with 83 additions and 65 deletions

View File

@ -2264,6 +2264,8 @@ class Function(object):
(defining and calling) is thread-safe, but if users call other methods or
invoke the base `python_function` themselves, external synchronization is
necessary.
In addition, Function is not reentrant, so recursive functions need to call
the wrapped function, not the wrapper.
"""
def __init__(self,
@ -2322,7 +2324,8 @@ class Function(object):
def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
with self._lock:
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
@property
@ -2348,7 +2351,8 @@ class Function(object):
"""Returns a concrete function which cleans up its graph function."""
if self.input_signature:
args, kwargs = None, None
graph_function, _, _ = self._maybe_define_function(args, kwargs)
with self._lock:
graph_function, _, _ = self._maybe_define_function(args, kwargs)
return graph_function
def _get_concrete_function_internal(self, *args, **kwargs):
@ -2391,33 +2395,34 @@ class Function(object):
"inputs (%s), input_signature (%s)" %
(str(args), str(self.input_signature)))
args, kwargs = None, None
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
if self.input_signature:
args = self.input_signature
kwargs = {}
seen_names = set()
captured = object_identity.ObjectIdentitySet(
graph_function.graph.internal_captures)
# pylint: disable=protected-access
graph_function._arg_keywords = []
prefix_counts = {}
# pylint: enable=protected-access
num_positional = 0
for arg in graph_function.graph.inputs:
if arg in captured:
break
num_positional += 1
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
proposal = user_arg_name
while proposal in seen_names:
index = prefix_counts.get(user_arg_name, 1)
proposal = "{}_{}".format(user_arg_name, index)
prefix_counts[user_arg_name] = index + 1
seen_names.add(proposal)
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
# Anything can be a positional argument, in the same order as .inputs
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
return graph_function
with self._lock:
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
if self.input_signature:
args = self.input_signature
kwargs = {}
seen_names = set()
captured = object_identity.ObjectIdentitySet(
graph_function.graph.internal_captures)
# pylint: disable=protected-access
graph_function._arg_keywords = []
prefix_counts = {}
# pylint: enable=protected-access
num_positional = 0
for arg in graph_function.graph.inputs:
if arg in captured:
break
num_positional += 1
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
proposal = user_arg_name
while proposal in seen_names:
index = prefix_counts.get(user_arg_name, 1)
proposal = "{}_{}".format(user_arg_name, index)
prefix_counts[user_arg_name] = index + 1
seen_names.add(proposal)
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
# Anything can be a positional argument, in the same order as .inputs
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
return graph_function
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
@ -2595,6 +2600,8 @@ class Function(object):
`args` and `kwargs` can be None if this `Function` was created with an
`input_signature`.
Caller must hold self._lock.
Args:
args: The varargs for the Python function.
kwargs: The keyword args for the Python function.
@ -2622,43 +2629,42 @@ class Function(object):
"Arguments supplied to `defun`-generated functions must be"
" hashable. Original error: %s" % e)
with self._lock:
graph_function = self._function_cache.primary.get(cache_key, None)
if graph_function is not None:
return graph_function, args, kwargs
logging.vlog(1,
"Creating new FuncGraph for Python function %r (key: %r)",
self._python_function, cache_key)
logging.vlog(2,
"Python function signature [args: %s] [kwargs: %s]",
args,
kwargs)
# pylint: disable=protected-access
call_context_key = cache_key._replace(input_signature=None)
# pylint: disable=protected-access
ag_status = (
ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED)
with ag_ctx.ControlStatusCtx(
status=ag_status, options=self._autograph_options):
# Build a function with shape relaxation retracing if:
# 1. shape relaxation is explicitly enabled
# and 2. there's no provided input signature
# and 3. there's been a cache miss for this calling context
if (self._experimental_relax_shapes
and self.input_signature is None
and call_context_key in self._function_cache.missed):
return self._define_function_with_shape_relaxation(args, kwargs)
self._function_cache.missed.add(call_context_key)
graph_function = self._function_cache.primary.get(cache_key, None)
if graph_function is not None:
return graph_function, args, kwargs
logging.vlog(1,
"Creating new FuncGraph for Python function %r (key: %r)",
self._python_function, cache_key)
logging.vlog(2,
"Python function signature [args: %s] [kwargs: %s]",
args,
kwargs)
# pylint: disable=protected-access
call_context_key = cache_key._replace(input_signature=None)
# pylint: disable=protected-access
ag_status = (
ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED)
with ag_ctx.ControlStatusCtx(
status=ag_status, options=self._autograph_options):
# Build a function with shape relaxation retracing if:
# 1. shape relaxation is explicitly enabled
# and 2. there's no provided input signature
# and 3. there's been a cache miss for this calling context
if (self._experimental_relax_shapes
and self.input_signature is None
and call_context_key in self._function_cache.missed):
return self._define_function_with_shape_relaxation(args, kwargs)
self._function_cache.missed.add(call_context_key)
graph_function = self._function_cache.primary.get(cache_key, None)
if graph_function is None:
graph_function = self._create_graph_function(args, kwargs)
self._function_cache.primary[cache_key] = graph_function
return graph_function, args, kwargs
if graph_function is None:
graph_function = self._create_graph_function(args, kwargs)
self._function_cache.primary[cache_key] = graph_function
return graph_function, args, kwargs
def register(func, *args, **kwargs):

View File

@ -525,6 +525,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(set(concrete_functions), 1)
def testGetConcreteFunctionThreadSafetyWithArgs(self):
@def_function.function
def add_100(*args):
return math_ops.add_n(args)
p = multiprocessing.pool.ThreadPool(2)
args = (constant_op.constant(1.),) * 100
f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
# I see about len(args) + max(0, len(args) - 3) arguments expected.
f1(*args)
del f2
def testInputSpecGraphFunction(self):
matmul = def_function.function(math_ops.matmul)