Rollback of rollback: [TF/XLA] Only force retracing for non-unique XLA context ID for TPUReplicatedContext
Fixes https://github.com/tensorflow/tensorflow/issues/39872 PiperOrigin-RevId: 317137904 Change-Id: Id287e10a0ab2494b11427435d8f89a383eeaf392
This commit is contained in:
parent
2ff1c5a31b
commit
b8bb250ebb
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
|||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
@ -385,6 +386,24 @@ class DefFunctionTest(test.TestCase):
|
|||
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
||||
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
||||
|
||||
def testNoExcessiveRetracing(self):
|
||||
inner_retracings = 0
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def inner(a, b):
|
||||
nonlocal inner_retracings
|
||||
inner_retracings += 1
|
||||
return a * b + a
|
||||
|
||||
def outer(a, b):
|
||||
return inner(a, b)
|
||||
|
||||
func_input = random_ops.random_normal([10, 10])
|
||||
for _ in range(2):
|
||||
def_function.function(outer)(func_input, func_input)
|
||||
|
||||
self.assertEqual(inner_retracings, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
|
|
@ -2981,9 +2981,10 @@ class Function(object):
|
|||
if not executing_eagerly:
|
||||
# We want to force function retracing for each different
|
||||
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
|
||||
tpu_context = _enclosing_xla_context()
|
||||
if tpu_context is not None:
|
||||
xla_context_id = id(tpu_context)
|
||||
xla_context = _enclosing_xla_context()
|
||||
if xla_context is not None and \
|
||||
xla_context.RequiresUniqueFunctionRetracing():
|
||||
xla_context_id = id(xla_context)
|
||||
|
||||
with ops.init_scope():
|
||||
# The graph, or whether we're executing eagerly, should be a part of the
|
||||
|
|
|
@ -3682,6 +3682,11 @@ class XLAControlFlowContext(ControlFlowContext):
|
|||
def AddValue(self, x):
|
||||
return x
|
||||
|
||||
def RequiresUniqueFunctionRetracing(self):
|
||||
"""Returns whether the tf.function should be retraced if the context changes.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def from_control_flow_context_def(context_def, import_scope=None):
|
||||
"""Deserializes `context_def` into the appropriate ControlFlowContext.
|
||||
|
|
|
@ -639,6 +639,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||
def GetControlPivot(self):
|
||||
return self._pivot
|
||||
|
||||
def RequiresUniqueFunctionRetracing(self):
|
||||
# More context: b/158152827. TPU stack uses the TPUReplicateContext to
|
||||
# create replicated variable handles and cluster TPU computations, thus we
|
||||
# always retrace a tf.function when the wrapped TPUReplicateContext changes.
|
||||
return True
|
||||
|
||||
|
||||
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
||||
"""The context for outside compilation in Tensorflow 2.0.
|
||||
|
|
Loading…
Reference in New Issue