diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 4c37f261fa5..2c6dbb5d849 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2264,6 +2264,8 @@ class Function(object): (defining and calling) is thread-safe, but if users call other methods or invoke the base `python_function` themselves, external synchronization is necessary. + In addition, Function is not reentrant, so recursive functions need to call + the wrapped function, not the wrapper. """ def __init__(self, @@ -2322,7 +2324,8 @@ class Function(object): def __call__(self, *args, **kwargs): """Calls a graph function specialized to the inputs.""" - graph_function, args, kwargs = self._maybe_define_function(args, kwargs) + with self._lock: + graph_function, args, kwargs = self._maybe_define_function(args, kwargs) return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access @property @@ -2348,7 +2351,8 @@ class Function(object): """Returns a concrete function which cleans up its graph function.""" if self.input_signature: args, kwargs = None, None - graph_function, _, _ = self._maybe_define_function(args, kwargs) + with self._lock: + graph_function, _, _ = self._maybe_define_function(args, kwargs) return graph_function def _get_concrete_function_internal(self, *args, **kwargs): @@ -2391,33 +2395,34 @@ class Function(object): "inputs (%s), input_signature (%s)" % (str(args), str(self.input_signature))) args, kwargs = None, None - graph_function, args, kwargs = self._maybe_define_function(args, kwargs) - if self.input_signature: - args = self.input_signature - kwargs = {} - seen_names = set() - captured = object_identity.ObjectIdentitySet( - graph_function.graph.internal_captures) - # pylint: disable=protected-access - graph_function._arg_keywords = [] - prefix_counts = {} - # pylint: enable=protected-access - num_positional = 0 - for arg in graph_function.graph.inputs: - if arg in captured: - break - num_positional += 1 - user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) - proposal = user_arg_name - while proposal in seen_names: - index = prefix_counts.get(user_arg_name, 1) - proposal = "{}_{}".format(user_arg_name, index) - prefix_counts[user_arg_name] = index + 1 - seen_names.add(proposal) - graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access - # Anything can be a positional argument, in the same order as .inputs - graph_function._num_positional_args = num_positional # pylint: disable=protected-access - return graph_function + with self._lock: + graph_function, args, kwargs = self._maybe_define_function(args, kwargs) + if self.input_signature: + args = self.input_signature + kwargs = {} + seen_names = set() + captured = object_identity.ObjectIdentitySet( + graph_function.graph.internal_captures) + # pylint: disable=protected-access + graph_function._arg_keywords = [] + prefix_counts = {} + # pylint: enable=protected-access + num_positional = 0 + for arg in graph_function.graph.inputs: + if arg in captured: + break + num_positional += 1 + user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) + proposal = user_arg_name + while proposal in seen_names: + index = prefix_counts.get(user_arg_name, 1) + proposal = "{}_{}".format(user_arg_name, index) + prefix_counts[user_arg_name] = index + 1 + seen_names.add(proposal) + graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access + # Anything can be a positional argument, in the same order as .inputs + graph_function._num_positional_args = num_positional # pylint: disable=protected-access + return graph_function def __get__(self, instance, owner): """Makes it possible to defun instance methods.""" @@ -2595,6 +2600,8 @@ class Function(object): `args` and `kwargs` can be None if this `Function` was created with an `input_signature`. + Caller must hold self._lock. + Args: args: The varargs for the Python function. kwargs: The keyword args for the Python function. @@ -2622,43 +2629,42 @@ class Function(object): "Arguments supplied to `defun`-generated functions must be" " hashable. Original error: %s" % e) - with self._lock: + graph_function = self._function_cache.primary.get(cache_key, None) + if graph_function is not None: + return graph_function, args, kwargs + + logging.vlog(1, + "Creating new FuncGraph for Python function %r (key: %r)", + self._python_function, cache_key) + logging.vlog(2, + "Python function signature [args: %s] [kwargs: %s]", + args, + kwargs) + + # pylint: disable=protected-access + call_context_key = cache_key._replace(input_signature=None) + # pylint: disable=protected-access + + ag_status = ( + ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED) + with ag_ctx.ControlStatusCtx( + status=ag_status, options=self._autograph_options): + + # Build a function with shape relaxation retracing if: + # 1. shape relaxation is explicitly enabled + # and 2. there's no provided input signature + # and 3. there's been a cache miss for this calling context + if (self._experimental_relax_shapes + and self.input_signature is None + and call_context_key in self._function_cache.missed): + return self._define_function_with_shape_relaxation(args, kwargs) + + self._function_cache.missed.add(call_context_key) graph_function = self._function_cache.primary.get(cache_key, None) - if graph_function is not None: - return graph_function, args, kwargs - - logging.vlog(1, - "Creating new FuncGraph for Python function %r (key: %r)", - self._python_function, cache_key) - logging.vlog(2, - "Python function signature [args: %s] [kwargs: %s]", - args, - kwargs) - - # pylint: disable=protected-access - call_context_key = cache_key._replace(input_signature=None) - # pylint: disable=protected-access - - ag_status = ( - ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED) - with ag_ctx.ControlStatusCtx( - status=ag_status, options=self._autograph_options): - - # Build a function with shape relaxation retracing if: - # 1. shape relaxation is explicitly enabled - # and 2. there's no provided input signature - # and 3. there's been a cache miss for this calling context - if (self._experimental_relax_shapes - and self.input_signature is None - and call_context_key in self._function_cache.missed): - return self._define_function_with_shape_relaxation(args, kwargs) - - self._function_cache.missed.add(call_context_key) - graph_function = self._function_cache.primary.get(cache_key, None) - if graph_function is None: - graph_function = self._create_graph_function(args, kwargs) - self._function_cache.primary[cache_key] = graph_function - return graph_function, args, kwargs + if graph_function is None: + graph_function = self._create_graph_function(args, kwargs) + self._function_cache.primary[cache_key] = graph_function + return graph_function, args, kwargs def register(func, *args, **kwargs): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index c5671747736..c062fddd8d1 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -525,6 +525,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertLen(set(concrete_functions), 1) + def testGetConcreteFunctionThreadSafetyWithArgs(self): + @def_function.function + def add_100(*args): + return math_ops.add_n(args) + + p = multiprocessing.pool.ThreadPool(2) + args = (constant_op.constant(1.),) * 100 + f1, f2 = p.map(add_100.get_concrete_function, [args] * 2) + # I see about len(args) + max(0, len(args) - 3) arguments expected. + f1(*args) + del f2 + def testInputSpecGraphFunction(self): matmul = def_function.function(math_ops.matmul)