Ensure that if an op is outside compiled, its gradient gets outside compiled as well in TF2 (similar to TF1) when colocate_gradients_with_ops=True.

A couple changes were needed for this:
1. EnterGradientColocation needs to be called on the enclosing XlaContext in the parent graph hierarchy since FuncGraphs do not inherit control flow contexts.
2. Make behavior of gradients codepath in outside compilation consistent with tf.tpu.outside_compilation by directly executing the gradients ops under a OutsideCompilationV2Context.

PiperOrigin-RevId: 337397123
Change-Id: Ia47298da90d7094bee920398e6a0b4342469c9d8
This commit is contained in:
Saurabh Saxena 2020-10-15 15:39:04 -07:00 committed by TensorFlower Gardener
parent 0434f45502
commit 2d2b112517
3 changed files with 79 additions and 5 deletions
tensorflow/python

View File

@ -4332,12 +4332,16 @@ class Graph(object):
def _colocate_with_for_gradient(self, op, gradient_uid,
ignore_existing=False):
with self.colocate_with(op, ignore_existing):
if gradient_uid is not None and self._control_flow_context is not None:
self._control_flow_context.EnterGradientColocation(op, gradient_uid)
try:
if gradient_uid is not None:
ctx = _get_enclosing_context(self)
if ctx is not None:
ctx.EnterGradientColocation(op, gradient_uid)
try:
yield
finally:
ctx.ExitGradientColocation(op, gradient_uid)
else:
yield
finally:
self._control_flow_context.ExitGradientColocation(op, gradient_uid)
else:
yield
@ -6955,3 +6959,15 @@ def set_int_list_attr(op, attr_name, ints):
"""TF internal method used to set a list(int) attribute in the node_def."""
ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access
def _get_enclosing_context(graph):
# pylint: disable=protected-access
if graph is None:
return None
if graph._control_flow_context is not None:
return graph._control_flow_context
if graph.building_function and hasattr(graph, "outer_graph"):
return _get_enclosing_context(graph.outer_graph)

View File

@ -288,6 +288,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._outer_device_function_stack = None
self._oc_dev_fn_stack = None
self._outside_compilation_cluster = None
self._outside_compilation_v2_context = None
self._outside_compilation_counter = 0
self._in_gradient_colocation = None
self._gradient_colocation_stack = []
@ -379,6 +380,21 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def EnterGradientColocation(self, op, gradient_uid):
if op is not None:
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
# If we are in TF 2 functions (control flow V2 functions, or
# tf.function()), we need to attach _xla_outside_compilation attribute
# directly because we are not in TPUReplicateContext.
try:
outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii")
except ValueError:
# The attr was not present: do nothing.
return
parts = outside_attr.split(".")
cluster = parts[0] + "." + gradient_uid
self._outside_compilation_v2_context = OutsideCompilationV2Context(
cluster)
self._outside_compilation_v2_context.Enter()
return
self._gradient_colocation_stack.append(op)
if not self._outside_compilation_cluster:
try:
@ -418,6 +434,17 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def ExitGradientColocation(self, op, gradient_uid):
if op is not None:
if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access
# Inside a TF2 tf.function or control flow graph and `op` was not
# marked to be outside compiled.
assert self._outside_compilation_v2_context is None
return
if self._outside_compilation_v2_context is not None:
# Inside a TF2 tf.function or control flow graph and `op` was
# marked to be outside compiled.
self._outside_compilation_v2_context.Exit()
self._outside_compilation_v2_context = None
return
if not self._gradient_colocation_stack:
raise errors.InternalError(
op.node_def, op,

View File

@ -34,6 +34,7 @@ from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import array_ops
@ -450,6 +451,36 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
strategy.experimental_local_results(train_step()),
constant_op.constant(2916., shape=(strategy.num_replicas_in_sync)))
def testColocateGradientWithOutsideCompiledOp(self):
strategy = get_tpu_strategy()
@def_function.function
def train_step():
@def_function.function
def tpu_fn(x):
x1 = tpu.outside_compilation(math_ops.sqrt, x)
grad = gradients_impl.gradients([x1], [x],
colocate_gradients_with_ops=True)[0]
sqrt = [
op for op in ops.get_default_graph().get_operations()
if op.type == "Sqrt"
][0]
sqrt_grad = [
op for op in ops.get_default_graph().get_operations()
if op.type == "SqrtGrad"
][0]
assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0"
assert (sqrt_grad.get_attr(
tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid")
return grad
return strategy.run(tpu_fn, args=(25.0,))
self.assertAllEqual(
strategy.experimental_local_results(train_step()),
constant_op.constant(.1, shape=(strategy.num_replicas_in_sync)))
class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
parameterized.TestCase):