Makes experimental_implements work correctly with GradientTape. Before it would erroneously keep experimental_implements attribute in forward/backward of eager versions of the tf.function, which would cause a runtime crash.
This change unifies that code path being used for both eager and non eager functions. PiperOrigin-RevId: 352033823 Change-Id: I98ae4d07420826482faf9e0ace3f53193964a816
This commit is contained in:
parent
fbd8dfede8
commit
b7d254ae61
tensorflow/python/eager
@ -602,6 +602,32 @@ class _EagerDefinedFunction(object):
|
||||
return outputs
|
||||
|
||||
|
||||
def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph):
|
||||
"""Creates forward and backward functions from the function graphs."""
|
||||
forward_function_name = _forward_name(forward_graph.name)
|
||||
common_attributes = dict(attrs)
|
||||
# NB: forward and backward function need to drop "_implements".
|
||||
# attribute, because their signature contains all the intermediate tensors
|
||||
# that they compute. Thus they don't have a stable signature which can
|
||||
# be directly optimized downstream.
|
||||
# See for more details:
|
||||
# https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
|
||||
common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
|
||||
backward_function_attr = _parse_func_attrs(
|
||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||
backward_function_attr.update(common_attributes)
|
||||
backward_function = ConcreteFunction(
|
||||
backwards_graph, attrs=backward_function_attr)
|
||||
forward_function_attr = _parse_func_attrs({
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||
backward_function.name})
|
||||
forward_function_attr.update(common_attributes)
|
||||
forward_function = _EagerDefinedFunction(
|
||||
forward_function_name, forward_graph, forward_graph.inputs,
|
||||
forward_graph.outputs, forward_function_attr)
|
||||
return forward_function, backward_function
|
||||
|
||||
|
||||
class _DelayedRewriteGradientFunctions(object):
|
||||
"""Caches forward/backward functions with a delayed forward rewrite."""
|
||||
|
||||
@ -683,39 +709,15 @@ class _DelayedRewriteGradientFunctions(object):
|
||||
c for c in backwards_graph_captures if
|
||||
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
||||
|
||||
forward_function_name = _forward_name(self._func_graph.name)
|
||||
|
||||
# NB: forward and backward function have their "_implements"
|
||||
# attribute set to None if it was present. This is because we don't
|
||||
# support replacing those functions. If we do want for those functions
|
||||
# to have implements function we need to provide a mechanism that
|
||||
# would allow to identify all functions that call this one
|
||||
# and trace and update their signatures as well. At the moment
|
||||
# we disable this, until the tooling for doing this becomes available.
|
||||
# See:
|
||||
# https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
|
||||
common_attributes = dict(self._attrs)
|
||||
common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
|
||||
|
||||
existing_outputs = object_identity.ObjectIdentitySet(
|
||||
self._func_graph.outputs)
|
||||
for capture in captures_from_forward:
|
||||
if capture not in existing_outputs:
|
||||
existing_outputs.add(capture)
|
||||
self._func_graph.outputs.append(capture)
|
||||
backward_function_attr = _parse_func_attrs(
|
||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||
backward_function_attr.update(common_attributes)
|
||||
|
||||
backward_function = ConcreteFunction(
|
||||
backwards_graph, attrs=backward_function_attr)
|
||||
forward_function_attr = _parse_func_attrs({
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||
backward_function.name})
|
||||
forward_function_attr.update(common_attributes)
|
||||
forward_function = _EagerDefinedFunction(
|
||||
forward_function_name, self._func_graph, self._func_graph.inputs,
|
||||
self._func_graph.outputs, forward_function_attr)
|
||||
forward_function, backward_function = _create_forward_backward_with_graph(
|
||||
self._attrs, self._func_graph, backwards_graph)
|
||||
return forward_function, backward_function
|
||||
|
||||
def _rewrite_forward_and_call_backward(self, op, *doutputs):
|
||||
@ -928,11 +930,6 @@ class _TapeGradientFunctions(object):
|
||||
existing_outputs.add(capture)
|
||||
self._func_graph.outputs.append(capture)
|
||||
|
||||
forward_function_name = _forward_name(self._func_graph.name)
|
||||
backward_function_attr = _parse_func_attrs(
|
||||
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
|
||||
backward_function_attr.update(self._attrs)
|
||||
|
||||
# The ordering of `backwards_graph.inputs` is important: inputs of
|
||||
# `backward_function` correspond to outputs (including
|
||||
# side outputs) of `self._tape_forward_function`.
|
||||
@ -943,18 +940,9 @@ class _TapeGradientFunctions(object):
|
||||
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
||||
if grad is not None)
|
||||
backwards_graph.structured_outputs = gradients_wrt_inputs
|
||||
backward_function = ConcreteFunction(
|
||||
backwards_graph, attrs=backward_function_attr)
|
||||
|
||||
forward_function_attr = _parse_func_attrs({
|
||||
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
|
||||
backward_function.name})
|
||||
forward_function_attr.update(self._attrs)
|
||||
|
||||
forward_function = _EagerDefinedFunction(
|
||||
forward_function_name, self._func_graph, self._func_graph.inputs,
|
||||
self._func_graph.outputs,
|
||||
forward_function_attr)
|
||||
forward_function, backward_function = _create_forward_backward_with_graph(
|
||||
self._attrs, self._func_graph, backwards_graph)
|
||||
|
||||
if not input_tangents:
|
||||
# There is no need to special-case forwardprop, so we can return the
|
||||
@ -971,14 +959,9 @@ class _TapeGradientFunctions(object):
|
||||
# are in the same order the backward function expects them to be in:
|
||||
# [inference outputs] + [jvps] + [side outputs] + [captures].
|
||||
forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
|
||||
|
||||
wrapped_forward_function = _EagerDefinedFunction(
|
||||
_forward_name(self._func_graph.name), forward_wrapper.graph,
|
||||
forward_wrapper.graph.inputs, forward_wrapper.graph.outputs,
|
||||
forward_function_attr)
|
||||
wrapped_backward_function = ConcreteFunction(
|
||||
wrapped_backwards_graph, attrs=backward_function_attr)
|
||||
|
||||
(wrapped_forward_function,
|
||||
wrapped_backward_function) = _create_forward_backward_with_graph(
|
||||
self._attrs, forward_wrapper.graph, wrapped_backwards_graph)
|
||||
if (len(inference_args) + len(input_tangents)
|
||||
!= len(forward_wrapper.graph.inputs)):
|
||||
raise AssertionError(
|
||||
|
@ -272,6 +272,19 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
functions = ops.get_default_graph().as_graph_def().library.function
|
||||
self.assertEmpty(functions)
|
||||
|
||||
def testImplementsAttributeWorksWithGradientTape(self):
|
||||
add = lambda x, y: x + y ** 2
|
||||
add = def_function.function(experimental_implements='MyFunc')(add)
|
||||
x = variables.Variable(3.0)
|
||||
y = variables.Variable(2.0)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
g = add(x, y)
|
||||
|
||||
dg_dy, dg_dx = tape.gradient(g, [y, x])
|
||||
self.assertEqual(dg_dy.numpy(), 4.0)
|
||||
self.assertEqual(dg_dx.numpy(), 1.0)
|
||||
|
||||
def testImplementsAttributeWorksOnVariables(self):
|
||||
with context.graph_mode(), self.cached_session():
|
||||
v = def_function.function(
|
||||
|
Loading…
Reference in New Issue
Block a user