Rollback of rollback: [TF/XLA] Only force retracing for non-unique XLA context ID for TPUReplicatedContext

Fixes https://github.com/tensorflow/tensorflow/issues/39872

PiperOrigin-RevId: 317137904
Change-Id: Id287e10a0ab2494b11427435d8f89a383eeaf392
This commit is contained in:
George Karpenkov 2020-06-18 10:59:05 -07:00 committed by TensorFlower Gardener
parent 2ff1c5a31b
commit b8bb250ebb
4 changed files with 34 additions and 3 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -385,6 +386,24 @@ class DefFunctionTest(test.TestCase):
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64) f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input)) self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
def testNoExcessiveRetracing(self):
inner_retracings = 0
@def_function.function(experimental_compile=True)
def inner(a, b):
nonlocal inner_retracings
inner_retracings += 1
return a * b + a
def outer(a, b):
return inner(a, b)
func_input = random_ops.random_normal([10, 10])
for _ in range(2):
def_function.function(outer)(func_input, func_input)
self.assertEqual(inner_retracings, 1)
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution() ops.enable_eager_execution()

View File

@ -2981,9 +2981,10 @@ class Function(object):
if not executing_eagerly: if not executing_eagerly:
# We want to force function retracing for each different # We want to force function retracing for each different
# XLAControlFlowContext, so add `xla_context_id` to the cache key. # XLAControlFlowContext, so add `xla_context_id` to the cache key.
tpu_context = _enclosing_xla_context() xla_context = _enclosing_xla_context()
if tpu_context is not None: if xla_context is not None and \
xla_context_id = id(tpu_context) xla_context.RequiresUniqueFunctionRetracing():
xla_context_id = id(xla_context)
with ops.init_scope(): with ops.init_scope():
# The graph, or whether we're executing eagerly, should be a part of the # The graph, or whether we're executing eagerly, should be a part of the

View File

@ -3682,6 +3682,11 @@ class XLAControlFlowContext(ControlFlowContext):
def AddValue(self, x): def AddValue(self, x):
return x return x
def RequiresUniqueFunctionRetracing(self):
"""Returns whether the tf.function should be retraced if the context changes.
"""
return False
def from_control_flow_context_def(context_def, import_scope=None): def from_control_flow_context_def(context_def, import_scope=None):
"""Deserializes `context_def` into the appropriate ControlFlowContext. """Deserializes `context_def` into the appropriate ControlFlowContext.

View File

@ -639,6 +639,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def GetControlPivot(self): def GetControlPivot(self):
return self._pivot return self._pivot
def RequiresUniqueFunctionRetracing(self):
# More context: b/158152827. TPU stack uses the TPUReplicateContext to
# create replicated variable handles and cluster TPU computations, thus we
# always retrace a tf.function when the wrapped TPUReplicateContext changes.
return True
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext): class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
"""The context for outside compilation in Tensorflow 2.0. """The context for outside compilation in Tensorflow 2.0.