Rollback of rollback: [TF/XLA] Only force retracing for non-unique XLA context ID for TPUReplicatedContext
Fixes https://github.com/tensorflow/tensorflow/issues/39872 PiperOrigin-RevId: 317137904 Change-Id: Id287e10a0ab2494b11427435d8f89a383eeaf392
This commit is contained in:
parent
2ff1c5a31b
commit
b8bb250ebb
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
@ -385,6 +386,24 @@ class DefFunctionTest(test.TestCase):
|
||||||
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
|
||||||
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
|
||||||
|
|
||||||
|
def testNoExcessiveRetracing(self):
|
||||||
|
inner_retracings = 0
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def inner(a, b):
|
||||||
|
nonlocal inner_retracings
|
||||||
|
inner_retracings += 1
|
||||||
|
return a * b + a
|
||||||
|
|
||||||
|
def outer(a, b):
|
||||||
|
return inner(a, b)
|
||||||
|
|
||||||
|
func_input = random_ops.random_normal([10, 10])
|
||||||
|
for _ in range(2):
|
||||||
|
def_function.function(outer)(func_input, func_input)
|
||||||
|
|
||||||
|
self.assertEqual(inner_retracings, 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
|
|
@ -2981,9 +2981,10 @@ class Function(object):
|
||||||
if not executing_eagerly:
|
if not executing_eagerly:
|
||||||
# We want to force function retracing for each different
|
# We want to force function retracing for each different
|
||||||
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
|
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
|
||||||
tpu_context = _enclosing_xla_context()
|
xla_context = _enclosing_xla_context()
|
||||||
if tpu_context is not None:
|
if xla_context is not None and \
|
||||||
xla_context_id = id(tpu_context)
|
xla_context.RequiresUniqueFunctionRetracing():
|
||||||
|
xla_context_id = id(xla_context)
|
||||||
|
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
# The graph, or whether we're executing eagerly, should be a part of the
|
# The graph, or whether we're executing eagerly, should be a part of the
|
||||||
|
|
|
@ -3682,6 +3682,11 @@ class XLAControlFlowContext(ControlFlowContext):
|
||||||
def AddValue(self, x):
|
def AddValue(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def RequiresUniqueFunctionRetracing(self):
|
||||||
|
"""Returns whether the tf.function should be retraced if the context changes.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def from_control_flow_context_def(context_def, import_scope=None):
|
def from_control_flow_context_def(context_def, import_scope=None):
|
||||||
"""Deserializes `context_def` into the appropriate ControlFlowContext.
|
"""Deserializes `context_def` into the appropriate ControlFlowContext.
|
||||||
|
|
|
@ -639,6 +639,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||||
def GetControlPivot(self):
|
def GetControlPivot(self):
|
||||||
return self._pivot
|
return self._pivot
|
||||||
|
|
||||||
|
def RequiresUniqueFunctionRetracing(self):
|
||||||
|
# More context: b/158152827. TPU stack uses the TPUReplicateContext to
|
||||||
|
# create replicated variable handles and cluster TPU computations, thus we
|
||||||
|
# always retrace a tf.function when the wrapped TPUReplicateContext changes.
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
class OutsideCompilationV2Context(control_flow_ops.ControlFlowContext):
|
||||||
"""The context for outside compilation in Tensorflow 2.0.
|
"""The context for outside compilation in Tensorflow 2.0.
|
||||||
|
|
Loading…
Reference in New Issue