Ensure that if an op is outside compiled, its gradient gets outside compiled as well in TF2 (similar to TF1) when colocate_gradients_with_ops=True.
A couple changes were needed for this: 1. EnterGradientColocation needs to be called on the enclosing XlaContext in the parent graph hierarchy since FuncGraphs do not inherit control flow contexts. 2. Make behavior of gradients codepath in outside compilation consistent with tf.tpu.outside_compilation by directly executing the gradients ops under a OutsideCompilationV2Context. PiperOrigin-RevId: 337397123 Change-Id: Ia47298da90d7094bee920398e6a0b4342469c9d8
This commit is contained in:
parent
0434f45502
commit
2d2b112517
tensorflow/python
@ -4332,12 +4332,16 @@ class Graph(object):
|
||||
def _colocate_with_for_gradient(self, op, gradient_uid,
|
||||
ignore_existing=False):
|
||||
with self.colocate_with(op, ignore_existing):
|
||||
if gradient_uid is not None and self._control_flow_context is not None:
|
||||
self._control_flow_context.EnterGradientColocation(op, gradient_uid)
|
||||
try:
|
||||
if gradient_uid is not None:
|
||||
ctx = _get_enclosing_context(self)
|
||||
if ctx is not None:
|
||||
ctx.EnterGradientColocation(op, gradient_uid)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
ctx.ExitGradientColocation(op, gradient_uid)
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
self._control_flow_context.ExitGradientColocation(op, gradient_uid)
|
||||
else:
|
||||
yield
|
||||
|
||||
@ -6955,3 +6959,15 @@ def set_int_list_attr(op, attr_name, ints):
|
||||
"""TF internal method used to set a list(int) attribute in the node_def."""
|
||||
ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
|
||||
op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access
|
||||
|
||||
|
||||
def _get_enclosing_context(graph):
|
||||
# pylint: disable=protected-access
|
||||
if graph is None:
|
||||
return None
|
||||
|
||||
if graph._control_flow_context is not None:
|
||||
return graph._control_flow_context
|
||||
|
||||
if graph.building_function and hasattr(graph, "outer_graph"):
|
||||
return _get_enclosing_context(graph.outer_graph)
|
||||
|
@ -288,6 +288,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
self._outer_device_function_stack = None
|
||||
self._oc_dev_fn_stack = None
|
||||
self._outside_compilation_cluster = None
|
||||
self._outside_compilation_v2_context = None
|
||||
self._outside_compilation_counter = 0
|
||||
self._in_gradient_colocation = None
|
||||
self._gradient_colocation_stack = []
|
||||
@ -379,6 +380,21 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
|
||||
def EnterGradientColocation(self, op, gradient_uid):
|
||||
if op is not None:
|
||||
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
|
||||
# If we are in TF 2 functions (control flow V2 functions, or
|
||||
# tf.function()), we need to attach _xla_outside_compilation attribute
|
||||
# directly because we are not in TPUReplicateContext.
|
||||
try:
|
||||
outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
|
||||
except ValueError:
|
||||
# The attr was not present: do nothing.
|
||||
return
|
||||
parts = outside_attr.split(".")
|
||||
cluster = parts[0] + "." + gradient_uid
|
||||
self._outside_compilation_v2_context = OutsideCompilationV2Context(
|
||||
cluster)
|
||||
self._outside_compilation_v2_context.Enter()
|
||||
return
|
||||
self._gradient_colocation_stack.append(op)
|
||||
if not self._outside_compilation_cluster:
|
||||
try:
|
||||
@ -418,6 +434,17 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
|
||||
def ExitGradientColocation(self, op, gradient_uid):
|
||||
if op is not None:
|
||||
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
|
||||
# Inside a TF2 tf.function or control flow graph and `op` was not
|
||||
# marked to be outside compiled.
|
||||
assert self._outside_compilation_v2_context is None
|
||||
return
|
||||
if self._outside_compilation_v2_context is not None:
|
||||
# Inside a TF2 tf.function or control flow graph and `op` was
|
||||
# marked to be outside compiled.
|
||||
self._outside_compilation_v2_context.Exit()
|
||||
self._outside_compilation_v2_context = None
|
||||
return
|
||||
if not self._gradient_colocation_stack:
|
||||
raise errors.InternalError(
|
||||
op.node_def, op,
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -450,6 +451,36 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
|
||||
strategy.experimental_local_results(train_step()),
|
||||
constant_op.constant(2916., shape=(strategy.num_replicas_in_sync)))
|
||||
|
||||
def testColocateGradientWithOutsideCompiledOp(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
|
||||
@def_function.function
|
||||
def tpu_fn(x):
|
||||
x1 = tpu.outside_compilation(math_ops.sqrt, x)
|
||||
grad = gradients_impl.gradients([x1], [x],
|
||||
colocate_gradients_with_ops=True)[0]
|
||||
sqrt = [
|
||||
op for op in ops.get_default_graph().get_operations()
|
||||
if op.type == "Sqrt"
|
||||
][0]
|
||||
sqrt_grad = [
|
||||
op for op in ops.get_default_graph().get_operations()
|
||||
if op.type == "SqrtGrad"
|
||||
][0]
|
||||
assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0"
|
||||
assert (sqrt_grad.get_attr(
|
||||
tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid")
|
||||
return grad
|
||||
|
||||
return strategy.run(tpu_fn, args=(25.0,))
|
||||
|
||||
self.assertAllEqual(
|
||||
strategy.experimental_local_results(train_step()),
|
||||
constant_op.constant(.1, shape=(strategy.num_replicas_in_sync)))
|
||||
|
||||
|
||||
class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user