Rolls back cl/338294004 and part of cl/338299497
PiperOrigin-RevId: 338384800 Change-Id: I7cf4d230da1a9de064552b464055ee8675aa7fa4
This commit is contained in:
parent
37f7c75a71
commit
fb22dff317
@ -43,7 +43,6 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import custom_gradient
|
from tensorflow.python.ops import custom_gradient
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gradient_checker_v2
|
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
@ -1714,35 +1713,6 @@ class JacobianTest(test.TestCase):
|
|||||||
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
|
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
|
||||||
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
|
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
|
||||||
|
|
||||||
def test_nested_batch_jacobian_foldl(self):
|
|
||||||
def _grad(f):
|
|
||||||
def _grad_function(primal):
|
|
||||||
with backprop.GradientTape() as tape:
|
|
||||||
tape.watch(primal)
|
|
||||||
primal_out = f(primal)
|
|
||||||
return tape.batch_jacobian(primal_out, primal)
|
|
||||||
return _grad_function
|
|
||||||
|
|
||||||
def _func(x):
|
|
||||||
return array_ops.reshape(
|
|
||||||
functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
|
|
||||||
array_ops.transpose(x)),
|
|
||||||
[1, 1])
|
|
||||||
|
|
||||||
f = _func
|
|
||||||
x = constant_op.constant([[1., 2.]])
|
|
||||||
for _ in range(2):
|
|
||||||
theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
|
|
||||||
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
|
||||||
f = _grad(f)
|
|
||||||
expected_flat = array_ops.reshape(numerical, [-1])
|
|
||||||
self.assertAllClose(expected_flat,
|
|
||||||
array_ops.reshape(f(x), [-1]),
|
|
||||||
rtol=1e-3)
|
|
||||||
self.assertAllClose(expected_flat,
|
|
||||||
array_ops.reshape(def_function.function(f)(x), [-1]),
|
|
||||||
rtol=1e-3)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_indexed_slices(self):
|
def test_indexed_slices(self):
|
||||||
with backprop.GradientTape(persistent=True) as g:
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
|
@ -19,35 +19,12 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import handle_data_util
|
|
||||||
|
|
||||||
|
|
||||||
def _DTypeFromTensor(tensor):
|
|
||||||
"""Extract either `tensor.dtype` or the unanimous sub-type of a variant."""
|
|
||||||
dtype = tensor.dtype
|
|
||||||
if dtype.base_dtype == dtypes.variant:
|
|
||||||
# If we know statically that the data a variant points to is non-trainable
|
|
||||||
# then the variant itself is non-trainable.
|
|
||||||
if isinstance(tensor, ops.EagerTensor):
|
|
||||||
handle_data = tensor._handle_data # pylint: disable=protected-access
|
|
||||||
else:
|
|
||||||
handle_data = handle_data_util.get_resource_handle_data(tensor)
|
|
||||||
if (handle_data is not None
|
|
||||||
and handle_data.is_set
|
|
||||||
and handle_data.shape_and_type):
|
|
||||||
first_type = handle_data.shape_and_type[0].dtype
|
|
||||||
if all(shape_and_type.dtype == first_type
|
|
||||||
for shape_and_type in handle_data.shape_and_type):
|
|
||||||
return first_type
|
|
||||||
return dtype
|
|
||||||
|
|
||||||
|
|
||||||
def IsTrainable(tensor_or_dtype):
|
def IsTrainable(tensor_or_dtype):
|
||||||
"""Determines whether a tensor or dtype supports infinitesimal changes."""
|
|
||||||
if tensor_util.is_tensor(tensor_or_dtype):
|
if tensor_util.is_tensor(tensor_or_dtype):
|
||||||
dtype = _DTypeFromTensor(tensor_or_dtype)
|
dtype = tensor_or_dtype.dtype
|
||||||
else:
|
else:
|
||||||
dtype = tensor_or_dtype
|
dtype = tensor_or_dtype
|
||||||
dtype = dtypes.as_dtype(dtype)
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
@ -1387,7 +1387,6 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
|
|||||||
gradients with respect to the inputs.
|
gradients with respect to the inputs.
|
||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
iteration_count = 0
|
|
||||||
# First we need to figure out how many side outputs from the forward pass
|
# First we need to figure out how many side outputs from the forward pass
|
||||||
# will be required. We do this in a temporary graph to avoid actually
|
# will be required. We do this in a temporary graph to avoid actually
|
||||||
# running multiple copies of the backward pass (one per _GradientsHelper
|
# running multiple copies of the backward pass (one per _GradientsHelper
|
||||||
@ -1402,42 +1401,15 @@ class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
|
|||||||
# all of the forward op's outputs: symbolic gradients with tf.gradients
|
# all of the forward op's outputs: symbolic gradients with tf.gradients
|
||||||
# instead rely on regenerating backward functions when higher-order
|
# instead rely on regenerating backward functions when higher-order
|
||||||
# gradients are requested.
|
# gradients are requested.
|
||||||
while (len(outputs) < len(self._func_graph.outputs)
|
while len(outputs) < len(self._func_graph.outputs):
|
||||||
# It's possible for gradient generation to add new ops to the forward
|
|
||||||
# pass. If all of the new outputs are non-trainable, there's no
|
|
||||||
# reason to continue.
|
|
||||||
and any(backprop_util.IsTrainable(output)
|
|
||||||
for output in self._func_graph.outputs[len(outputs):])):
|
|
||||||
iteration_count += 1
|
|
||||||
if iteration_count >= 20 and iteration_count % 5 == 0:
|
|
||||||
new_op_with_trainable_output = None
|
|
||||||
num_new_trainable_outputs = 0
|
|
||||||
for output in self._func_graph.outputs[len(outputs):]:
|
|
||||||
if backprop_util.IsTrainable(output):
|
|
||||||
num_new_trainable_outputs += 1
|
|
||||||
new_op_with_trainable_output = output.op
|
|
||||||
logging.warning(
|
|
||||||
("Determining side outputs for the function '{}' is taking longer "
|
|
||||||
"than expected ({} iterations, typically this converges in 5 or "
|
|
||||||
"so). This could indicate that a gradient registration is adding "
|
|
||||||
"new ops to the forward pass every time gradients are generated. "
|
|
||||||
"{} new trainable output(s) were added this iteration, one from "
|
|
||||||
"the following op:\n {}\nThis may indicate a TensorFlow bug, or "
|
|
||||||
"an issue in a tf.custom_gradient.")
|
|
||||||
.format(
|
|
||||||
self._func_graph.name, iteration_count,
|
|
||||||
num_new_trainable_outputs, new_op_with_trainable_output))
|
|
||||||
outputs = list(self._func_graph.outputs)
|
outputs = list(self._func_graph.outputs)
|
||||||
self._build_functions_for_outputs(
|
self._build_functions_for_outputs(
|
||||||
outputs, inference_args, input_tangents)
|
outputs, inference_args, input_tangents)
|
||||||
|
|
||||||
(forward_function, forward_graph,
|
(forward_function, forward_graph,
|
||||||
backward_function, output_indices, num_output_tangents) = (
|
backward_function, output_indices, num_output_tangents) = (
|
||||||
self._build_functions_for_outputs(
|
self._build_functions_for_outputs(
|
||||||
outputs, inference_args, input_tangents))
|
outputs, inference_args, input_tangents))
|
||||||
if (len(self._func_graph.outputs) > len(outputs)
|
if len(self._func_graph.outputs) != len(outputs):
|
||||||
and any(backprop_util.IsTrainable(output)
|
|
||||||
for output in self._func_graph.outputs[len(outputs):])):
|
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
("Unexpectedly added new outputs to the forward function when "
|
("Unexpectedly added new outputs to the forward function when "
|
||||||
"building the backward function: {}").format(
|
"building the backward function: {}").format(
|
||||||
|
@ -442,11 +442,6 @@ class FuncGraph(ops.Graph):
|
|||||||
return self._fallback_outer_graph
|
return self._fallback_outer_graph
|
||||||
return current
|
return current
|
||||||
|
|
||||||
@outer_graph.setter
|
|
||||||
def outer_graph(self, new_outer_graph):
|
|
||||||
"""Sets `outer_graph` to `new_outer_graph`."""
|
|
||||||
self._weak_outer_graph = weakref.ref(new_outer_graph)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_types(self):
|
def output_types(self):
|
||||||
return [t.dtype for t in self.outputs]
|
return [t.dtype for t in self.outputs]
|
||||||
|
@ -42,7 +42,6 @@ from tensorflow.python.ops import control_flow_util_v2
|
|||||||
from tensorflow.python.ops import control_flow_v2_toggles
|
from tensorflow.python.ops import control_flow_v2_toggles
|
||||||
from tensorflow.python.ops import custom_gradient
|
from tensorflow.python.ops import custom_gradient
|
||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.ops import gradient_checker_v2
|
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import list_ops
|
from tensorflow.python.ops import list_ops
|
||||||
from tensorflow.python.ops import map_fn
|
from tensorflow.python.ops import map_fn
|
||||||
@ -172,57 +171,6 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(fnWithLoop(), 4.0)
|
self.assertAllEqual(fnWithLoop(), 4.0)
|
||||||
|
|
||||||
def checkIteratedGradients(self, func):
|
|
||||||
with context.eager_mode():
|
|
||||||
|
|
||||||
def _Grad(f):
|
|
||||||
def _GradFunction(primal):
|
|
||||||
with backprop.GradientTape() as tape:
|
|
||||||
tape.watch(primal)
|
|
||||||
primal_out = f(primal)
|
|
||||||
return tape.gradient(primal_out, primal)
|
|
||||||
return _GradFunction
|
|
||||||
|
|
||||||
f = func
|
|
||||||
one = constant_op.constant(1.)
|
|
||||||
|
|
||||||
for _ in range(3):
|
|
||||||
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
|
||||||
def_function.function(f), [one])
|
|
||||||
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
|
||||||
f = _Grad(f)
|
|
||||||
self.assertAllClose(array_ops.reshape(numerical, []),
|
|
||||||
def_function.function(f)(one),
|
|
||||||
rtol=1e-3)
|
|
||||||
|
|
||||||
def testIteratedGradients(self):
|
|
||||||
|
|
||||||
def _Func(x):
|
|
||||||
_, z = while_loop_v2(
|
|
||||||
lambda i, _: i < 2,
|
|
||||||
lambda i, y: (i + 1, math_ops.cos(y)),
|
|
||||||
[0, x])
|
|
||||||
return z
|
|
||||||
|
|
||||||
self.checkIteratedGradients(_Func)
|
|
||||||
|
|
||||||
def testIteratedGradientsWithList(self):
|
|
||||||
|
|
||||||
def _Func(x):
|
|
||||||
results = list_ops.empty_tensor_list(
|
|
||||||
element_shape=[], element_dtype=dtypes.float32)
|
|
||||||
|
|
||||||
def _LoopBody(i, y, handle):
|
|
||||||
return (i + 1, math_ops.cos(y),
|
|
||||||
list_ops.tensor_list_push_back(handle, y))
|
|
||||||
|
|
||||||
_, z, results = while_loop_v2(
|
|
||||||
lambda i, _, h: i < 2, _LoopBody, [0, x, results])
|
|
||||||
return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
|
|
||||||
results, dtypes.float32))
|
|
||||||
|
|
||||||
self.checkIteratedGradients(_Func)
|
|
||||||
|
|
||||||
def testDeviceLabelsInherited(self):
|
def testDeviceLabelsInherited(self):
|
||||||
def _LoopBody(i, y):
|
def _LoopBody(i, y):
|
||||||
result = math_ops.cos(y)
|
result = math_ops.cos(y)
|
||||||
|
@ -26,6 +26,7 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop_util
|
from tensorflow.python.eager import backprop_util
|
||||||
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import auto_control_deps
|
from tensorflow.python.framework import auto_control_deps
|
||||||
from tensorflow.python.framework import auto_control_deps_utils as acd
|
from tensorflow.python.framework import auto_control_deps_utils as acd
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -192,6 +193,37 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
return [None] + outputs
|
return [None] + outputs
|
||||||
|
|
||||||
|
|
||||||
|
def _run_as_function_for_tape_gradients(make_op, cond_inputs):
|
||||||
|
"""Fix higher-order tape gradients by wrapping `make_op` in a function."""
|
||||||
|
# GradientTapes created inside a function currently don't work well with
|
||||||
|
# un-wrapped control flow ops in that same function. Wrapping in an extra
|
||||||
|
# layer of intermediate function means we run extra logic in the function
|
||||||
|
# gradient code to record the correct intermediates on the tape.
|
||||||
|
#
|
||||||
|
# The function attribute inputs to cond/case ops are not hashable, so we pass
|
||||||
|
# everything as a capture to bypass defun's caching.
|
||||||
|
if (gradients_util.PossibleTapeGradientTypes(cond_inputs)
|
||||||
|
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
|
||||||
|
# We only need one function between the tape and the cond; if we've
|
||||||
|
# already wrapped once, we stop wrapping to avoid infinite recursion.
|
||||||
|
and not (ops.get_default_graph().building_function
|
||||||
|
and "cond_gradient_wrapper" in ops.get_default_graph().name)):
|
||||||
|
|
||||||
|
op = None
|
||||||
|
def _run_make_and_extract_op():
|
||||||
|
# Post-processing happens on the cond op, not the function call op.
|
||||||
|
nonlocal op
|
||||||
|
tensors = make_op()
|
||||||
|
op, tensors = _get_op_and_outputs(tensors) # pylint: disable=unused-variable
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
return op, function.defun_with_attributes(
|
||||||
|
_run_make_and_extract_op,
|
||||||
|
attributes=dict(func_name="cond_gradient_wrapper"))()
|
||||||
|
else:
|
||||||
|
return _get_op_and_outputs(make_op())
|
||||||
|
|
||||||
|
|
||||||
def _build_cond(pred,
|
def _build_cond(pred,
|
||||||
true_graph,
|
true_graph,
|
||||||
false_graph,
|
false_graph,
|
||||||
@ -268,25 +300,19 @@ def _build_cond(pred,
|
|||||||
else:
|
else:
|
||||||
op_fn = gen_functional_ops.stateless_if
|
op_fn = gen_functional_ops.stateless_if
|
||||||
|
|
||||||
def _make_op(inputs):
|
def make_op():
|
||||||
if_op, tensors = util.get_op_and_outputs(op_fn(
|
return op_fn(
|
||||||
pred,
|
pred,
|
||||||
inputs, [t.dtype for t in true_graph.outputs],
|
cond_inputs, [t.dtype for t in true_graph.outputs],
|
||||||
util.create_new_tf_function(true_graph),
|
util.create_new_tf_function(true_graph),
|
||||||
util.create_new_tf_function(false_graph),
|
util.create_new_tf_function(false_graph),
|
||||||
output_shapes=_get_output_shapes(true_graph.outputs,
|
output_shapes=_get_output_shapes(true_graph.outputs,
|
||||||
false_graph.outputs),
|
false_graph.outputs),
|
||||||
name=name))
|
name=name)
|
||||||
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
|
if_op, tensors = _run_as_function_for_tape_gradients(make_op, cond_inputs)
|
||||||
|
|
||||||
# `if_op` is None if this is a `StatelessIf` op with no outputs.
|
# `if_op` is None if this is a `StatelessIf` op with no outputs.
|
||||||
if if_op is not None:
|
if if_op is not None:
|
||||||
# The true and false graphs have already been created, and we need that
|
|
||||||
# to happen before we know which tensors will be captured and so whether
|
|
||||||
# to wrap the cond in a tf.function. Post-hoc mutation of the branch
|
|
||||||
# `outer_graph` properties seems like the only option if we want to
|
|
||||||
# conditionally wrap in a function.
|
|
||||||
true_graph.outer_graph = ops.get_default_graph()
|
|
||||||
false_graph.outer_graph = ops.get_default_graph()
|
|
||||||
if_op._true_graph = true_graph
|
if_op._true_graph = true_graph
|
||||||
if_op._false_graph = false_graph
|
if_op._false_graph = false_graph
|
||||||
util.maybe_set_lowering_attr(if_op)
|
util.maybe_set_lowering_attr(if_op)
|
||||||
@ -294,9 +320,8 @@ def _build_cond(pred,
|
|||||||
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
|
_set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
|
||||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||||
if_op.graph.prevent_fetching(if_op)
|
if_op.graph.prevent_fetching(if_op)
|
||||||
return tensors
|
|
||||||
tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)
|
|
||||||
|
|
||||||
|
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
|
||||||
# Return identities for each output of the If op, rather than the output of
|
# Return identities for each output of the If op, rather than the output of
|
||||||
# the If op directly. This makes pruning work if the output of cond() is
|
# the If op directly. This makes pruning work if the output of cond() is
|
||||||
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
|
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
|
||||||
@ -693,6 +718,15 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
|||||||
branch_graph.structured_outputs, branch_graph.outputs)
|
branch_graph.structured_outputs, branch_graph.outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_op_and_outputs(op_or_outputs):
|
||||||
|
if isinstance(op_or_outputs, ops.Operation):
|
||||||
|
return op_or_outputs, []
|
||||||
|
elif not op_or_outputs: # Empty list.
|
||||||
|
return None, []
|
||||||
|
else:
|
||||||
|
return op_or_outputs[0].op, op_or_outputs
|
||||||
|
|
||||||
|
|
||||||
def _pack_sequence_as(structured_outputs, op_outputs):
|
def _pack_sequence_as(structured_outputs, op_outputs):
|
||||||
"""Packs the outputs of the gradient If/Case op.
|
"""Packs the outputs of the gradient If/Case op.
|
||||||
|
|
||||||
@ -1156,23 +1190,24 @@ def _build_case(branch_index,
|
|||||||
with ops.control_dependencies(
|
with ops.control_dependencies(
|
||||||
sum((list(bg.control_captures) for bg in branch_graphs), [])):
|
sum((list(bg.control_captures) for bg in branch_graphs), [])):
|
||||||
|
|
||||||
def _make_op(inputs):
|
def _make_op():
|
||||||
case_op, tensors = util.get_op_and_outputs(op_fn(
|
return op_fn(
|
||||||
branch_index,
|
branch_index,
|
||||||
inputs, [t.dtype for t in branch_graphs[0].outputs],
|
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||||
[util.create_new_tf_function(g) for g in branch_graphs],
|
[util.create_new_tf_function(g) for g in branch_graphs],
|
||||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||||
name=name))
|
name=name)
|
||||||
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
|
case_op, tensors = _run_as_function_for_tape_gradients(
|
||||||
|
_make_op, case_inputs)
|
||||||
|
|
||||||
if case_op is not None:
|
if case_op is not None:
|
||||||
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
|
||||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||||
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
_set_read_only_resource_inputs_attr(case_op, branch_graphs)
|
||||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||||
case_op.graph.prevent_fetching(case_op)
|
case_op.graph.prevent_fetching(case_op)
|
||||||
return tensors
|
|
||||||
tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs)
|
|
||||||
|
|
||||||
|
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
|
||||||
# Return identities for each output of the Case op, rather than the output of
|
# Return identities for each output of the Case op, rather than the output of
|
||||||
# the Case op directly. This makes pruning work if the output of switch_case()
|
# the Case op directly. This makes pruning work if the output of switch_case()
|
||||||
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
||||||
|
@ -28,7 +28,6 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework.func_graph import FuncGraph
|
from tensorflow.python.framework.func_graph import FuncGraph
|
||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import control_flow_v2_func_graphs
|
from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||||
from tensorflow.python.ops import gradients_util
|
|
||||||
from tensorflow.python.util import keras_deps
|
from tensorflow.python.util import keras_deps
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
|
|
||||||
@ -313,48 +312,3 @@ def get_func_graph(op, input_shapes, func_name):
|
|||||||
func_graph = function_def_to_graph.function_def_to_graph(
|
func_graph = function_def_to_graph.function_def_to_graph(
|
||||||
fdef, input_shapes)
|
fdef, input_shapes)
|
||||||
return func_graph
|
return func_graph
|
||||||
|
|
||||||
|
|
||||||
def get_op_and_outputs(op_or_outputs):
|
|
||||||
if isinstance(op_or_outputs, ops.Operation):
|
|
||||||
return op_or_outputs, []
|
|
||||||
elif not op_or_outputs: # Empty list.
|
|
||||||
return None, []
|
|
||||||
else:
|
|
||||||
return op_or_outputs[0].op, op_or_outputs
|
|
||||||
|
|
||||||
|
|
||||||
def run_as_function_for_tape_gradients(make_op, inputs):
|
|
||||||
"""Fix higher-order tape gradients by wrapping `make_op` in a function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
make_op: A function that takes a list of inputs and returns a list of output
|
|
||||||
tensors. This function should set any handle data relevant to its outputs
|
|
||||||
before returning.
|
|
||||||
inputs: A list of tensors to check for tape gradients and pass to
|
|
||||||
`make_op`. These should include all tensors used in `make_op`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensors corresponding to `make_op`'s output.
|
|
||||||
"""
|
|
||||||
# GradientTapes created inside a function currently don't work well with
|
|
||||||
# un-wrapped control flow ops in that same function. Wrapping in an extra
|
|
||||||
# layer of intermediate function means we run extra logic in the function
|
|
||||||
# gradient code to record the correct intermediates on the tape.
|
|
||||||
#
|
|
||||||
# The function attribute inputs to control flow ops are not hashable, so we
|
|
||||||
# pass everything as a capture to bypass defun's caching.
|
|
||||||
if (gradients_util.PossibleTapeGradientTypes(inputs)
|
|
||||||
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
|
|
||||||
# We only need one function between the tape and the op; if we've already
|
|
||||||
# wrapped once, we stop wrapping to avoid infinite recursion.
|
|
||||||
and not (ops.get_default_graph().building_function
|
|
||||||
and "cflow_gradient_wrapper" in ops.get_default_graph().name)):
|
|
||||||
results = function.defun_with_attributes(
|
|
||||||
make_op,
|
|
||||||
autograph=False,
|
|
||||||
attributes=dict(func_name="cflow_gradient_wrapper"))(inputs)
|
|
||||||
return results
|
|
||||||
else:
|
|
||||||
return make_op(inputs)
|
|
||||||
|
|
||||||
|
@ -429,20 +429,19 @@ def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
|
|||||||
else:
|
else:
|
||||||
op_fn = gen_functional_ops.stateless_while
|
op_fn = gen_functional_ops.stateless_while
|
||||||
|
|
||||||
def _make_op(inputs):
|
outputs = op_fn(
|
||||||
while_op, tensors = util.get_op_and_outputs(op_fn(
|
loop_vars,
|
||||||
inputs,
|
|
||||||
util.create_new_tf_function(cond_graph),
|
util.create_new_tf_function(cond_graph),
|
||||||
util.create_new_tf_function(body_graph),
|
util.create_new_tf_function(body_graph),
|
||||||
output_shapes=output_shapes,
|
output_shapes=output_shapes,
|
||||||
parallel_iterations=parallel_iterations,
|
parallel_iterations=parallel_iterations,
|
||||||
name=name))
|
name=name)
|
||||||
_copy_handle_data(body_graph.outputs, tensors)
|
while_op = outputs[0].op
|
||||||
|
_copy_handle_data(body_graph.outputs, outputs)
|
||||||
util.maybe_set_lowering_attr(while_op)
|
util.maybe_set_lowering_attr(while_op)
|
||||||
util.maybe_propagate_compile_time_consts_in_xla(while_op)
|
util.maybe_propagate_compile_time_consts_in_xla(while_op)
|
||||||
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
|
_set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
|
||||||
return tensors
|
return outputs
|
||||||
return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_intermediates(func_graph):
|
def _get_intermediates(func_graph):
|
||||||
@ -823,7 +822,7 @@ def _get_accumulator(tensor):
|
|||||||
# tf.defun adds an Identity for each output, check whether that is the case.
|
# tf.defun adds an Identity for each output, check whether that is the case.
|
||||||
identity_op = t.consumers()[0]
|
identity_op = t.consumers()[0]
|
||||||
if (identity_op.type == "Identity" and
|
if (identity_op.type == "Identity" and
|
||||||
any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
|
identity_op.outputs[0] in tensor.graph.outputs):
|
||||||
return identity_op.outputs[0]
|
return identity_op.outputs[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user