diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index b63a3b434d4..78d44a81b0b 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -385,6 +386,24 @@ class DefFunctionTest(test.TestCase): f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64) self.assertAllClose([1.1, 3.3, 6.6], f(f64_input)) + def testNoExcessiveRetracing(self): + inner_retracings = 0 + + @def_function.function(experimental_compile=True) + def inner(a, b): + nonlocal inner_retracings + inner_retracings += 1 + return a * b + a + + def outer(a, b): + return inner(a, b) + + func_input = random_ops.random_normal([10, 10]) + for _ in range(2): + def_function.function(outer)(func_input, func_input) + + self.assertEqual(inner_retracings, 1) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a40eaf886b3..c02318cb814 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2981,9 +2981,10 @@ class Function(object): if not executing_eagerly: # We want to force function retracing for each different # XLAControlFlowContext, so add `xla_context_id` to the cache key. - tpu_context = _enclosing_xla_context() - if tpu_context is not None: - xla_context_id = id(tpu_context) + xla_context = _enclosing_xla_context() + if xla_context is not None and \ + xla_context.RequiresUniqueFunctionRetracing(): + xla_context_id = id(xla_context) with ops.init_scope(): # The graph, or whether we're executing eagerly, should be a part of the diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 3398308d42e..748f842a9e0 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3682,6 +3682,11 @@ class XLAControlFlowContext(ControlFlowContext): def AddValue(self, x): return x + def RequiresUniqueFunctionRetracing(self): + """Returns whether the tf.function should be retraced if the context changes. + """ + return False + def from_control_flow_context_def(context_def, import_scope=None): """Deserializes `context_def` into the appropriate ControlFlowContext. diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 28eba69b7da..ce3aaa8a058 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -639,6 +639,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def GetControlPivot(self): return self._pivot + def RequiresUniqueFunctionRetracing(self): + # More context: b/158152827. TPU stack uses the TPUReplicateContext to + # create replicated variable handles and cluster TPU computations, thus we + # always retrace a tf.function when the wrapped TPUReplicateContext changes. + return True + class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): """The context for outside compilation in Tensorflow 2.0.