Support taking gradients of tf.cond and tf.while_loop using LookupTable.

For tf.cond, this required that we don't create a default zeros output grad when the output grad for all branch functions is None. E.g. since LookupTable ops are marked non-differentiable the output gradient wrt the LookupTable resource tensor is always None. Right now we try to convert that to a zeros tensor which is not supported.
Also added support for tf.cond v2 to have branch functions with no outputs. This is necessary now that we may have grad If ops with no outputs.

In tf.while_loop, since a captured LookupTable resource is a loop output as well, due to the requirement for matching input and output signatures, gradients_util tries to create a default gradient for the LookupTable which is not supported. So in gradients_util we now check whether the resource is a differentiable resource before building the default grad. Hopefully we can avoid this once we have explicit captures in While.

PiperOrigin-RevId: 277099963
Change-Id: Ib1e87fe42213bd10294d63c6ed4e77859489f1ce
This commit is contained in:
Saurabh Saxena 2019-10-28 11:01:39 -07:00 committed by TensorFlower Gardener
parent 2b4c48b2a4
commit 6d7211299d
7 changed files with 173 additions and 59 deletions

View File

@ -36,13 +36,7 @@ class ControlFlowTransformer(converter.Base):
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if not returns:
# TODO(b/110167197): Replace with a plain return.
template = """
return 1
"""
return_stmt = templates.replace(template)
elif len(returns) == 1:
if len(returns) == 1:
template = """
return retval
"""

View File

@ -716,13 +716,6 @@ def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias,
time_major=time_major,
go_backwards=go_backwards,
sequence_lengths=sequence_lengths)
# Note that mask is a boolean tensor, which doesn't need to do gradient
# calculation, when using tf.cond, a default gradient is added for it,
# which then cause the backward function to have a signature mismatch.
# Force the mask to not generate gradient to allow implementation_selector
# to work properly.
# TODO(b/80444525): Remove the stop_gradient().
mask = array_ops.stop_gradient(mask)
def input_right_padded():
return cudnn_gru(
@ -1467,13 +1460,6 @@ def lstm_with_backend_selection(inputs, init_h, init_c, kernel,
time_major=time_major,
go_backwards=go_backwards,
sequence_lengths=sequence_lengths)
# Note that mask is a boolean tensor, which doesn't need to do gradient
# calculation, when using tf.cond, a default gradient is added for it,
# which then cause the backward function to have a signature mismatch.
# Force the mask to not generate gradient to allow implementation_selector
# to work properly.
# TODO(b/80444525): Remove the stop_gradient().
mask = array_ops.stop_gradient(mask)
def input_right_padded():
return cudnn_lstm(

View File

@ -26,7 +26,9 @@ from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import constant_op
@ -36,6 +38,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import variables
@ -397,6 +400,53 @@ class StaticHashTableTest(BaseLookupTableTest):
self.assertAllEqual([10, -1, 5], self.evaluate(result1))
self.assertAllEqual([10, -1, 5], self.evaluate(result2))
@test_util.enable_control_flow_v2
def testLookupTableInWhileV2(self):
lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
constant_op.constant([2, 5], dtype=dtypes.int64),
constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
beta = variables.Variable(1.0, trainable=True)
@def_function.function
def get_loss(unused_beta):
return map_fn.map_fn(
lookup.lookup,
constant_op.constant([2, 3], dtype=dtypes.int64),
dtype=dtypes.float32)
with backprop.GradientTape() as tape:
loss = get_loss(beta)
self.assertIsNone(tape.gradient(loss, beta))
@test_util.enable_control_flow_v2
def testLookupTableInCondV2(self):
lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
constant_op.constant([2, 5], dtype=dtypes.int64),
constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
beta = variables.Variable(1.0, trainable=True)
@def_function.function
def get_loss(beta):
def true_fn():
return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64))
def false_fn():
return constant_op.constant(0, dtype=dtypes.float32)
return beta * control_flow_ops.cond(
constant_op.constant(True), true_fn=true_fn, false_fn=false_fn)
with backprop.GradientTape() as tape:
loss = get_loss(beta)
grad = tape.gradient(loss, beta)
self.evaluate(variables.global_variables_initializer())
self.evaluate(lookup_ops.tables_initializer())
self.assertAllEqual(grad, -10.)
class KeyValueTensorInitializerTest(BaseLookupTableTest):

View File

@ -38,7 +38,6 @@ from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
@ -121,6 +120,11 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
false_grad_graph = _create_grad_func(
false_graph, grads, util.unique_grad_fn_name(false_graph.name))
# Replaces output None grads with zeros if atleast one branch has non-None
# grad at that index.
_create_zeros_for_none_grads([true_graph, false_graph],
[true_grad_graph, false_grad_graph])
if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
# Modify 'op' to output the intermediates needed by the grad functions. Note
# that all needed intermediates are wrapped in optionals. Each optional
@ -219,8 +223,6 @@ def _build_cond(pred,
# this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match([true_graph, false_graph],
[true_inputs, false_inputs])
# Save the original number of outputs to return to the caller.
num_cond_outputs = len(true_graph.outputs)
# We do not output intermediates of the gradient If op since this is just
# for backwards compatibility with existing code.
if not building_gradient and util.output_all_intermediates():
@ -270,12 +272,15 @@ def _build_cond(pred,
false_graph.outputs),
name=name)
# TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
if_op = tensors[0].op
if_op._true_graph = true_graph
if_op._false_graph = false_graph
util.maybe_set_lowering_attr(if_op)
util.maybe_propagate_compile_time_consts_in_xla(if_op)
if_op, tensors = _get_op_and_outputs(tensors)
# `if_op` is None if this is a `StatelessIf` op with no outputs.
if if_op is not None:
if_op._true_graph = true_graph
if_op._false_graph = false_graph
util.maybe_set_lowering_attr(if_op)
util.maybe_propagate_compile_time_consts_in_xla(if_op)
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
# Return identities for each output of the If op, rather than the output of
# the If op directly. This makes pruning work if the output of cond() is
@ -287,10 +292,7 @@ def _build_cond(pred,
# correct output structure
tensors = [array_ops.identity(t) for t in tensors]
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
tensors[:num_cond_outputs])
return _pack_sequence_as(true_graph.structured_outputs, tensors)
def get_func_graphs(op):
@ -368,18 +370,6 @@ def _grad_fn(func_graph, grads):
ys, func_graph.inputs, grad_ys=grad_ys,
src_graph=func_graph)
# Functions can't return None; replace Nones with zero tensors.
# TODO(b/80444525): don't return anything here and make _IfGrad return None if
# both branches have zero gradient.
for i in range(len(result)):
if result[i] is None:
if func_graph.inputs[i].dtype == dtypes.resource:
result[i] = array_ops.zeros(
gen_resource_variable_ops.variable_shape(func_graph.inputs[i]),
dtype=default_gradient.get_zeros_dtype(func_graph.inputs[i]))
else:
result[i] = array_ops.zeros_like(func_graph.inputs[i])
return result
@ -546,6 +536,34 @@ def _make_inputs_match(branch_graphs, branch_inputs):
return new_inputs
def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
"""Creates zeros for None out grads if atleast one branch has non-None grad.
Args:
forward_graphs: List of forward FuncGraphs.
grad_graphs: List of grad FuncGraphs.
"""
assert len(forward_graphs) == len(grad_graphs)
branch_outputs = [g.structured_outputs for g in grad_graphs]
num_outputs_per_branch = [len(outs) for outs in branch_outputs]
assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
if (any(t is None for t in branch_outs) and
any(t is not None for t in branch_outs)):
for branch_index, t in enumerate(branch_outs):
if t is None:
with grad_graphs[branch_index].as_default():
zeros = default_gradient.zeros_like(
forward_graphs[branch_index].inputs[output_idx])
grad_graphs[branch_index].structured_outputs[output_idx] = zeros
for grad_graph in grad_graphs:
grad_graph.outputs = [
t for t in func_graph_module.flatten(grad_graph.structured_outputs)
if t is not None
]
def _make_output_composite_tensors_match(op_type, branch_graphs):
"""Modifies each branch_graph's outputs to have the same output signature.
@ -591,7 +609,9 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
branch_graph.structured_outputs = branch_outs
branch_graph.outputs = func_graph_module.flatten(branch_outs)
branch_graph.outputs = [
t for t in func_graph_module.flatten(branch_outs) if t is not None
]
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
@ -646,10 +666,46 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
branch_graph.outputs[index], dtypes.int64)
for branch_graph in branch_graphs:
branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
branch_graph.structured_outputs = _pack_sequence_as(
branch_graph.structured_outputs, branch_graph.outputs)
def _get_op_and_outputs(op_or_outputs):
if isinstance(op_or_outputs, ops.Operation):
return op_or_outputs, []
elif not op_or_outputs: # Empty list.
return None, []
else:
return op_or_outputs[0].op, op_or_outputs
def _pack_sequence_as(structured_outputs, op_outputs):
"""Packs the outputs of the gradient If/Case op.
The branch functions may contain None's in the list of `structured_outputs`.
`op_outputs` has those outputs missing. So we need to add those Nones to the
list of `op_outputs` and then pack it in the same structure as
`structured_outputs`.
Args:
structured_outputs: structured_outputs from one of the branch functions.
op_outputs: List of output tensors of the op.
Returns:
`op_outputs` packed like `structured_outputs`.
"""
outputs_with_nones = []
counter = 0
for output in nest.flatten(structured_outputs, expand_composites=True):
if output is None:
outputs_with_nones.append(None)
else:
outputs_with_nones.append(op_outputs[counter])
counter += 1
return func_graph_module.pack_sequence_as(structured_outputs,
outputs_with_nones)
def _wrap_intermediates(func_graph, intermediates):
with func_graph.as_default():
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
@ -933,6 +989,9 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
branch_grad_graphs.append(
_create_grad_func(branch_graph, grads,
util.unique_grad_fn_name(branch_graph.name)))
# Replaces output None grads with zeros if atleast one branch has non-None
# grad at that index.
_create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
if any(g.op_needs_rewrite for g in branch_grad_graphs):
# Modify 'op' to output the intermediates needed by the grad functions. Note
@ -1033,10 +1092,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name)
# TODO(b/110167197): this requires Case to have at least 1 output
case_op = tensors[0].op
util.maybe_set_lowering_attr(case_op)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
case_op, tensors = _get_op_and_outputs(tensors)
if case_op is not None:
util.maybe_set_lowering_attr(case_op)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
# Return identities for each output of the Case op, rather than the output of
# the Case op directly. This makes pruning work if the output of switch_case()
@ -1048,7 +1110,4 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
# correct output structure
tensors = [array_ops.identity(t) for t in tensors]
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs,
tensors)
return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)

View File

@ -63,3 +63,22 @@ def ones_like(t):
return array_ops.ones(*shape_and_dtype(t))
else:
return array_ops.ones_like(t)
def supports_default_grad(t):
"""Whether tensor `t` supports creating a default gradient.
This function assumes that `t` is of a trainable type.
Args:
t: Tensor
Returns:
Bool
"""
if t.dtype == dtypes.resource:
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if (handle_data is None or not handle_data.is_set or
len(handle_data.shape_and_type) != 1):
return False
return True

View File

@ -653,7 +653,10 @@ def _GradientsHelper(ys,
# issue here because of zeros.
if loop_state:
out_grads[i] = loop_state.ZerosLike(op, i)
else:
elif default_gradient.supports_default_grad(op.outputs[i]):
# TODO(b/143286622): The supports_default_grad check is needed
# because While op emits non-differentiable resource tensors
# as outputs. Remove this check when that is not the case.
out_grads[i] = control_flow_state.ZerosLikeOutsideLoop(op, i)
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access

View File

@ -497,8 +497,11 @@ def _preprocess_grad(grad, body_graph_output, while_op_output):
# GradientTape initializes resource and variant grads as None instead of
# zeros. Set to zeros so _GradientsHelper computes the gradients instead of
# returning None.
if (while_op_output.dtype in (dtypes.resource, dtypes.variant)
and grad is None):
# TODO(b/143286622): The supports_default_grad check is needed
# because While op emits non-differentiable resource tensors
# as outputs. Remove this check when that is not the case.
if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
default_gradient.supports_default_grad(while_op_output) and grad is None):
return _zeros_like(while_op_output)
return grad