Makes experimental_implements work correctly with GradientTape. Before it would erroneously keep experimental_implements attribute in forward/backward of eager versions of the tf.function, which would cause a runtime crash.

This change unifies that code path being used for both eager and non eager functions.

PiperOrigin-RevId: 352033823
Change-Id: I98ae4d07420826482faf9e0ace3f53193964a816
This commit is contained in:
Mark Sandler 2021-01-15 10:14:52 -08:00 committed by TensorFlower Gardener
parent fbd8dfede8
commit b7d254ae61
2 changed files with 46 additions and 50 deletions
tensorflow/python/eager

View File

@ -602,6 +602,32 @@ class _EagerDefinedFunction(object):
return outputs
def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph):
"""Creates forward and backward functions from the function graphs."""
forward_function_name = _forward_name(forward_graph.name)
common_attributes = dict(attrs)
# NB: forward and backward function need to drop "_implements".
# attribute, because their signature contains all the intermediate tensors
# that they compute. Thus they don't have a stable signature which can
# be directly optimized downstream.
# See for more details:
# https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(common_attributes)
backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
backward_function.name})
forward_function_attr.update(common_attributes)
forward_function = _EagerDefinedFunction(
forward_function_name, forward_graph, forward_graph.inputs,
forward_graph.outputs, forward_function_attr)
return forward_function, backward_function
class _DelayedRewriteGradientFunctions(object):
"""Caches forward/backward functions with a delayed forward rewrite."""
@ -683,39 +709,15 @@ class _DelayedRewriteGradientFunctions(object):
c for c in backwards_graph_captures if
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
forward_function_name = _forward_name(self._func_graph.name)
# NB: forward and backward function have their "_implements"
# attribute set to None if it was present. This is because we don't
# support replacing those functions. If we do want for those functions
# to have implements function we need to provide a mechanism that
# would allow to identify all functions that call this one
# and trace and update their signatures as well. At the moment
# we disable this, until the tooling for doing this becomes available.
# See:
# https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
common_attributes = dict(self._attrs)
common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
existing_outputs = object_identity.ObjectIdentitySet(
self._func_graph.outputs)
for capture in captures_from_forward:
if capture not in existing_outputs:
existing_outputs.add(capture)
self._func_graph.outputs.append(capture)
backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(common_attributes)
backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
backward_function.name})
forward_function_attr.update(common_attributes)
forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs, forward_function_attr)
forward_function, backward_function = _create_forward_backward_with_graph(
self._attrs, self._func_graph, backwards_graph)
return forward_function, backward_function
def _rewrite_forward_and_call_backward(self, op, *doutputs):
@ -928,11 +930,6 @@ class _TapeGradientFunctions(object):
existing_outputs.add(capture)
self._func_graph.outputs.append(capture)
forward_function_name = _forward_name(self._func_graph.name)
backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
# `backward_function` correspond to outputs (including
# side outputs) of `self._tape_forward_function`.
@ -943,18 +940,9 @@ class _TapeGradientFunctions(object):
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
backward_function.name})
forward_function_attr.update(self._attrs)
forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs,
forward_function_attr)
forward_function, backward_function = _create_forward_backward_with_graph(
self._attrs, self._func_graph, backwards_graph)
if not input_tangents:
# There is no need to special-case forwardprop, so we can return the
@ -971,14 +959,9 @@ class _TapeGradientFunctions(object):
# are in the same order the backward function expects them to be in:
# [inference outputs] + [jvps] + [side outputs] + [captures].
forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
wrapped_forward_function = _EagerDefinedFunction(
_forward_name(self._func_graph.name), forward_wrapper.graph,
forward_wrapper.graph.inputs, forward_wrapper.graph.outputs,
forward_function_attr)
wrapped_backward_function = ConcreteFunction(
wrapped_backwards_graph, attrs=backward_function_attr)
(wrapped_forward_function,
wrapped_backward_function) = _create_forward_backward_with_graph(
self._attrs, forward_wrapper.graph, wrapped_backwards_graph)
if (len(inference_args) + len(input_tangents)
!= len(forward_wrapper.graph.inputs)):
raise AssertionError(

View File

@ -272,6 +272,19 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
functions = ops.get_default_graph().as_graph_def().library.function
self.assertEmpty(functions)
def testImplementsAttributeWorksWithGradientTape(self):
add = lambda x, y: x + y ** 2
add = def_function.function(experimental_implements='MyFunc')(add)
x = variables.Variable(3.0)
y = variables.Variable(2.0)
with backprop.GradientTape() as tape:
g = add(x, y)
dg_dy, dg_dx = tape.gradient(g, [y, x])
self.assertEqual(dg_dy.numpy(), 4.0)
self.assertEqual(dg_dx.numpy(), 1.0)
def testImplementsAttributeWorksOnVariables(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(