Rolls back cl/338294004 and part of cl/338299497

PiperOrigin-RevId: 338384800
Change-Id: I7cf4d230da1a9de064552b464055ee8675aa7fa4
This commit is contained in:
Allen Lavoie 2020-10-21 18:59:04 -07:00 committed by TensorFlower Gardener
parent 37f7c75a71
commit fb22dff317
8 changed files with 88 additions and 238 deletions

View File

@ -43,7 +43,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@ -1714,35 +1713,6 @@ class JacobianTest(test.TestCase):
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
def test_nested_batch_jacobian_foldl(self):
def _grad(f):
def _grad_function(primal):
with backprop.GradientTape() as tape:
tape.watch(primal)
primal_out = f(primal)
return tape.batch_jacobian(primal_out, primal)
return _grad_function
def _func(x):
return array_ops.reshape(
functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
array_ops.transpose(x)),
[1, 1])
f = _func
x = constant_op.constant([[1., 2.]])
for _ in range(2):
theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
self.assertAllClose(theoretical, numerical, rtol=1e-3)
f = _grad(f)
expected_flat = array_ops.reshape(numerical, [-1])
self.assertAllClose(expected_flat,
array_ops.reshape(f(x), [-1]),
rtol=1e-3)
self.assertAllClose(expected_flat,
array_ops.reshape(def_function.function(f)(x), [-1]),
rtol=1e-3)
@test_util.run_in_graph_and_eager_modes
def test_indexed_slices(self):
with backprop.GradientTape(persistent=True) as g:

View File

@ -19,35 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import handle_data_util
def _DTypeFromTensor(tensor):
"""Extract either `tensor.dtype` or the unanimous sub-type of a variant."""
dtype = tensor.dtype
if dtype.base_dtype == dtypes.variant:
# If we know statically that the data a variant points to is non-trainable
# then the variant itself is non-trainable.
if isinstance(tensor, ops.EagerTensor):
handle_data = tensor._handle_data # pylint: disable=protected-access
else:
handle_data = handle_data_util.get_resource_handle_data(tensor)
if (handle_data is not None
and handle_data.is_set
and handle_data.shape_and_type):
first_type = handle_data.shape_and_type[0].dtype
if all(shape_and_type.dtype == first_type
for shape_and_type in handle_data.shape_and_type):
return first_type
return dtype
def IsTrainable(tensor_or_dtype):
"""Determines whether a tensor or dtype supports infinitesimal changes."""
if tensor_util.is_tensor(tensor_or_dtype):
dtype = _DTypeFromTensor(tensor_or_dtype)
dtype = tensor_or_dtype.dtype
else:
dtype = tensor_or_dtype
dtype = dtypes.as_dtype(dtype)

View File

@ -1387,7 +1387,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
gradients with respect to the inputs.
"""
outputs = []
iteration_count = 0
# First we need to figure out how many side outputs from the forward pass
# will be required. We do this in a temporary graph to avoid actually
# running multiple copies of the backward pass (one per _GradientsHelper
@ -1402,42 +1401,15 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
# all of the forward op's outputs: symbolic gradients with tf.gradients
# instead rely on regenerating backward functions when higher-order
# gradients are requested.
while (len(outputs) < len(self._func_graph.outputs)
# It's possible for gradient generation to add new ops to the forward
# pass. If all of the new outputs are non-trainable, there's no
# reason to continue.
and any(backprop_util.IsTrainable(output)
for output in self._func_graph.outputs[len(outputs):])):
iteration_count += 1
if iteration_count >= 20 and iteration_count % 5 == 0:
new_op_with_trainable_output = None
num_new_trainable_outputs = 0
for output in self._func_graph.outputs[len(outputs):]:
if backprop_util.IsTrainable(output):
num_new_trainable_outputs += 1
new_op_with_trainable_output = output.op
logging.warning(
("Determining side outputs for the function '{}' is taking longer "
"than expected ({} iterations, typically this converges in 5 or "
"so). This could indicate that a gradient registration is adding "
"new ops to the forward pass every time gradients are generated. "
"{} new trainable output(s) were added this iteration, one from "
"the following op:\n {}\nThis may indicate a TensorFlow bug, or "
"an issue in a tf.custom_gradient.")
.format(
self._func_graph.name, iteration_count,
num_new_trainable_outputs, new_op_with_trainable_output))
while len(outputs) < len(self._func_graph.outputs):
outputs = list(self._func_graph.outputs)
self._build_functions_for_outputs(
outputs, inference_args, input_tangents)
(forward_function, forward_graph,
backward_function, output_indices, num_output_tangents) = (
self._build_functions_for_outputs(
outputs, inference_args, input_tangents))
if (len(self._func_graph.outputs) > len(outputs)
and any(backprop_util.IsTrainable(output)
for output in self._func_graph.outputs[len(outputs):])):
if len(self._func_graph.outputs) != len(outputs):
raise AssertionError(
("Unexpectedly added new outputs to the forward function when "
"building the backward function: {}").format(

View File

@ -442,11 +442,6 @@ class FuncGraph(ops.Graph):
return self._fallback_outer_graph
return current
@outer_graph.setter
def outer_graph(self, new_outer_graph):
"""Sets `outer_graph` to `new_outer_graph`."""
self._weak_outer_graph = weakref.ref(new_outer_graph)
@property
def output_types(self):
return [t.dtype for t in self.outputs]

View File

@ -42,7 +42,6 @@ from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
@ -172,57 +171,6 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.assertAllEqual(fnWithLoop(), 4.0)
def checkIteratedGradients(self, func):
with context.eager_mode():
def _Grad(f):
def _GradFunction(primal):
with backprop.GradientTape() as tape:
tape.watch(primal)
primal_out = f(primal)
return tape.gradient(primal_out, primal)
return _GradFunction
f = func
one = constant_op.constant(1.)
for _ in range(3):
theoretical, numerical = gradient_checker_v2.compute_gradient(
def_function.function(f), [one])
self.assertAllClose(theoretical, numerical, rtol=1e-3)
f = _Grad(f)
self.assertAllClose(array_ops.reshape(numerical, []),
def_function.function(f)(one),
rtol=1e-3)
def testIteratedGradients(self):
def _Func(x):
_, z = while_loop_v2(
lambda i, _: i < 2,
lambda i, y: (i + 1, math_ops.cos(y)),
[0, x])
return z
self.checkIteratedGradients(_Func)
def testIteratedGradientsWithList(self):
def _Func(x):
results = list_ops.empty_tensor_list(
element_shape=[], element_dtype=dtypes.float32)
def _LoopBody(i, y, handle):
return (i + 1, math_ops.cos(y),
list_ops.tensor_list_push_back(handle, y))
_, z, results = while_loop_v2(
lambda i, _, h: i < 2, _LoopBody, [0, x, results])
return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
results, dtypes.float32))
self.checkIteratedGradients(_Func)
def testDeviceLabelsInherited(self):
def _LoopBody(i, y):
result = math_ops.cos(y)

View File

@ -26,6 +26,7 @@ from __future__ import print_function
import collections
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import function
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
return [None] + outputs
def _run_as_function_for_tape_gradients(make_op, cond_inputs):
"""Fix higher-order tape gradients by wrapping `make_op` in a function."""
# GradientTapes created inside a function currently don't work well with
# un-wrapped control flow ops in that same function. Wrapping in an extra
# layer of intermediate function means we run extra logic in the function
# gradient code to record the correct intermediates on the tape.
#
# The function attribute inputs to cond/case ops are not hashable, so we pass
# everything as a capture to bypass defun's caching.
if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
# We only need one function between the tape and the cond; if we've
# already wrapped once, we stop wrapping to avoid infinite recursion.
and not (ops.get_default_graph().building_function
and "cond_gradient_wrapper" in ops.get_default_graph().name)):
op = None
def _run_make_and_extract_op():
# Post-processing happens on the cond op, not the function call op.
nonlocal op
tensors = make_op()
op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable
return tensors
return op, function.defun_with_attributes(
_run_make_and_extract_op,
attributes=dict(func_name="cond_gradient_wrapper"))()
else:
return _get_op_and_outputs(make_op())
def _build_cond(pred,
true_graph,
false_graph,
@ -268,35 +300,28 @@ def _build_cond(pred,
else:
op_fn = gen_functional_ops.stateless_if
def _make_op(inputs):
if_op, tensors = util.get_op_and_outputs(op_fn(
def make_op():
return op_fn(
pred,
inputs, [t.dtype for t in true_graph.outputs],
cond_inputs, [t.dtype for t in true_graph.outputs],
util.create_new_tf_function(true_graph),
util.create_new_tf_function(false_graph),
output_shapes=_get_output_shapes(true_graph.outputs,
false_graph.outputs),
name=name))
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
# `if_op` is None if this is a `StatelessIf` op with no outputs.
if if_op is not None:
# The true and false graphs have already been created, and we need that
# to happen before we know which tensors will be captured and so whether
# to wrap the cond in a tf.function. Post-hoc mutation of the branch
# `outer_graph` properties seems like the only option if we want to
# conditionally wrap in a function.
true_graph.outer_graph = ops.get_default_graph()
false_graph.outer_graph = ops.get_default_graph()
if_op._true_graph = true_graph
if_op._false_graph = false_graph
util.maybe_set_lowering_attr(if_op)
util.maybe_propagate_compile_time_consts_in_xla(if_op)
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
return tensors
tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
name=name)
if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
# `if_op` is None if this is a `StatelessIf` op with no outputs.
if if_op is not None:
if_op._true_graph = true_graph
if_op._false_graph = false_graph
util.maybe_set_lowering_attr(if_op)
util.maybe_propagate_compile_time_consts_in_xla(if_op)
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
# Return identities for each output of the If op, rather than the output of
# the If op directly. This makes pruning work if the output of cond() is
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
@ -693,6 +718,15 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
branch_graph.structured_outputs, branch_graph.outputs)
def _get_op_and_outputs(op_or_outputs):
if isinstance(op_or_outputs, ops.Operation):
return op_or_outputs, []
elif not op_or_outputs: # Empty list.
return None, []
else:
return op_or_outputs[0].op, op_or_outputs
def _pack_sequence_as(structured_outputs, op_outputs):
"""Packs the outputs of the gradient If/Case op.
@ -1156,23 +1190,24 @@ def _build_case(branch_index,
with ops.control_dependencies(
sum((list(bg.control_captures) for bg in branch_graphs), [])):
def _make_op(inputs):
case_op, tensors = util.get_op_and_outputs(op_fn(
def _make_op():
return op_fn(
branch_index,
inputs, [t.dtype for t in branch_graphs[0].outputs],
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
[util.create_new_tf_function(g) for g in branch_graphs],
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name))
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
if case_op is not None:
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
return tensors
tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
name=name)
case_op, tensors = _run_as_function_for_tape_gradients(
_make_op, case_inputs)
if case_op is not None:
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
# Return identities for each output of the Case op, rather than the output of
# the Case op directly. This makes pruning work if the output of switch_case()
# is fetched: the lowering pass converts the Case outputs into IdentityN

View File

@ -28,7 +28,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.ops import gradients_util
from tensorflow.python.util import keras_deps
from tensorflow.python.util import tf_contextlib
@ -313,48 +312,3 @@ def get_func_graph(op, input_shapes, func_name):
func_graph = function_def_to_graph.function_def_to_graph(
fdef, input_shapes)
return func_graph
def get_op_and_outputs(op_or_outputs):
if isinstance(op_or_outputs, ops.Operation):
return op_or_outputs, []
elif not op_or_outputs: # Empty list.
return None, []
else:
return op_or_outputs[0].op, op_or_outputs
def run_as_function_for_tape_gradients(make_op, inputs):
"""Fix higher-order tape gradients by wrapping `make_op` in a function.
Args:
make_op: A function that takes a list of inputs and returns a list of output
tensors. This function should set any handle data relevant to its outputs
before returning.
inputs: A list of tensors to check for tape gradients and pass to
`make_op`. These should include all tensors used in `make_op`.
Returns:
Tensors corresponding to `make_op`'s output.
"""
# GradientTapes created inside a function currently don't work well with
# un-wrapped control flow ops in that same function. Wrapping in an extra
# layer of intermediate function means we run extra logic in the function
# gradient code to record the correct intermediates on the tape.
#
# The function attribute inputs to control flow ops are not hashable, so we
# pass everything as a capture to bypass defun's caching.
if (gradients_util.PossibleTapeGradientTypes(inputs)
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
# We only need one function between the tape and the op; if we've already
# wrapped once, we stop wrapping to avoid infinite recursion.
and not (ops.get_default_graph().building_function
and "cflow_gradient_wrapper" in ops.get_default_graph().name)):
results = function.defun_with_attributes(
make_op,
autograph=False,
attributes=dict(func_name="cflow_gradient_wrapper"))(inputs)
return results
else:
return make_op(inputs)

View File

@ -429,20 +429,19 @@ def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
else:
op_fn = gen_functional_ops.stateless_while
def _make_op(inputs):
while_op, tensors = util.get_op_and_outputs(op_fn(
inputs,
util.create_new_tf_function(cond_graph),
util.create_new_tf_function(body_graph),
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
name=name))
_copy_handle_data(body_graph.outputs, tensors)
util.maybe_set_lowering_attr(while_op)
util.maybe_propagate_compile_time_consts_in_xla(while_op)
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
return tensors
return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
outputs = op_fn(
loop_vars,
util.create_new_tf_function(cond_graph),
util.create_new_tf_function(body_graph),
output_shapes=output_shapes,
parallel_iterations=parallel_iterations,
name=name)
while_op = outputs[0].op
_copy_handle_data(body_graph.outputs, outputs)
util.maybe_set_lowering_attr(while_op)
util.maybe_propagate_compile_time_consts_in_xla(while_op)
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
return outputs
def _get_intermediates(func_graph):
@ -823,7 +822,7 @@ def _get_accumulator(tensor):
# tf.defun adds an Identity for each output, check whether that is the case.
identity_op = t.consumers()[0]
if (identity_op.type == "Identity" and
any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
identity_op.outputs[0] in tensor.graph.outputs):
return identity_op.outputs[0]
return None