[TF/XLA] Only force retracing for non-unique XLA context ID for TPUReplicatedContext
Fixes https://github.com/tensorflow/tensorflow/issues/39872 PiperOrigin-RevId: 316503485 Change-Id: Ice63983fcdf2fdedca60a9054f3b76ac60e1ff15
This commit is contained in:
parent
b7d66ef926
commit
ba658404f2
@ -29,7 +29,6 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -386,24 +385,6 @@ class DefFunctionTest(test.TestCase):
|
|||||||
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
||||||
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
||||||
|
|
||||||
def testNoExcessiveRetracing(self):
|
|
||||||
inner_retracings = 0
|
|
||||||
|
|
||||||
@def_function.function(experimental_compile=True)
|
|
||||||
def inner(a, b):
|
|
||||||
nonlocal inner_retracings
|
|
||||||
inner_retracings += 1
|
|
||||||
return a * b + a
|
|
||||||
|
|
||||||
def outer(a, b):
|
|
||||||
return inner(a, b)
|
|
||||||
|
|
||||||
func_input = random_ops.random_normal([10, 10])
|
|
||||||
for _ in range(2):
|
|
||||||
def_function.function(outer)(func_input, func_input)
|
|
||||||
|
|
||||||
self.assertEqual(inner_retracings, 1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -2981,10 +2981,9 @@ class Function(object):
|
|||||||
if not executing_eagerly:
|
if not executing_eagerly:
|
||||||
# We want to force function retracing for each different
|
# We want to force function retracing for each different
|
||||||
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
|
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
|
||||||
xla_context = _enclosing_xla_context()
|
tpu_context = _enclosing_xla_context()
|
||||||
if xla_context is not None and \
|
if tpu_context is not None:
|
||||||
xla_context.RequiresUniqueFunctionRetracing():
|
xla_context_id = id(tpu_context)
|
||||||
xla_context_id = id(xla_context)
|
|
||||||
|
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
# The graph, or whether we're executing eagerly, should be a part of the
|
# The graph, or whether we're executing eagerly, should be a part of the
|
||||||
|
@ -3682,11 +3682,6 @@ class XLAControlFlowContext(ControlFlowContext):
|
|||||||
def AddValue(self, x):
|
def AddValue(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def RequiresUniqueFunctionRetracing(self):
|
|
||||||
"""Returns whether the tf.function should be retraced if the context changes.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def from_control_flow_context_def(context_def, import_scope=None):
|
def from_control_flow_context_def(context_def, import_scope=None):
|
||||||
"""Deserializes `context_def` into the appropriate ControlFlowContext.
|
"""Deserializes `context_def` into the appropriate ControlFlowContext.
|
||||||
|
@ -639,12 +639,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
|||||||
def GetControlPivot(self):
|
def GetControlPivot(self):
|
||||||
return self._pivot
|
return self._pivot
|
||||||
|
|
||||||
def RequiresUniqueFunctionRetracing(self):
|
|
||||||
# More context: b/158152827. TPU stack uses the TPUReplicateContext to
|
|
||||||
# create replicated variable handles and cluster TPU computations, thus we
|
|
||||||
# always retrace a tf.function when the wrapped TPUReplicateContext changes.
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
||||||
"""The context for outside compilation in Tensorflow 2.0.
|
"""The context for outside compilation in Tensorflow 2.0.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user