diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index c4b0e14e00b..28fbe63a020 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -36,13 +36,7 @@ class ControlFlowTransformer(converter.Base):
 
   def _create_cond_branch(self, body_name, aliased_orig_names,
                           aliased_new_names, body, returns):
-    if not returns:
-      # TODO(b/110167197): Replace with a plain return.
-      template = """
-        return 1
-      """
-      return_stmt = templates.replace(template)
-    elif len(returns) == 1:
+    if len(returns) == 1:
       template = """
         return retval
       """
diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py
index 49b6dbfec97..a62e3fc8600 100644
--- a/tensorflow/python/keras/layers/recurrent_v2.py
+++ b/tensorflow/python/keras/layers/recurrent_v2.py
@@ -716,13 +716,6 @@ def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias,
           time_major=time_major,
           go_backwards=go_backwards,
           sequence_lengths=sequence_lengths)
-    # Note that mask is a boolean tensor, which doesn't need to do gradient
-    # calculation, when using tf.cond, a default gradient is added for it,
-    # which then cause the backward function to have a signature mismatch.
-    # Force the mask to not generate gradient to allow implementation_selector
-    # to work properly.
-    # TODO(b/80444525): Remove the stop_gradient().
-    mask = array_ops.stop_gradient(mask)
 
     def input_right_padded():
       return cudnn_gru(
@@ -1467,13 +1460,6 @@ def lstm_with_backend_selection(inputs, init_h, init_c, kernel,
           time_major=time_major,
           go_backwards=go_backwards,
           sequence_lengths=sequence_lengths)
-    # Note that mask is a boolean tensor, which doesn't need to do gradient
-    # calculation, when using tf.cond, a default gradient is added for it,
-    # which then cause the backward function to have a signature mismatch.
-    # Force the mask to not generate gradient to allow implementation_selector
-    # to work properly.
-    # TODO(b/80444525): Remove the stop_gradient().
-    mask = array_ops.stop_gradient(mask)
 
     def input_right_padded():
       return cudnn_lstm(
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 888cadd5344..8129a59251b 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -26,7 +26,9 @@ from tensorflow.python import tf2
 from tensorflow.python.client import session
 from tensorflow.python.data.experimental.ops import counter
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function
 from tensorflow.python.eager import wrap_function
 from tensorflow.python.framework import constant_op
@@ -36,6 +38,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import map_fn
 from tensorflow.python.ops import variables
@@ -397,6 +400,53 @@ class StaticHashTableTest(BaseLookupTableTest):
     self.assertAllEqual([10, -1, 5], self.evaluate(result1))
     self.assertAllEqual([10, -1, 5], self.evaluate(result2))
 
+  @test_util.enable_control_flow_v2
+  def testLookupTableInWhileV2(self):
+    lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
+        constant_op.constant([2, 5], dtype=dtypes.int64),
+        constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
+
+    beta = variables.Variable(1.0, trainable=True)
+
+    @def_function.function
+    def get_loss(unused_beta):
+      return map_fn.map_fn(
+          lookup.lookup,
+          constant_op.constant([2, 3], dtype=dtypes.int64),
+          dtype=dtypes.float32)
+
+    with backprop.GradientTape() as tape:
+      loss = get_loss(beta)
+
+    self.assertIsNone(tape.gradient(loss, beta))
+
+  @test_util.enable_control_flow_v2
+  def testLookupTableInCondV2(self):
+    lookup = self.getHashTable()(lookup_ops.KeyValueTensorInitializer(
+        constant_op.constant([2, 5], dtype=dtypes.int64),
+        constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1)
+
+    beta = variables.Variable(1.0, trainable=True)
+
+    @def_function.function
+    def get_loss(beta):
+
+      def true_fn():
+        return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64))
+
+      def false_fn():
+        return constant_op.constant(0, dtype=dtypes.float32)
+
+      return beta * control_flow_ops.cond(
+          constant_op.constant(True), true_fn=true_fn, false_fn=false_fn)
+
+    with backprop.GradientTape() as tape:
+      loss = get_loss(beta)
+    grad = tape.gradient(loss, beta)
+    self.evaluate(variables.global_variables_initializer())
+    self.evaluate(lookup_ops.tables_initializer())
+    self.assertAllEqual(grad, -10.)
+
 
 class KeyValueTensorInitializerTest(BaseLookupTableTest):
 
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index b3d5000798e..6a15a6c6b09 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -38,7 +38,6 @@ from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import default_gradient
 from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_functional_ops
-from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import gradients_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.util import nest
@@ -121,6 +120,11 @@ def _IfGrad(op, *grads):  # pylint: disable=invalid-name
   false_grad_graph = _create_grad_func(
       false_graph, grads, util.unique_grad_fn_name(false_graph.name))
 
+  # Replaces output None grads with zeros if atleast one branch has non-None
+  # grad at that index.
+  _create_zeros_for_none_grads([true_graph, false_graph],
+                               [true_grad_graph, false_grad_graph])
+
   if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
     # Modify 'op' to output the intermediates needed by the grad functions. Note
     # that all needed intermediates are wrapped in optionals. Each optional
@@ -219,8 +223,6 @@ def _build_cond(pred,
   # this modifies true_graph and false_graph.
   cond_inputs = _make_inputs_match([true_graph, false_graph],
                                    [true_inputs, false_inputs])
-  # Save the original number of outputs to return to the caller.
-  num_cond_outputs = len(true_graph.outputs)
   # We do not output intermediates of the gradient If op since this is just
   # for backwards compatibility with existing code.
   if not building_gradient and util.output_all_intermediates():
@@ -270,12 +272,15 @@ def _build_cond(pred,
                                          false_graph.outputs),
         name=name)
 
-  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
-  if_op = tensors[0].op
-  if_op._true_graph = true_graph
-  if_op._false_graph = false_graph
-  util.maybe_set_lowering_attr(if_op)
-  util.maybe_propagate_compile_time_consts_in_xla(if_op)
+  if_op, tensors = _get_op_and_outputs(tensors)
+  # `if_op` is None if this is a `StatelessIf` op with no outputs.
+  if if_op is not None:
+    if_op._true_graph = true_graph
+    if_op._false_graph = false_graph
+    util.maybe_set_lowering_attr(if_op)
+    util.maybe_propagate_compile_time_consts_in_xla(if_op)
+    # Prevent fetching since the variant outputs can't be fetched directly.
+    if_op.graph.prevent_fetching(if_op)
 
   # Return identities for each output of the If op, rather than the output of
   # the If op directly. This makes pruning work if the output of cond() is
@@ -287,10 +292,7 @@ def _build_cond(pred,
   # correct output structure
   tensors = [array_ops.identity(t) for t in tensors]
 
-  # Prevent fetching since the variant outputs can't be fetched directly.
-  if_op.graph.prevent_fetching(if_op)
-  return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
-                                            tensors[:num_cond_outputs])
+  return _pack_sequence_as(true_graph.structured_outputs, tensors)
 
 
 def get_func_graphs(op):
@@ -368,18 +370,6 @@ def _grad_fn(func_graph, grads):
       ys, func_graph.inputs, grad_ys=grad_ys,
       src_graph=func_graph)
 
-  # Functions can't return None; replace Nones with zero tensors.
-  # TODO(b/80444525): don't return anything here and make _IfGrad return None if
-  # both branches have zero gradient.
-  for i in range(len(result)):
-    if result[i] is None:
-      if func_graph.inputs[i].dtype == dtypes.resource:
-        result[i] = array_ops.zeros(
-            gen_resource_variable_ops.variable_shape(func_graph.inputs[i]),
-            dtype=default_gradient.get_zeros_dtype(func_graph.inputs[i]))
-      else:
-        result[i] = array_ops.zeros_like(func_graph.inputs[i])
-
   return result
 
 
@@ -546,6 +536,34 @@ def _make_inputs_match(branch_graphs, branch_inputs):
   return new_inputs
 
 
+def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
+  """Creates zeros for None out grads if atleast one branch has non-None grad.
+
+  Args:
+    forward_graphs: List of forward FuncGraphs.
+    grad_graphs: List of grad FuncGraphs.
+  """
+  assert len(forward_graphs) == len(grad_graphs)
+  branch_outputs = [g.structured_outputs for g in grad_graphs]
+  num_outputs_per_branch = [len(outs) for outs in branch_outputs]
+  assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
+  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
+    if (any(t is None for t in branch_outs) and
+        any(t is not None for t in branch_outs)):
+      for branch_index, t in enumerate(branch_outs):
+        if t is None:
+          with grad_graphs[branch_index].as_default():
+            zeros = default_gradient.zeros_like(
+                forward_graphs[branch_index].inputs[output_idx])
+            grad_graphs[branch_index].structured_outputs[output_idx] = zeros
+
+  for grad_graph in grad_graphs:
+    grad_graph.outputs = [
+        t for t in func_graph_module.flatten(grad_graph.structured_outputs)
+        if t is not None
+    ]
+
+
 def _make_output_composite_tensors_match(op_type, branch_graphs):
   """Modifies each branch_graph's outputs to have the same output signature.
 
@@ -591,7 +609,9 @@ def _make_output_composite_tensors_match(op_type, branch_graphs):
 
   for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
     branch_graph.structured_outputs = branch_outs
-    branch_graph.outputs = func_graph_module.flatten(branch_outs)
+    branch_graph.outputs = [
+        t for t in func_graph_module.flatten(branch_outs) if t is not None
+    ]
 
 
 def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
@@ -646,10 +666,46 @@ def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
                 branch_graph.outputs[index], dtypes.int64)
 
   for branch_graph in branch_graphs:
-    branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
+    branch_graph.structured_outputs = _pack_sequence_as(
         branch_graph.structured_outputs, branch_graph.outputs)
 
 
+def _get_op_and_outputs(op_or_outputs):
+  if isinstance(op_or_outputs, ops.Operation):
+    return op_or_outputs, []
+  elif not op_or_outputs:  # Empty list.
+    return None, []
+  else:
+    return op_or_outputs[0].op, op_or_outputs
+
+
+def _pack_sequence_as(structured_outputs, op_outputs):
+  """Packs the outputs of the gradient If/Case op.
+
+  The branch functions may contain None's in the list of `structured_outputs`.
+  `op_outputs` has those outputs missing. So we need to add those Nones to the
+  list of `op_outputs` and then pack it in the same structure as
+  `structured_outputs`.
+
+  Args:
+    structured_outputs: structured_outputs from one of the branch functions.
+    op_outputs: List of output tensors of the op.
+
+  Returns:
+    `op_outputs` packed like `structured_outputs`.
+  """
+  outputs_with_nones = []
+  counter = 0
+  for output in nest.flatten(structured_outputs, expand_composites=True):
+    if output is None:
+      outputs_with_nones.append(None)
+    else:
+      outputs_with_nones.append(op_outputs[counter])
+      counter += 1
+  return func_graph_module.pack_sequence_as(structured_outputs,
+                                            outputs_with_nones)
+
+
 def _wrap_intermediates(func_graph, intermediates):
   with func_graph.as_default():
     return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
@@ -933,6 +989,9 @@ def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
     branch_grad_graphs.append(
         _create_grad_func(branch_graph, grads,
                           util.unique_grad_fn_name(branch_graph.name)))
+  # Replaces output None grads with zeros if atleast one branch has non-None
+  # grad at that index.
+  _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs)
 
   if any(g.op_needs_rewrite for g in branch_grad_graphs):
     # Modify 'op' to output the intermediates needed by the grad functions. Note
@@ -1033,10 +1092,13 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
         output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
         name=name)
 
-  # TODO(b/110167197): this requires Case to have at least 1 output
-  case_op = tensors[0].op
-  util.maybe_set_lowering_attr(case_op)
-  util.maybe_propagate_compile_time_consts_in_xla(case_op)
+  case_op, tensors = _get_op_and_outputs(tensors)
+
+  if case_op is not None:
+    util.maybe_set_lowering_attr(case_op)
+    util.maybe_propagate_compile_time_consts_in_xla(case_op)
+    # Prevent fetching since the variant outputs can't be fetched directly.
+    case_op.graph.prevent_fetching(case_op)
 
   # Return identities for each output of the Case op, rather than the output of
   # the Case op directly. This makes pruning work if the output of switch_case()
@@ -1048,7 +1110,4 @@ def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
   # correct output structure
   tensors = [array_ops.identity(t) for t in tensors]
 
-  # Prevent fetching since the variant outputs can't be fetched directly.
-  case_op.graph.prevent_fetching(case_op)
-  return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs,
-                                            tensors)
+  return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
diff --git a/tensorflow/python/ops/default_gradient.py b/tensorflow/python/ops/default_gradient.py
index 3323370b89b..2c18e9993e3 100644
--- a/tensorflow/python/ops/default_gradient.py
+++ b/tensorflow/python/ops/default_gradient.py
@@ -63,3 +63,22 @@ def ones_like(t):
     return array_ops.ones(*shape_and_dtype(t))
   else:
     return array_ops.ones_like(t)
+
+
+def supports_default_grad(t):
+  """Whether tensor `t` supports creating a default gradient.
+
+  This function assumes that `t` is of a trainable type.
+
+  Args:
+    t: Tensor
+
+  Returns:
+    Bool
+  """
+  if t.dtype == dtypes.resource:
+    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
+    if (handle_data is None or not handle_data.is_set or
+        len(handle_data.shape_and_type) != 1):
+      return False
+  return True
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
index 4c22da7f8d6..797f106a365 100644
--- a/tensorflow/python/ops/gradients_util.py
+++ b/tensorflow/python/ops/gradients_util.py
@@ -653,7 +653,10 @@ def _GradientsHelper(ys,
               # issue here because of zeros.
               if loop_state:
                 out_grads[i] = loop_state.ZerosLike(op, i)
-              else:
+              elif default_gradient.supports_default_grad(op.outputs[i]):
+                # TODO(b/143286622): The supports_default_grad check is needed
+                # because While op emits non-differentiable resource tensors
+                # as outputs. Remove this check when that is not the case.
                 out_grads[i] = control_flow_state.ZerosLikeOutsideLoop(op, i)
           with ops.name_scope(op.name + "_grad"):
             # pylint: disable=protected-access
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index d240885c1a9..d5cad732b8a 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -497,8 +497,11 @@ def _preprocess_grad(grad, body_graph_output, while_op_output):
   # GradientTape initializes resource and variant grads as None instead of
   # zeros. Set to zeros so _GradientsHelper computes the gradients instead of
   # returning None.
-  if (while_op_output.dtype in (dtypes.resource, dtypes.variant)
-      and grad is None):
+  # TODO(b/143286622): The supports_default_grad check is needed
+  # because While op emits non-differentiable resource tensors
+  # as outputs. Remove this check when that is not the case.
+  if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
+      default_gradient.supports_default_grad(while_op_output) and grad is None):
     return _zeros_like(while_op_output)
 
   return grad