Rolls back cl/338294004 and part of cl/338299497
PiperOrigin-RevId: 338384800 Change-Id: I7cf4d230da1a9de064552b464055ee8675aa7fa4
This commit is contained in:
parent
37f7c75a71
commit
fb22dff317
@ -43,7 +43,6 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
@ -1714,35 +1713,6 @@ class JacobianTest(test.TestCase):
|
||||
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
|
||||
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
|
||||
|
||||
def test_nested_batch_jacobian_foldl(self):
|
||||
def _grad(f):
|
||||
def _grad_function(primal):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(primal)
|
||||
primal_out = f(primal)
|
||||
return tape.batch_jacobian(primal_out, primal)
|
||||
return _grad_function
|
||||
|
||||
def _func(x):
|
||||
return array_ops.reshape(
|
||||
functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
|
||||
array_ops.transpose(x)),
|
||||
[1, 1])
|
||||
|
||||
f = _func
|
||||
x = constant_op.constant([[1., 2.]])
|
||||
for _ in range(2):
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
|
||||
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
||||
f = _grad(f)
|
||||
expected_flat = array_ops.reshape(numerical, [-1])
|
||||
self.assertAllClose(expected_flat,
|
||||
array_ops.reshape(f(x), [-1]),
|
||||
rtol=1e-3)
|
||||
self.assertAllClose(expected_flat,
|
||||
array_ops.reshape(def_function.function(f)(x), [-1]),
|
||||
rtol=1e-3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_indexed_slices(self):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
|
@ -19,35 +19,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import handle_data_util
|
||||
|
||||
|
||||
def _DTypeFromTensor(tensor):
|
||||
"""Extract either `tensor.dtype` or the unanimous sub-type of a variant."""
|
||||
dtype = tensor.dtype
|
||||
if dtype.base_dtype == dtypes.variant:
|
||||
# If we know statically that the data a variant points to is non-trainable
|
||||
# then the variant itself is non-trainable.
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
handle_data = tensor._handle_data # pylint: disable=protected-access
|
||||
else:
|
||||
handle_data = handle_data_util.get_resource_handle_data(tensor)
|
||||
if (handle_data is not None
|
||||
and handle_data.is_set
|
||||
and handle_data.shape_and_type):
|
||||
first_type = handle_data.shape_and_type[0].dtype
|
||||
if all(shape_and_type.dtype == first_type
|
||||
for shape_and_type in handle_data.shape_and_type):
|
||||
return first_type
|
||||
return dtype
|
||||
|
||||
|
||||
def IsTrainable(tensor_or_dtype):
|
||||
"""Determines whether a tensor or dtype supports infinitesimal changes."""
|
||||
if tensor_util.is_tensor(tensor_or_dtype):
|
||||
dtype = _DTypeFromTensor(tensor_or_dtype)
|
||||
dtype = tensor_or_dtype.dtype
|
||||
else:
|
||||
dtype = tensor_or_dtype
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
|
@ -1387,7 +1387,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
|
||||
gradients with respect to the inputs.
|
||||
"""
|
||||
outputs = []
|
||||
iteration_count = 0
|
||||
# First we need to figure out how many side outputs from the forward pass
|
||||
# will be required. We do this in a temporary graph to avoid actually
|
||||
# running multiple copies of the backward pass (one per _GradientsHelper
|
||||
@ -1402,42 +1401,15 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
|
||||
# all of the forward op's outputs: symbolic gradients with tf.gradients
|
||||
# instead rely on regenerating backward functions when higher-order
|
||||
# gradients are requested.
|
||||
while (len(outputs) < len(self._func_graph.outputs)
|
||||
# It's possible for gradient generation to add new ops to the forward
|
||||
# pass. If all of the new outputs are non-trainable, there's no
|
||||
# reason to continue.
|
||||
and any(backprop_util.IsTrainable(output)
|
||||
for output in self._func_graph.outputs[len(outputs):])):
|
||||
iteration_count += 1
|
||||
if iteration_count >= 20 and iteration_count % 5 == 0:
|
||||
new_op_with_trainable_output = None
|
||||
num_new_trainable_outputs = 0
|
||||
for output in self._func_graph.outputs[len(outputs):]:
|
||||
if backprop_util.IsTrainable(output):
|
||||
num_new_trainable_outputs += 1
|
||||
new_op_with_trainable_output = output.op
|
||||
logging.warning(
|
||||
("Determining side outputs for the function '{}' is taking longer "
|
||||
"than expected ({} iterations, typically this converges in 5 or "
|
||||
"so). This could indicate that a gradient registration is adding "
|
||||
"new ops to the forward pass every time gradients are generated. "
|
||||
"{} new trainable output(s) were added this iteration, one from "
|
||||
"the following op:\n {}\nThis may indicate a TensorFlow bug, or "
|
||||
"an issue in a tf.custom_gradient.")
|
||||
.format(
|
||||
self._func_graph.name, iteration_count,
|
||||
num_new_trainable_outputs, new_op_with_trainable_output))
|
||||
while len(outputs) < len(self._func_graph.outputs):
|
||||
outputs = list(self._func_graph.outputs)
|
||||
self._build_functions_for_outputs(
|
||||
outputs, inference_args, input_tangents)
|
||||
|
||||
(forward_function, forward_graph,
|
||||
backward_function, output_indices, num_output_tangents) = (
|
||||
self._build_functions_for_outputs(
|
||||
outputs, inference_args, input_tangents))
|
||||
if (len(self._func_graph.outputs) > len(outputs)
|
||||
and any(backprop_util.IsTrainable(output)
|
||||
for output in self._func_graph.outputs[len(outputs):])):
|
||||
if len(self._func_graph.outputs) != len(outputs):
|
||||
raise AssertionError(
|
||||
("Unexpectedly added new outputs to the forward function when "
|
||||
"building the backward function: {}").format(
|
||||
|
@ -442,11 +442,6 @@ class FuncGraph(ops.Graph):
|
||||
return self._fallback_outer_graph
|
||||
return current
|
||||
|
||||
@outer_graph.setter
|
||||
def outer_graph(self, new_outer_graph):
|
||||
"""Sets `outer_graph` to `new_outer_graph`."""
|
||||
self._weak_outer_graph = weakref.ref(new_outer_graph)
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return [t.dtype for t in self.outputs]
|
||||
|
@ -42,7 +42,6 @@ from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
@ -172,57 +171,6 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(fnWithLoop(), 4.0)
|
||||
|
||||
def checkIteratedGradients(self, func):
|
||||
with context.eager_mode():
|
||||
|
||||
def _Grad(f):
|
||||
def _GradFunction(primal):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(primal)
|
||||
primal_out = f(primal)
|
||||
return tape.gradient(primal_out, primal)
|
||||
return _GradFunction
|
||||
|
||||
f = func
|
||||
one = constant_op.constant(1.)
|
||||
|
||||
for _ in range(3):
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||
def_function.function(f), [one])
|
||||
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
||||
f = _Grad(f)
|
||||
self.assertAllClose(array_ops.reshape(numerical, []),
|
||||
def_function.function(f)(one),
|
||||
rtol=1e-3)
|
||||
|
||||
def testIteratedGradients(self):
|
||||
|
||||
def _Func(x):
|
||||
_, z = while_loop_v2(
|
||||
lambda i, _: i < 2,
|
||||
lambda i, y: (i + 1, math_ops.cos(y)),
|
||||
[0, x])
|
||||
return z
|
||||
|
||||
self.checkIteratedGradients(_Func)
|
||||
|
||||
def testIteratedGradientsWithList(self):
|
||||
|
||||
def _Func(x):
|
||||
results = list_ops.empty_tensor_list(
|
||||
element_shape=[], element_dtype=dtypes.float32)
|
||||
|
||||
def _LoopBody(i, y, handle):
|
||||
return (i + 1, math_ops.cos(y),
|
||||
list_ops.tensor_list_push_back(handle, y))
|
||||
|
||||
_, z, results = while_loop_v2(
|
||||
lambda i, _, h: i < 2, _LoopBody, [0, x, results])
|
||||
return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
|
||||
results, dtypes.float32))
|
||||
|
||||
self.checkIteratedGradients(_Func)
|
||||
|
||||
def testDeviceLabelsInherited(self):
|
||||
def _LoopBody(i, y):
|
||||
result = math_ops.cos(y)
|
||||
|
@ -26,6 +26,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import auto_control_deps_utils as acd
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
||||
return [None] + outputs
|
||||
|
||||
|
||||
def _run_as_function_for_tape_gradients(make_op, cond_inputs):
|
||||
"""Fix higher-order tape gradients by wrapping `make_op` in a function."""
|
||||
# GradientTapes created inside a function currently don't work well with
|
||||
# un-wrapped control flow ops in that same function. Wrapping in an extra
|
||||
# layer of intermediate function means we run extra logic in the function
|
||||
# gradient code to record the correct intermediates on the tape.
|
||||
#
|
||||
# The function attribute inputs to cond/case ops are not hashable, so we pass
|
||||
# everything as a capture to bypass defun's caching.
|
||||
if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
|
||||
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
|
||||
# We only need one function between the tape and the cond; if we've
|
||||
# already wrapped once, we stop wrapping to avoid infinite recursion.
|
||||
and not (ops.get_default_graph().building_function
|
||||
and "cond_gradient_wrapper" in ops.get_default_graph().name)):
|
||||
|
||||
op = None
|
||||
def _run_make_and_extract_op():
|
||||
# Post-processing happens on the cond op, not the function call op.
|
||||
nonlocal op
|
||||
tensors = make_op()
|
||||
op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable
|
||||
return tensors
|
||||
|
||||
return op, function.defun_with_attributes(
|
||||
_run_make_and_extract_op,
|
||||
attributes=dict(func_name="cond_gradient_wrapper"))()
|
||||
else:
|
||||
return _get_op_and_outputs(make_op())
|
||||
|
||||
|
||||
def _build_cond(pred,
|
||||
true_graph,
|
||||
false_graph,
|
||||
@ -268,35 +300,28 @@ def _build_cond(pred,
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_if
|
||||
|
||||
def _make_op(inputs):
|
||||
if_op, tensors = util.get_op_and_outputs(op_fn(
|
||||
def make_op():
|
||||
return op_fn(
|
||||
pred,
|
||||
inputs, [t.dtype for t in true_graph.outputs],
|
||||
cond_inputs, [t.dtype for t in true_graph.outputs],
|
||||
util.create_new_tf_function(true_graph),
|
||||
util.create_new_tf_function(false_graph),
|
||||
output_shapes=_get_output_shapes(true_graph.outputs,
|
||||
false_graph.outputs),
|
||||
name=name))
|
||||
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
|
||||
# `if_op` is None if this is a `StatelessIf` op with no outputs.
|
||||
if if_op is not None:
|
||||
# The true and false graphs have already been created, and we need that
|
||||
# to happen before we know which tensors will be captured and so whether
|
||||
# to wrap the cond in a tf.function. Post-hoc mutation of the branch
|
||||
# `outer_graph` properties seems like the only option if we want to
|
||||
# conditionally wrap in a function.
|
||||
true_graph.outer_graph = ops.get_default_graph()
|
||||
false_graph.outer_graph = ops.get_default_graph()
|
||||
if_op._true_graph = true_graph
|
||||
if_op._false_graph = false_graph
|
||||
util.maybe_set_lowering_attr(if_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(if_op)
|
||||
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
if_op.graph.prevent_fetching(if_op)
|
||||
return tensors
|
||||
tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
|
||||
name=name)
|
||||
if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
|
||||
|
||||
# `if_op` is None if this is a `StatelessIf` op with no outputs.
|
||||
if if_op is not None:
|
||||
if_op._true_graph = true_graph
|
||||
if_op._false_graph = false_graph
|
||||
util.maybe_set_lowering_attr(if_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(if_op)
|
||||
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
if_op.graph.prevent_fetching(if_op)
|
||||
|
||||
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
|
||||
# Return identities for each output of the If op, rather than the output of
|
||||
# the If op directly. This makes pruning work if the output of cond() is
|
||||
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
|
||||
@ -693,6 +718,15 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
||||
branch_graph.structured_outputs, branch_graph.outputs)
|
||||
|
||||
|
||||
def _get_op_and_outputs(op_or_outputs):
|
||||
if isinstance(op_or_outputs, ops.Operation):
|
||||
return op_or_outputs, []
|
||||
elif not op_or_outputs: # Empty list.
|
||||
return None, []
|
||||
else:
|
||||
return op_or_outputs[0].op, op_or_outputs
|
||||
|
||||
|
||||
def _pack_sequence_as(structured_outputs, op_outputs):
|
||||
"""Packs the outputs of the gradient If/Case op.
|
||||
|
||||
@ -1156,23 +1190,24 @@ def _build_case(branch_index,
|
||||
with ops.control_dependencies(
|
||||
sum((list(bg.control_captures) for bg in branch_graphs), [])):
|
||||
|
||||
def _make_op(inputs):
|
||||
case_op, tensors = util.get_op_and_outputs(op_fn(
|
||||
def _make_op():
|
||||
return op_fn(
|
||||
branch_index,
|
||||
inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||
[util.create_new_tf_function(g) for g in branch_graphs],
|
||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||
name=name))
|
||||
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
|
||||
if case_op is not None:
|
||||
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
case_op.graph.prevent_fetching(case_op)
|
||||
return tensors
|
||||
tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
|
||||
name=name)
|
||||
case_op, tensors = _run_as_function_for_tape_gradients(
|
||||
_make_op, case_inputs)
|
||||
|
||||
if case_op is not None:
|
||||
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
case_op.graph.prevent_fetching(case_op)
|
||||
|
||||
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
|
||||
# Return identities for each output of the Case op, rather than the output of
|
||||
# the Case op directly. This makes pruning work if the output of switch_case()
|
||||
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.func_graph import FuncGraph
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.ops import gradients_util
|
||||
from tensorflow.python.util import keras_deps
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
@ -313,48 +312,3 @@ def get_func_graph(op, input_shapes, func_name):
|
||||
func_graph = function_def_to_graph.function_def_to_graph(
|
||||
fdef, input_shapes)
|
||||
return func_graph
|
||||
|
||||
|
||||
def get_op_and_outputs(op_or_outputs):
|
||||
if isinstance(op_or_outputs, ops.Operation):
|
||||
return op_or_outputs, []
|
||||
elif not op_or_outputs: # Empty list.
|
||||
return None, []
|
||||
else:
|
||||
return op_or_outputs[0].op, op_or_outputs
|
||||
|
||||
|
||||
def run_as_function_for_tape_gradients(make_op, inputs):
|
||||
"""Fix higher-order tape gradients by wrapping `make_op` in a function.
|
||||
|
||||
Args:
|
||||
make_op: A function that takes a list of inputs and returns a list of output
|
||||
tensors. This function should set any handle data relevant to its outputs
|
||||
before returning.
|
||||
inputs: A list of tensors to check for tape gradients and pass to
|
||||
`make_op`. These should include all tensors used in `make_op`.
|
||||
|
||||
Returns:
|
||||
Tensors corresponding to `make_op`'s output.
|
||||
"""
|
||||
# GradientTapes created inside a function currently don't work well with
|
||||
# un-wrapped control flow ops in that same function. Wrapping in an extra
|
||||
# layer of intermediate function means we run extra logic in the function
|
||||
# gradient code to record the correct intermediates on the tape.
|
||||
#
|
||||
# The function attribute inputs to control flow ops are not hashable, so we
|
||||
# pass everything as a capture to bypass defun's caching.
|
||||
if (gradients_util.PossibleTapeGradientTypes(inputs)
|
||||
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
|
||||
# We only need one function between the tape and the op; if we've already
|
||||
# wrapped once, we stop wrapping to avoid infinite recursion.
|
||||
and not (ops.get_default_graph().building_function
|
||||
and "cflow_gradient_wrapper" in ops.get_default_graph().name)):
|
||||
results = function.defun_with_attributes(
|
||||
make_op,
|
||||
autograph=False,
|
||||
attributes=dict(func_name="cflow_gradient_wrapper"))(inputs)
|
||||
return results
|
||||
else:
|
||||
return make_op(inputs)
|
||||
|
||||
|
@ -429,20 +429,19 @@ def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
|
||||
else:
|
||||
op_fn = gen_functional_ops.stateless_while
|
||||
|
||||
def _make_op(inputs):
|
||||
while_op, tensors = util.get_op_and_outputs(op_fn(
|
||||
inputs,
|
||||
util.create_new_tf_function(cond_graph),
|
||||
util.create_new_tf_function(body_graph),
|
||||
output_shapes=output_shapes,
|
||||
parallel_iterations=parallel_iterations,
|
||||
name=name))
|
||||
_copy_handle_data(body_graph.outputs, tensors)
|
||||
util.maybe_set_lowering_attr(while_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(while_op)
|
||||
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
|
||||
return tensors
|
||||
return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
|
||||
outputs = op_fn(
|
||||
loop_vars,
|
||||
util.create_new_tf_function(cond_graph),
|
||||
util.create_new_tf_function(body_graph),
|
||||
output_shapes=output_shapes,
|
||||
parallel_iterations=parallel_iterations,
|
||||
name=name)
|
||||
while_op = outputs[0].op
|
||||
_copy_handle_data(body_graph.outputs, outputs)
|
||||
util.maybe_set_lowering_attr(while_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(while_op)
|
||||
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
|
||||
return outputs
|
||||
|
||||
|
||||
def _get_intermediates(func_graph):
|
||||
@ -823,7 +822,7 @@ def _get_accumulator(tensor):
|
||||
# tf.defun adds an Identity for each output, check whether that is the case.
|
||||
identity_op = t.consumers()[0]
|
||||
if (identity_op.type == "Identity" and
|
||||
any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
|
||||
identity_op.outputs[0] in tensor.graph.outputs):
|
||||
return identity_op.outputs[0]
|
||||
return None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user