Support taking gradients of tf.cond and tf.while_loop using LookupTable.
For tf.cond, this required that we don't create a default zeros output grad when the output grad for all branch functions is None. E.g. since LookupTable ops are marked non-differentiable the output gradient wrt the LookupTable resource tensor is always None. Right now we try to convert that to a zeros tensor which is not supported. Also added support for tf.cond v2 to have branch functions with no outputs. This is necessary now that we may have grad If ops with no outputs. In tf.while_loop, since a captured LookupTable resource is a loop output as well, due to the requirement for matching input and output signatures, gradients_util tries to create a default gradient for the LookupTable which is not supported. So in gradients_util we now check whether the resource is a differentiable resource before building the default grad. Hopefully we can avoid this once we have explicit captures in While. PiperOrigin-RevId: 277099963 Change-Id: Ib1e87fe42213bd10294d63c6ed4e77859489f1ce
This commit is contained in:
parent
2b4c48b2a4
commit
6d7211299d
tensorflow/python
autograph/converters
keras/layers
kernel_tests
ops
@ -36,13 +36,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
|
||||
def _create_cond_branch(self, body_name, aliased_orig_names,
|
||||
aliased_new_names, body, returns):
|
||||
if not returns:
|
||||
# TODO(b/110167197): Replace with a plain return.
|
||||
template = """
|
||||
return 1
|
||||
"""
|
||||
return_stmt = templates.replace(template)
|
||||
elif len(returns) == 1:
|
||||
if len(returns) == 1:
|
||||
template = """
|
||||
return retval
|
||||
"""
|
||||
|
@ -716,13 +716,6 @@ def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias,
|
||||
time_major=time_major,
|
||||
go_backwards=go_backwards,
|
||||
sequence_lengths=sequence_lengths)
|
||||
# Note that mask is a boolean tensor, which doesn't need to do gradient
|
||||
# calculation, when using tf.cond, a default gradient is added for it,
|
||||
# which then cause the backward function to have a signature mismatch.
|
||||
# Force the mask to not generate gradient to allow implementation_selector
|
||||
# to work properly.
|
||||
# TODO(b/80444525): Remove the stop_gradient().
|
||||
mask = array_ops.stop_gradient(mask)
|
||||
|
||||
def input_right_padded():
|
||||
return cudnn_gru(
|
||||
@ -1467,13 +1460,6 @@ def lstm_with_backend_selection(inputs, init_h, init_c, kernel,
|
||||
time_major=time_major,
|
||||
go_backwards=go_backwards,
|
||||
sequence_lengths=sequence_lengths)
|
||||
# Note that mask is a boolean tensor, which doesn't need to do gradient
|
||||
# calculation, when using tf.cond, a default gradient is added for it,
|
||||
# which then cause the backward function to have a signature mismatch.
|
||||
# Force the mask to not generate gradient to allow implementation_selector
|
||||
# to work properly.
|
||||
# TODO(b/80444525): Remove the stop_gradient().
|
||||
mask = array_ops.stop_gradient(mask)
|
||||
|
||||
def input_right_padded():
|
||||
return cudnn_lstm(
|
||||
|
@ -26,7 +26,9 @@ from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import counter
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -36,6 +38,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import variables
|
||||
@ -397,6 +400,53 @@ class StaticHashTableTest(BaseLookupTableTest):
|
||||
self.assertAllEqual([10, -1, 5], self.evaluate(result1))
|
||||
self.assertAllEqual([10, -1, 5], self.evaluate(result2))
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testLookupTableInWhileV2(self):
|
||||
lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
||||
constant_op.constant([2, 5], dtype=dtypes.int64),
|
||||
constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
|
||||
|
||||
beta = variables.Variable(1.0, trainable=True)
|
||||
|
||||
@def_function.function
|
||||
def get_loss(unused_beta):
|
||||
return map_fn.map_fn(
|
||||
lookup.lookup,
|
||||
constant_op.constant([2, 3], dtype=dtypes.int64),
|
||||
dtype=dtypes.float32)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = get_loss(beta)
|
||||
|
||||
self.assertIsNone(tape.gradient(loss, beta))
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testLookupTableInCondV2(self):
|
||||
lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
|
||||
constant_op.constant([2, 5], dtype=dtypes.int64),
|
||||
constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
|
||||
|
||||
beta = variables.Variable(1.0, trainable=True)
|
||||
|
||||
@def_function.function
|
||||
def get_loss(beta):
|
||||
|
||||
def true_fn():
|
||||
return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64))
|
||||
|
||||
def false_fn():
|
||||
return constant_op.constant(0, dtype=dtypes.float32)
|
||||
|
||||
return beta * control_flow_ops.cond(
|
||||
constant_op.constant(True), true_fn=true_fn, false_fn=false_fn)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = get_loss(beta)
|
||||
grad = tape.gradient(loss, beta)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
self.assertAllEqual(grad, -10.)
|
||||
|
||||
|
||||
class KeyValueTensorInitializerTest(BaseLookupTableTest):
|
||||
|
||||
|
@ -38,7 +38,6 @@ from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import default_gradient
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_functional_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gradients_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
@ -121,6 +120,11 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
||||
false_grad_graph = _create_grad_func(
|
||||
false_graph, grads, util.unique_grad_fn_name(false_graph.name))
|
||||
|
||||
# Replaces output None grads with zeros if atleast one branch has non-None
|
||||
# grad at that index.
|
||||
_create_zeros_for_none_grads([true_graph, false_graph],
|
||||
[true_grad_graph, false_grad_graph])
|
||||
|
||||
if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
|
||||
# Modify 'op' to output the intermediates needed by the grad functions. Note
|
||||
# that all needed intermediates are wrapped in optionals. Each optional
|
||||
@ -219,8 +223,6 @@ def _build_cond(pred,
|
||||
# this modifies true_graph and false_graph.
|
||||
cond_inputs = _make_inputs_match([true_graph, false_graph],
|
||||
[true_inputs, false_inputs])
|
||||
# Save the original number of outputs to return to the caller.
|
||||
num_cond_outputs = len(true_graph.outputs)
|
||||
# We do not output intermediates of the gradient If op since this is just
|
||||
# for backwards compatibility with existing code.
|
||||
if not building_gradient and util.output_all_intermediates():
|
||||
@ -270,12 +272,15 @@ def _build_cond(pred,
|
||||
false_graph.outputs),
|
||||
name=name)
|
||||
|
||||
# TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
|
||||
if_op = tensors[0].op
|
||||
if_op._true_graph = true_graph
|
||||
if_op._false_graph = false_graph
|
||||
util.maybe_set_lowering_attr(if_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(if_op)
|
||||
if_op, tensors = _get_op_and_outputs(tensors)
|
||||
# `if_op` is None if this is a `StatelessIf` op with no outputs.
|
||||
if if_op is not None:
|
||||
if_op._true_graph = true_graph
|
||||
if_op._false_graph = false_graph
|
||||
util.maybe_set_lowering_attr(if_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(if_op)
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
if_op.graph.prevent_fetching(if_op)
|
||||
|
||||
# Return identities for each output of the If op, rather than the output of
|
||||
# the If op directly. This makes pruning work if the output of cond() is
|
||||
@ -287,10 +292,7 @@ def _build_cond(pred,
|
||||
# correct output structure
|
||||
tensors = [array_ops.identity(t) for t in tensors]
|
||||
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
if_op.graph.prevent_fetching(if_op)
|
||||
return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
|
||||
tensors[:num_cond_outputs])
|
||||
return _pack_sequence_as(true_graph.structured_outputs, tensors)
|
||||
|
||||
|
||||
def get_func_graphs(op):
|
||||
@ -368,18 +370,6 @@ def _grad_fn(func_graph, grads):
|
||||
ys, func_graph.inputs, grad_ys=grad_ys,
|
||||
src_graph=func_graph)
|
||||
|
||||
# Functions can't return None; replace Nones with zero tensors.
|
||||
# TODO(b/80444525): don't return anything here and make _IfGrad return None if
|
||||
# both branches have zero gradient.
|
||||
for i in range(len(result)):
|
||||
if result[i] is None:
|
||||
if func_graph.inputs[i].dtype == dtypes.resource:
|
||||
result[i] = array_ops.zeros(
|
||||
gen_resource_variable_ops.variable_shape(func_graph.inputs[i]),
|
||||
dtype=default_gradient.get_zeros_dtype(func_graph.inputs[i]))
|
||||
else:
|
||||
result[i] = array_ops.zeros_like(func_graph.inputs[i])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -546,6 +536,34 @@ def _make_inputs_match(branch_graphs, branch_inputs):
|
||||
return new_inputs
|
||||
|
||||
|
||||
def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
|
||||
"""Creates zeros for None out grads if atleast one branch has non-None grad.
|
||||
|
||||
Args:
|
||||
forward_graphs: List of forward FuncGraphs.
|
||||
grad_graphs: List of grad FuncGraphs.
|
||||
"""
|
||||
assert len(forward_graphs) == len(grad_graphs)
|
||||
branch_outputs = [g.structured_outputs for g in grad_graphs]
|
||||
num_outputs_per_branch = [len(outs) for outs in branch_outputs]
|
||||
assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
|
||||
for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
|
||||
if (any(t is None for t in branch_outs) and
|
||||
any(t is not None for t in branch_outs)):
|
||||
for branch_index, t in enumerate(branch_outs):
|
||||
if t is None:
|
||||
with grad_graphs[branch_index].as_default():
|
||||
zeros = default_gradient.zeros_like(
|
||||
forward_graphs[branch_index].inputs[output_idx])
|
||||
grad_graphs[branch_index].structured_outputs[output_idx] = zeros
|
||||
|
||||
for grad_graph in grad_graphs:
|
||||
grad_graph.outputs = [
|
||||
t for t in func_graph_module.flatten(grad_graph.structured_outputs)
|
||||
if t is not None
|
||||
]
|
||||
|
||||
|
||||
def _make_output_composite_tensors_match(op_type, branch_graphs):
|
||||
"""Modifies each branch_graph's outputs to have the same output signature.
|
||||
|
||||
@ -591,7 +609,9 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
|
||||
|
||||
for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
|
||||
branch_graph.structured_outputs = branch_outs
|
||||
branch_graph.outputs = func_graph_module.flatten(branch_outs)
|
||||
branch_graph.outputs = [
|
||||
t for t in func_graph_module.flatten(branch_outs) if t is not None
|
||||
]
|
||||
|
||||
|
||||
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
||||
@ -646,10 +666,46 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
||||
branch_graph.outputs[index], dtypes.int64)
|
||||
|
||||
for branch_graph in branch_graphs:
|
||||
branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
||||
branch_graph.structured_outputs = _pack_sequence_as(
|
||||
branch_graph.structured_outputs, branch_graph.outputs)
|
||||
|
||||
|
||||
def _get_op_and_outputs(op_or_outputs):
|
||||
if isinstance(op_or_outputs, ops.Operation):
|
||||
return op_or_outputs, []
|
||||
elif not op_or_outputs: # Empty list.
|
||||
return None, []
|
||||
else:
|
||||
return op_or_outputs[0].op, op_or_outputs
|
||||
|
||||
|
||||
def _pack_sequence_as(structured_outputs, op_outputs):
|
||||
"""Packs the outputs of the gradient If/Case op.
|
||||
|
||||
The branch functions may contain None's in the list of `structured_outputs`.
|
||||
`op_outputs` has those outputs missing. So we need to add those Nones to the
|
||||
list of `op_outputs` and then pack it in the same structure as
|
||||
`structured_outputs`.
|
||||
|
||||
Args:
|
||||
structured_outputs: structured_outputs from one of the branch functions.
|
||||
op_outputs: List of output tensors of the op.
|
||||
|
||||
Returns:
|
||||
`op_outputs` packed like `structured_outputs`.
|
||||
"""
|
||||
outputs_with_nones = []
|
||||
counter = 0
|
||||
for output in nest.flatten(structured_outputs, expand_composites=True):
|
||||
if output is None:
|
||||
outputs_with_nones.append(None)
|
||||
else:
|
||||
outputs_with_nones.append(op_outputs[counter])
|
||||
counter += 1
|
||||
return func_graph_module.pack_sequence_as(structured_outputs,
|
||||
outputs_with_nones)
|
||||
|
||||
|
||||
def _wrap_intermediates(func_graph, intermediates):
|
||||
with func_graph.as_default():
|
||||
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
||||
@ -933,6 +989,9 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||
branch_grad_graphs.append(
|
||||
_create_grad_func(branch_graph, grads,
|
||||
util.unique_grad_fn_name(branch_graph.name)))
|
||||
# Replaces output None grads with zeros if atleast one branch has non-None
|
||||
# grad at that index.
|
||||
_create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
|
||||
|
||||
if any(g.op_needs_rewrite for g in branch_grad_graphs):
|
||||
# Modify 'op' to output the intermediates needed by the grad functions. Note
|
||||
@ -1033,10 +1092,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||
name=name)
|
||||
|
||||
# TODO(b/110167197): this requires Case to have at least 1 output
|
||||
case_op = tensors[0].op
|
||||
util.maybe_set_lowering_attr(case_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||
case_op, tensors = _get_op_and_outputs(tensors)
|
||||
|
||||
if case_op is not None:
|
||||
util.maybe_set_lowering_attr(case_op)
|
||||
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
case_op.graph.prevent_fetching(case_op)
|
||||
|
||||
# Return identities for each output of the Case op, rather than the output of
|
||||
# the Case op directly. This makes pruning work if the output of switch_case()
|
||||
@ -1048,7 +1110,4 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||
# correct output structure
|
||||
tensors = [array_ops.identity(t) for t in tensors]
|
||||
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
case_op.graph.prevent_fetching(case_op)
|
||||
return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs,
|
||||
tensors)
|
||||
return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
|
||||
|
@ -63,3 +63,22 @@ def ones_like(t):
|
||||
return array_ops.ones(*shape_and_dtype(t))
|
||||
else:
|
||||
return array_ops.ones_like(t)
|
||||
|
||||
|
||||
def supports_default_grad(t):
|
||||
"""Whether tensor `t` supports creating a default gradient.
|
||||
|
||||
This function assumes that `t` is of a trainable type.
|
||||
|
||||
Args:
|
||||
t: Tensor
|
||||
|
||||
Returns:
|
||||
Bool
|
||||
"""
|
||||
if t.dtype == dtypes.resource:
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
||||
if (handle_data is None or not handle_data.is_set or
|
||||
len(handle_data.shape_and_type) != 1):
|
||||
return False
|
||||
return True
|
||||
|
@ -653,7 +653,10 @@ def _GradientsHelper(ys,
|
||||
# issue here because of zeros.
|
||||
if loop_state:
|
||||
out_grads[i] = loop_state.ZerosLike(op, i)
|
||||
else:
|
||||
elif default_gradient.supports_default_grad(op.outputs[i]):
|
||||
# TODO(b/143286622): The supports_default_grad check is needed
|
||||
# because While op emits non-differentiable resource tensors
|
||||
# as outputs. Remove this check when that is not the case.
|
||||
out_grads[i] = control_flow_state.ZerosLikeOutsideLoop(op, i)
|
||||
with ops.name_scope(op.name + "_grad"):
|
||||
# pylint: disable=protected-access
|
||||
|
@ -497,8 +497,11 @@ def _preprocess_grad(grad, body_graph_output, while_op_output):
|
||||
# GradientTape initializes resource and variant grads as None instead of
|
||||
# zeros. Set to zeros so _GradientsHelper computes the gradients instead of
|
||||
# returning None.
|
||||
if (while_op_output.dtype in (dtypes.resource, dtypes.variant)
|
||||
and grad is None):
|
||||
# TODO(b/143286622): The supports_default_grad check is needed
|
||||
# because While op emits non-differentiable resource tensors
|
||||
# as outputs. Remove this check when that is not the case.
|
||||
if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
|
||||
default_gradient.supports_default_grad(while_op_output) and grad is None):
|
||||
return _zeros_like(while_op_output)
|
||||
|
||||
return grad
|
||||
|
Loading…
Reference in New Issue
Block a user