diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index c4b0e14e00b..28fbe63a020 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -36,13 +36,7 @@ class ControlFlowTransformer(converter.Base): def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): - if not returns: - # TODO(b/110167197): Replace with a plain return. - template = """ - return 1 - """ - return_stmt = templates.replace(template) - elif len(returns) == 1: + if len(returns) == 1: template = """ return retval """ diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 49b6dbfec97..a62e3fc8600 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -716,13 +716,6 @@ def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias, time_major=time_major, go_backwards=go_backwards, sequence_lengths=sequence_lengths) - # Note that mask is a boolean tensor, which doesn't need to do gradient - # calculation, when using tf.cond, a default gradient is added for it, - # which then cause the backward function to have a signature mismatch. - # Force the mask to not generate gradient to allow implementation_selector - # to work properly. - # TODO(b/80444525): Remove the stop_gradient(). - mask = array_ops.stop_gradient(mask) def input_right_padded(): return cudnn_gru( @@ -1467,13 +1460,6 @@ def lstm_with_backend_selection(inputs, init_h, init_c, kernel, time_major=time_major, go_backwards=go_backwards, sequence_lengths=sequence_lengths) - # Note that mask is a boolean tensor, which doesn't need to do gradient - # calculation, when using tf.cond, a default gradient is added for it, - # which then cause the backward function to have a signature mismatch. - # Force the mask to not generate gradient to allow implementation_selector - # to work properly. - # TODO(b/80444525): Remove the stop_gradient(). - mask = array_ops.stop_gradient(mask) def input_right_padded(): return cudnn_lstm( diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 888cadd5344..8129a59251b 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -26,7 +26,9 @@ from tensorflow.python import tf2 from tensorflow.python.client import session from tensorflow.python.data.experimental.ops import counter from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import wrap_function from tensorflow.python.framework import constant_op @@ -36,6 +38,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import variables @@ -397,6 +400,53 @@ class StaticHashTableTest(BaseLookupTableTest): self.assertAllEqual([10, -1, 5], self.evaluate(result1)) self.assertAllEqual([10, -1, 5], self.evaluate(result2)) + @test_util.enable_control_flow_v2 + def testLookupTableInWhileV2(self): + lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer( + constant_op.constant([2, 5], dtype=dtypes.int64), + constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1) + + beta = variables.Variable(1.0, trainable=True) + + @def_function.function + def get_loss(unused_beta): + return map_fn.map_fn( + lookup.lookup, + constant_op.constant([2, 3], dtype=dtypes.int64), + dtype=dtypes.float32) + + with backprop.GradientTape() as tape: + loss = get_loss(beta) + + self.assertIsNone(tape.gradient(loss, beta)) + + @test_util.enable_control_flow_v2 + def testLookupTableInCondV2(self): + lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer( + constant_op.constant([2, 5], dtype=dtypes.int64), + constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1) + + beta = variables.Variable(1.0, trainable=True) + + @def_function.function + def get_loss(beta): + + def true_fn(): + return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64)) + + def false_fn(): + return constant_op.constant(0, dtype=dtypes.float32) + + return beta * control_flow_ops.cond( + constant_op.constant(True), true_fn=true_fn, false_fn=false_fn) + + with backprop.GradientTape() as tape: + loss = get_loss(beta) + grad = tape.gradient(loss, beta) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual(grad, -10.) + class KeyValueTensorInitializerTest(BaseLookupTableTest): diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index b3d5000798e..6a15a6c6b09 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -38,7 +38,6 @@ from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import default_gradient from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_functional_ops -from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gradients_util from tensorflow.python.ops import math_ops from tensorflow.python.util import nest @@ -121,6 +120,11 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name false_grad_graph = _create_grad_func( false_graph, grads, util.unique_grad_fn_name(false_graph.name)) + # Replaces output None grads with zeros if atleast one branch has non-None + # grad at that index. + _create_zeros_for_none_grads([true_graph, false_graph], + [true_grad_graph, false_grad_graph]) + if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite): # Modify 'op' to output the intermediates needed by the grad functions. Note # that all needed intermediates are wrapped in optionals. Each optional @@ -219,8 +223,6 @@ def _build_cond(pred, # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match([true_graph, false_graph], [true_inputs, false_inputs]) - # Save the original number of outputs to return to the caller. - num_cond_outputs = len(true_graph.outputs) # We do not output intermediates of the gradient If op since this is just # for backwards compatibility with existing code. if not building_gradient and util.output_all_intermediates(): @@ -270,12 +272,15 @@ def _build_cond(pred, false_graph.outputs), name=name) - # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output - if_op = tensors[0].op - if_op._true_graph = true_graph - if_op._false_graph = false_graph - util.maybe_set_lowering_attr(if_op) - util.maybe_propagate_compile_time_consts_in_xla(if_op) + if_op, tensors = _get_op_and_outputs(tensors) + # `if_op` is None if this is a `StatelessIf` op with no outputs. + if if_op is not None: + if_op._true_graph = true_graph + if_op._false_graph = false_graph + util.maybe_set_lowering_attr(if_op) + util.maybe_propagate_compile_time_consts_in_xla(if_op) + # Prevent fetching since the variant outputs can't be fetched directly. + if_op.graph.prevent_fetching(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is @@ -287,10 +292,7 @@ def _build_cond(pred, # correct output structure tensors = [array_ops.identity(t) for t in tensors] - # Prevent fetching since the variant outputs can't be fetched directly. - if_op.graph.prevent_fetching(if_op) - return func_graph_module.pack_sequence_as(true_graph.structured_outputs, - tensors[:num_cond_outputs]) + return _pack_sequence_as(true_graph.structured_outputs, tensors) def get_func_graphs(op): @@ -368,18 +370,6 @@ def _grad_fn(func_graph, grads): ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) - # Functions can't return None; replace Nones with zero tensors. - # TODO(b/80444525): don't return anything here and make _IfGrad return None if - # both branches have zero gradient. - for i in range(len(result)): - if result[i] is None: - if func_graph.inputs[i].dtype == dtypes.resource: - result[i] = array_ops.zeros( - gen_resource_variable_ops.variable_shape(func_graph.inputs[i]), - dtype=default_gradient.get_zeros_dtype(func_graph.inputs[i])) - else: - result[i] = array_ops.zeros_like(func_graph.inputs[i]) - return result @@ -546,6 +536,34 @@ def _make_inputs_match(branch_graphs, branch_inputs): return new_inputs +def _create_zeros_for_none_grads(forward_graphs, grad_graphs): + """Creates zeros for None out grads if atleast one branch has non-None grad. + + Args: + forward_graphs: List of forward FuncGraphs. + grad_graphs: List of grad FuncGraphs. + """ + assert len(forward_graphs) == len(grad_graphs) + branch_outputs = [g.structured_outputs for g in grad_graphs] + num_outputs_per_branch = [len(outs) for outs in branch_outputs] + assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch + for output_idx, branch_outs in enumerate(zip(*branch_outputs)): + if (any(t is None for t in branch_outs) and + any(t is not None for t in branch_outs)): + for branch_index, t in enumerate(branch_outs): + if t is None: + with grad_graphs[branch_index].as_default(): + zeros = default_gradient.zeros_like( + forward_graphs[branch_index].inputs[output_idx]) + grad_graphs[branch_index].structured_outputs[output_idx] = zeros + + for grad_graph in grad_graphs: + grad_graph.outputs = [ + t for t in func_graph_module.flatten(grad_graph.structured_outputs) + if t is not None + ] + + def _make_output_composite_tensors_match(op_type, branch_graphs): """Modifies each branch_graph's outputs to have the same output signature. @@ -591,7 +609,9 @@ def _make_output_composite_tensors_match(op_type, branch_graphs): for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): branch_graph.structured_outputs = branch_outs - branch_graph.outputs = func_graph_module.flatten(branch_outs) + branch_graph.outputs = [ + t for t in func_graph_module.flatten(branch_outs) if t is not None + ] def _make_indexed_slices_indices_types_match(op_type, branch_graphs): @@ -646,10 +666,46 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs): branch_graph.outputs[index], dtypes.int64) for branch_graph in branch_graphs: - branch_graph.structured_outputs = func_graph_module.pack_sequence_as( + branch_graph.structured_outputs = _pack_sequence_as( branch_graph.structured_outputs, branch_graph.outputs) +def _get_op_and_outputs(op_or_outputs): + if isinstance(op_or_outputs, ops.Operation): + return op_or_outputs, [] + elif not op_or_outputs: # Empty list. + return None, [] + else: + return op_or_outputs[0].op, op_or_outputs + + +def _pack_sequence_as(structured_outputs, op_outputs): + """Packs the outputs of the gradient If/Case op. + + The branch functions may contain None's in the list of `structured_outputs`. + `op_outputs` has those outputs missing. So we need to add those Nones to the + list of `op_outputs` and then pack it in the same structure as + `structured_outputs`. + + Args: + structured_outputs: structured_outputs from one of the branch functions. + op_outputs: List of output tensors of the op. + + Returns: + `op_outputs` packed like `structured_outputs`. + """ + outputs_with_nones = [] + counter = 0 + for output in nest.flatten(structured_outputs, expand_composites=True): + if output is None: + outputs_with_nones.append(None) + else: + outputs_with_nones.append(op_outputs[counter]) + counter += 1 + return func_graph_module.pack_sequence_as(structured_outputs, + outputs_with_nones) + + def _wrap_intermediates(func_graph, intermediates): with func_graph.as_default(): return [gen_dataset_ops.optional_from_value([t]) for t in intermediates] @@ -933,6 +989,9 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name branch_grad_graphs.append( _create_grad_func(branch_graph, grads, util.unique_grad_fn_name(branch_graph.name))) + # Replaces output None grads with zeros if atleast one branch has non-None + # grad at that index. + _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs) if any(g.op_needs_rewrite for g in branch_grad_graphs): # Modify 'op' to output the intermediates needed by the grad functions. Note @@ -1033,10 +1092,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), name=name) - # TODO(b/110167197): this requires Case to have at least 1 output - case_op = tensors[0].op - util.maybe_set_lowering_attr(case_op) - util.maybe_propagate_compile_time_consts_in_xla(case_op) + case_op, tensors = _get_op_and_outputs(tensors) + + if case_op is not None: + util.maybe_set_lowering_attr(case_op) + util.maybe_propagate_compile_time_consts_in_xla(case_op) + # Prevent fetching since the variant outputs can't be fetched directly. + case_op.graph.prevent_fetching(case_op) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() @@ -1048,7 +1110,4 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None): # correct output structure tensors = [array_ops.identity(t) for t in tensors] - # Prevent fetching since the variant outputs can't be fetched directly. - case_op.graph.prevent_fetching(case_op) - return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs, - tensors) + return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors) diff --git a/tensorflow/python/ops/default_gradient.py b/tensorflow/python/ops/default_gradient.py index 3323370b89b..2c18e9993e3 100644 --- a/tensorflow/python/ops/default_gradient.py +++ b/tensorflow/python/ops/default_gradient.py @@ -63,3 +63,22 @@ def ones_like(t): return array_ops.ones(*shape_and_dtype(t)) else: return array_ops.ones_like(t) + + +def supports_default_grad(t): + """Whether tensor `t` supports creating a default gradient. + + This function assumes that `t` is of a trainable type. + + Args: + t: Tensor + + Returns: + Bool + """ + if t.dtype == dtypes.resource: + handle_data = resource_variable_ops.get_eager_safe_handle_data(t) + if (handle_data is None or not handle_data.is_set or + len(handle_data.shape_and_type) != 1): + return False + return True diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 4c22da7f8d6..797f106a365 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -653,7 +653,10 @@ def _GradientsHelper(ys, # issue here because of zeros. if loop_state: out_grads[i] = loop_state.ZerosLike(op, i) - else: + elif default_gradient.supports_default_grad(op.outputs[i]): + # TODO(b/143286622): The supports_default_grad check is needed + # because While op emits non-differentiable resource tensors + # as outputs. Remove this check when that is not the case. out_grads[i] = control_flow_state.ZerosLikeOutsideLoop(op, i) with ops.name_scope(op.name + "_grad"): # pylint: disable=protected-access diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index d240885c1a9..d5cad732b8a 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -497,8 +497,11 @@ def _preprocess_grad(grad, body_graph_output, while_op_output): # GradientTape initializes resource and variant grads as None instead of # zeros. Set to zeros so _GradientsHelper computes the gradients instead of # returning None. - if (while_op_output.dtype in (dtypes.resource, dtypes.variant) - and grad is None): + # TODO(b/143286622): The supports_default_grad check is needed + # because While op emits non-differentiable resource tensors + # as outputs. Remove this check when that is not the case. + if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and + default_gradient.supports_default_grad(while_op_output) and grad is None): return _zeros_like(while_op_output) return grad