Make get_concrete_function threadsafe even when the function has arguments.
PiperOrigin-RevId: 274000736
This commit is contained in:
parent
08f41c6216
commit
e258e9e31c
@ -2264,6 +2264,8 @@ class Function(object):
|
||||
(defining and calling) is thread-safe, but if users call other methods or
|
||||
invoke the base `python_function` themselves, external synchronization is
|
||||
necessary.
|
||||
In addition, Function is not reentrant, so recursive functions need to call
|
||||
the wrapped function, not the wrapper.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -2322,7 +2324,8 @@ class Function(object):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls a graph function specialized to the inputs."""
|
||||
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
||||
with self._lock:
|
||||
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
||||
return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
@ -2348,7 +2351,8 @@ class Function(object):
|
||||
"""Returns a concrete function which cleans up its graph function."""
|
||||
if self.input_signature:
|
||||
args, kwargs = None, None
|
||||
graph_function, _, _ = self._maybe_define_function(args, kwargs)
|
||||
with self._lock:
|
||||
graph_function, _, _ = self._maybe_define_function(args, kwargs)
|
||||
return graph_function
|
||||
|
||||
def _get_concrete_function_internal(self, *args, **kwargs):
|
||||
@ -2391,33 +2395,34 @@ class Function(object):
|
||||
"inputs (%s), input_signature (%s)" %
|
||||
(str(args), str(self.input_signature)))
|
||||
args, kwargs = None, None
|
||||
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
||||
if self.input_signature:
|
||||
args = self.input_signature
|
||||
kwargs = {}
|
||||
seen_names = set()
|
||||
captured = object_identity.ObjectIdentitySet(
|
||||
graph_function.graph.internal_captures)
|
||||
# pylint: disable=protected-access
|
||||
graph_function._arg_keywords = []
|
||||
prefix_counts = {}
|
||||
# pylint: enable=protected-access
|
||||
num_positional = 0
|
||||
for arg in graph_function.graph.inputs:
|
||||
if arg in captured:
|
||||
break
|
||||
num_positional += 1
|
||||
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
|
||||
proposal = user_arg_name
|
||||
while proposal in seen_names:
|
||||
index = prefix_counts.get(user_arg_name, 1)
|
||||
proposal = "{}_{}".format(user_arg_name, index)
|
||||
prefix_counts[user_arg_name] = index + 1
|
||||
seen_names.add(proposal)
|
||||
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
|
||||
# Anything can be a positional argument, in the same order as .inputs
|
||||
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
|
||||
return graph_function
|
||||
with self._lock:
|
||||
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
||||
if self.input_signature:
|
||||
args = self.input_signature
|
||||
kwargs = {}
|
||||
seen_names = set()
|
||||
captured = object_identity.ObjectIdentitySet(
|
||||
graph_function.graph.internal_captures)
|
||||
# pylint: disable=protected-access
|
||||
graph_function._arg_keywords = []
|
||||
prefix_counts = {}
|
||||
# pylint: enable=protected-access
|
||||
num_positional = 0
|
||||
for arg in graph_function.graph.inputs:
|
||||
if arg in captured:
|
||||
break
|
||||
num_positional += 1
|
||||
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
|
||||
proposal = user_arg_name
|
||||
while proposal in seen_names:
|
||||
index = prefix_counts.get(user_arg_name, 1)
|
||||
proposal = "{}_{}".format(user_arg_name, index)
|
||||
prefix_counts[user_arg_name] = index + 1
|
||||
seen_names.add(proposal)
|
||||
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
|
||||
# Anything can be a positional argument, in the same order as .inputs
|
||||
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
|
||||
return graph_function
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""Makes it possible to defun instance methods."""
|
||||
@ -2595,6 +2600,8 @@ class Function(object):
|
||||
`args` and `kwargs` can be None if this `Function` was created with an
|
||||
`input_signature`.
|
||||
|
||||
Caller must hold self._lock.
|
||||
|
||||
Args:
|
||||
args: The varargs for the Python function.
|
||||
kwargs: The keyword args for the Python function.
|
||||
@ -2622,43 +2629,42 @@ class Function(object):
|
||||
"Arguments supplied to `defun`-generated functions must be"
|
||||
" hashable. Original error: %s" % e)
|
||||
|
||||
with self._lock:
|
||||
graph_function = self._function_cache.primary.get(cache_key, None)
|
||||
if graph_function is not None:
|
||||
return graph_function, args, kwargs
|
||||
|
||||
logging.vlog(1,
|
||||
"Creating new FuncGraph for Python function %r (key: %r)",
|
||||
self._python_function, cache_key)
|
||||
logging.vlog(2,
|
||||
"Python function signature [args: %s] [kwargs: %s]",
|
||||
args,
|
||||
kwargs)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
call_context_key = cache_key._replace(input_signature=None)
|
||||
# pylint: disable=protected-access
|
||||
|
||||
ag_status = (
|
||||
ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED)
|
||||
with ag_ctx.ControlStatusCtx(
|
||||
status=ag_status, options=self._autograph_options):
|
||||
|
||||
# Build a function with shape relaxation retracing if:
|
||||
# 1. shape relaxation is explicitly enabled
|
||||
# and 2. there's no provided input signature
|
||||
# and 3. there's been a cache miss for this calling context
|
||||
if (self._experimental_relax_shapes
|
||||
and self.input_signature is None
|
||||
and call_context_key in self._function_cache.missed):
|
||||
return self._define_function_with_shape_relaxation(args, kwargs)
|
||||
|
||||
self._function_cache.missed.add(call_context_key)
|
||||
graph_function = self._function_cache.primary.get(cache_key, None)
|
||||
if graph_function is not None:
|
||||
return graph_function, args, kwargs
|
||||
|
||||
logging.vlog(1,
|
||||
"Creating new FuncGraph for Python function %r (key: %r)",
|
||||
self._python_function, cache_key)
|
||||
logging.vlog(2,
|
||||
"Python function signature [args: %s] [kwargs: %s]",
|
||||
args,
|
||||
kwargs)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
call_context_key = cache_key._replace(input_signature=None)
|
||||
# pylint: disable=protected-access
|
||||
|
||||
ag_status = (
|
||||
ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED)
|
||||
with ag_ctx.ControlStatusCtx(
|
||||
status=ag_status, options=self._autograph_options):
|
||||
|
||||
# Build a function with shape relaxation retracing if:
|
||||
# 1. shape relaxation is explicitly enabled
|
||||
# and 2. there's no provided input signature
|
||||
# and 3. there's been a cache miss for this calling context
|
||||
if (self._experimental_relax_shapes
|
||||
and self.input_signature is None
|
||||
and call_context_key in self._function_cache.missed):
|
||||
return self._define_function_with_shape_relaxation(args, kwargs)
|
||||
|
||||
self._function_cache.missed.add(call_context_key)
|
||||
graph_function = self._function_cache.primary.get(cache_key, None)
|
||||
if graph_function is None:
|
||||
graph_function = self._create_graph_function(args, kwargs)
|
||||
self._function_cache.primary[cache_key] = graph_function
|
||||
return graph_function, args, kwargs
|
||||
if graph_function is None:
|
||||
graph_function = self._create_graph_function(args, kwargs)
|
||||
self._function_cache.primary[cache_key] = graph_function
|
||||
return graph_function, args, kwargs
|
||||
|
||||
|
||||
def register(func, *args, **kwargs):
|
||||
|
@ -525,6 +525,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertLen(set(concrete_functions), 1)
|
||||
|
||||
def testGetConcreteFunctionThreadSafetyWithArgs(self):
|
||||
@def_function.function
|
||||
def add_100(*args):
|
||||
return math_ops.add_n(args)
|
||||
|
||||
p = multiprocessing.pool.ThreadPool(2)
|
||||
args = (constant_op.constant(1.),) * 100
|
||||
f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
|
||||
# I see about len(args) + max(0, len(args) - 3) arguments expected.
|
||||
f1(*args)
|
||||
del f2
|
||||
|
||||
def testInputSpecGraphFunction(self):
|
||||
matmul = def_function.function(math_ops.matmul)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user