[TF/XLA] Only force retracing for non-unique XLA context ID for TPUReplicatedContext

Fixes https://github.com/tensorflow/tensorflow/issues/39872

PiperOrigin-RevId: 316503485
Change-Id: Ice63983fcdf2fdedca60a9054f3b76ac60e1ff15
This commit is contained in:
George Karpenkov 2020-06-15 11:10:33 -07:00 committed by TensorFlower Gardener
parent b7d66ef926
commit ba658404f2
4 changed files with 3 additions and 34 deletions

View File

@ -29,7 +29,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
@ -386,24 +385,6 @@ class DefFunctionTest(test.TestCase):
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
def testNoExcessiveRetracing(self):
inner_retracings = 0
@def_function.function(experimental_compile=True)
def inner(a, b):
nonlocal inner_retracings
inner_retracings += 1
return a * b + a
def outer(a, b):
return inner(a, b)
func_input = random_ops.random_normal([10, 10])
for _ in range(2):
def_function.function(outer)(func_input, func_input)
self.assertEqual(inner_retracings, 1)
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -2981,10 +2981,9 @@ class Function(object):
if not executing_eagerly:
# We want to force function retracing for each different
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
xla_context = _enclosing_xla_context()
if xla_context is not None and \
xla_context.RequiresUniqueFunctionRetracing():
xla_context_id = id(xla_context)
tpu_context = _enclosing_xla_context()
if tpu_context is not None:
xla_context_id = id(tpu_context)
with ops.init_scope():
# The graph, or whether we're executing eagerly, should be a part of the

View File

@ -3682,11 +3682,6 @@ class XLAControlFlowContext(ControlFlowContext):
def AddValue(self, x):
return x
def RequiresUniqueFunctionRetracing(self):
"""Returns whether the tf.function should be retraced if the context changes.
"""
return False
def from_control_flow_context_def(context_def, import_scope=None):
"""Deserializes `context_def` into the appropriate ControlFlowContext.

View File

@ -639,12 +639,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def GetControlPivot(self):
return self._pivot
def RequiresUniqueFunctionRetracing(self):
# More context: b/158152827. TPU stack uses the TPUReplicateContext to
# create replicated variable handles and cluster TPU computations, thus we
# always retrace a tf.function when the wrapped TPUReplicateContext changes.
return True
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
"""The context for outside compilation in Tensorflow 2.0.