diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index ec780a7c0a1..9cf3bba8dd5 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -118,7 +118,13 @@ py_test(
     name = "control_flow_test",
     srcs = ["control_flow_test.py"],
     python_version = "PY3",
-    srcs_version = "PY2AND3",
+    srcs_version = "PY3",
+    tags = [
+        "no_oss_py2",
+        "no_pip",
+        "no_windows",
+        "nopip",
+    ],
     deps = [
         ":converters",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
index 44ab6dee926..65fb6765fcf 100644
--- a/tensorflow/python/autograph/converters/conditional_expressions.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -18,7 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import gast
+
 from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import templates
 
 
@@ -26,19 +29,20 @@ class ConditionalExpressionTransformer(converter.Base):
   """Converts conditional expressions to functional form."""
 
   def visit_IfExp(self, node):
-    return templates.replace_as_expression(
-        '''ag__.if_stmt(
+    template = '''
+        ag__.if_exp(
             test,
             lambda: true_expr,
             lambda: false_expr,
-            lambda: (),
-            lambda _: None,
-            ('<internal expr>',),
-            ())
-        ''',
+            expr_repr)
+    '''
+    expr_repr = parser.unparse(node.test, include_encoding_marker=False).strip()
+    return templates.replace_as_expression(
+        template,
         test=node.test,
         true_expr=node.body,
-        false_expr=node.orelse)
+        false_expr=node.orelse,
+        expr_repr=gast.Constant(expr_repr, kind=None))
 
 
 def transform(node, ctx):
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index a903c43bcfc..673781e47dd 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -23,7 +23,6 @@ import gast
 from tensorflow.python.autograph.core import converter
 from tensorflow.python.autograph.lang import directives
 from tensorflow.python.autograph.pyct import anno
-from tensorflow.python.autograph.pyct import ast_util
 from tensorflow.python.autograph.pyct import cfg
 from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.autograph.pyct import qual_names
@@ -57,114 +56,16 @@ class ControlFlowTransformer(converter.Base):
       fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
       return self.generic_visit(node)
 
-  def _create_cond_branch(self, body_name, aliased_orig_names,
-                          aliased_new_names, body, returns):
-    if len(returns) == 1:
-      template = """
-        return retval
-      """
-      return_stmt = templates.replace(template, retval=returns[0])
-    else:
-      template = """
-        return (retvals,)
-      """
-      return_stmt = templates.replace(template, retvals=returns)
-
-    if aliased_orig_names:
-      alias_declarations = []
-      for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
-        template = """
-          try:
-            aliased_new_name = aliased_orig_name
-          except NameError:
-            aliased_new_name = ag__.Undefined(symbol_name)
-        """
-
-        alias_declarations.extend(
-            templates.replace(
-                template,
-                aliased_new_name=new_name,
-                aliased_orig_name=old_name,
-                symbol_name=gast.Constant(str(old_name), kind=None)))
-
-      template = """
-        def body_name():
-          alias_declarations
-          body
-          return_stmt
-      """
-      return templates.replace(
-          template,
-          alias_declarations=alias_declarations,
-          body_name=body_name,
-          body=body,
-          return_stmt=return_stmt)
-    else:
-      template = """
-        def body_name():
-          body
-          return_stmt
-      """
-      return templates.replace(
-          template, body_name=body_name, body=body, return_stmt=return_stmt)
-
-  def _create_cond_expr(self, results, test, body_name, orelse_name,
-                        state_getter_name, state_setter_name,
-                        basic_symbol_names, composite_symbol_names):
-    if results is not None:
-      template = """
-        results = ag__.if_stmt(test, body_name, orelse_name,
-                               state_getter_name, state_setter_name,
-                               (basic_symbol_names,),
-                               (composite_symbol_names,))
-      """
-      return templates.replace(
-          template,
-          test=test,
-          results=results,
-          body_name=body_name,
-          orelse_name=orelse_name,
-          state_getter_name=state_getter_name,
-          state_setter_name=state_setter_name,
-          basic_symbol_names=basic_symbol_names,
-          composite_symbol_names=composite_symbol_names)
-    else:
-      template = """
-        ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name,
-                     (basic_symbol_names,), (composite_symbol_names,))
-      """
-      return templates.replace(
-          template,
-          test=test,
-          body_name=body_name,
-          orelse_name=orelse_name,
-          getter_name=state_getter_name,
-          setter_name=state_setter_name,
-          basic_symbol_names=basic_symbol_names,
-          composite_symbol_names=composite_symbol_names)
-
-  def _fmt_symbols(self, symbol_set):
-    if not symbol_set:
-      return 'no variables'
-    return ', '.join(map(str, symbol_set))
-
-  def _determine_aliased_symbols(self, scope, node_defined_in):
-    modified_live = scope.modified & node_defined_in
-    # Composite symbols are handled elsewhere, see _create_state_functions
-    return {
-        s for s in modified_live
-        if not s.is_composite() and s not in self.state[_Function].scope.globals
-    }
-
-  def _create_nonlocal_declarations(self, loop_vars):
+  def _create_nonlocal_declarations(self, vars_):
+    vars_ = set(vars_)
     results = []
     global_vars = self.state[_Function].scope.globals
 
     if global_vars:
-      results.append(gast.Global([str(v) for v in global_vars]))
+      results.append(gast.Global([str(v) for v in vars_]))
 
     nonlocal_vars = [
-        v for v in loop_vars if not v.is_composite() and v not in global_vars]
+        v for v in vars_ if not v.is_composite() and v not in global_vars]
     if nonlocal_vars:
       results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))
 
@@ -176,9 +77,9 @@ class ControlFlowTransformer(converter.Base):
       template = """
         def getter_name():
           return state_vars,
-        def setter_name(loop_vars):
+        def setter_name(vars_):
           nonlocal_declarations
-          state_vars, = loop_vars
+          state_vars, = vars_
       """
       return templates.replace(
           template,
@@ -222,166 +123,34 @@ class ControlFlowTransformer(converter.Base):
           symbol_name=gast.Constant(s.ssf(), kind=None))
     return assignments
 
-  def visit_If(self, node):
-    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
-    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
-    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
-    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
-
-    # Note: this information needs to be extracted before the body conversion
-    # that happens in the call to generic_visit below, because the conversion
-    # generates nodes that lack static analysis annotations.
-    need_alias_in_body = self._determine_aliased_symbols(
-        body_scope, defined_in)
-    need_alias_in_orelse = self._determine_aliased_symbols(
-        orelse_scope, defined_in)
-
-    node = self.generic_visit(node)
-
-    modified_in_cond = body_scope.modified | orelse_scope.modified
-    returned_from_cond = set()
-    composites = set()
-    for s in modified_in_cond:
-      if s in live_out and not s.is_composite():
-        returned_from_cond.add(s)
-      if s.is_composite():
-        # Special treatment for compound objects, always return them.
-        # This allows special handling within the if_stmt itself.
-        # For example, in TensorFlow we need to restore the state of composite
-        # symbols to ensure that only effects from the executed branch are seen.
-        composites.add(s)
-
-    created_in_body = body_scope.modified & returned_from_cond - defined_in
-    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
-
-    basic_created_in_body = tuple(
-        s for s in created_in_body if not s.is_composite())
-    basic_created_in_orelse = tuple(
-        s for s in created_in_orelse if not s.is_composite())
-
-    # These variables are defined only in a single branch. This is fine in
-    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
-    # to handle these cases specially or throw an Error.
-    possibly_undefined = (set(basic_created_in_body) ^
-                          set(basic_created_in_orelse))
-
-    # Alias the closure variables inside the conditional functions, to allow
-    # the functions access to the respective variables.
-    # We will alias variables independently for body and orelse scope,
-    # because different branches might write different variables.
-    aliased_body_orig_names = tuple(need_alias_in_body)
-    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
-    aliased_body_new_names = tuple(
-        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
-        for s in aliased_body_orig_names)
-    aliased_orelse_new_names = tuple(
-        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
-        for s in aliased_orelse_orig_names)
-
-    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
-    alias_orelse_map = dict(
-        zip(aliased_orelse_orig_names, aliased_orelse_new_names))
-
-    node_body = ast_util.rename_symbols(node.body, alias_body_map)
-    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
-
-    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
-    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
-    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
-    all_referenced = body_scope.referenced | orelse_scope.referenced
-    state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
-    state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
-
-    returned_from_cond = tuple(returned_from_cond)
-    composites = tuple(composites)
-
-    if returned_from_cond:
-      if len(returned_from_cond) == 1:
-        cond_results = returned_from_cond[0]
-      else:
-        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
-
-      returned_from_body = tuple(
-          alias_body_map[s] if s in need_alias_in_body else s
-          for s in returned_from_cond)
-      returned_from_orelse = tuple(
-          alias_orelse_map[s] if s in need_alias_in_orelse else s
-          for s in returned_from_cond)
-
-    else:
-      # When the cond would return no value, we leave the cond called without
-      # results. That in turn should trigger the side effect guards. The
-      # branch functions will return a dummy value that ensures cond
-      # actually has some return value as well.
-      cond_results = None
-      # TODO(mdan): Replace with None once side_effect_guards is retired.
-      returned_from_body = (templates.replace_as_expression(
-          'ag__.match_staging_level(1, cond_var_name)',
-          cond_var_name=cond_var_name),)
-      returned_from_orelse = (templates.replace_as_expression(
-          'ag__.match_staging_level(1, cond_var_name)',
-          cond_var_name=cond_var_name),)
-
-    cond_assign = self.create_assignment(cond_var_name, node.test)
-    body_def = self._create_cond_branch(
-        body_name,
-        aliased_orig_names=aliased_body_orig_names,
-        aliased_new_names=aliased_body_new_names,
-        body=node_body,
-        returns=returned_from_body)
-    orelse_def = self._create_cond_branch(
-        orelse_name,
-        aliased_orig_names=aliased_orelse_orig_names,
-        aliased_new_names=aliased_orelse_new_names,
-        body=node_orelse,
-        returns=returned_from_orelse)
-    undefined_assigns = self._create_undefined_assigns(possibly_undefined)
-    composite_defs = self._create_state_functions(
-        composites, [], state_getter_name, state_setter_name)
-
-    basic_symbol_names = tuple(
-        gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond)
-    composite_symbol_names = tuple(
-        gast.Constant(str(symbol), kind=None) for symbol in composites)
-
-    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
-                                       orelse_name, state_getter_name,
-                                       state_setter_name, basic_symbol_names,
-                                       composite_symbol_names)
-
-    if_ast = (
-        undefined_assigns + composite_defs + body_def + orelse_def +
-        cond_assign + cond_expr)
-    return if_ast
-
-  def _get_basic_loop_vars(self, modified, live_in, live_out):
-    # The loop variables corresponding to simple symbols (e.g. `x`).
-    basic_loop_vars = []
+  def _get_block_basic_vars(self, modified, live_in, live_out):
+    nonlocals = self.state[_Function].scope.nonlocals
+    basic_scope_vars = []
     for s in modified:
       if s.is_composite():
-        # TODO(mdan): Raise an error when this happens for a TF loop.
+        # TODO(mdan): Raise an error when this happens for a TF scope.
         continue
-      # Variables not live into or out of the loop are considered local to the
-      # loop.
-      if s not in live_in and s not in live_out:
-        continue
-      basic_loop_vars.append(s)
-    return frozenset(basic_loop_vars)
+      # Variables not live into or out of the scope are considered local to the
+      # scope.
+      if s in live_in or s in live_out or s in nonlocals:
+        basic_scope_vars.append(s)
+      continue
+    return frozenset(basic_scope_vars)
 
-  def _get_composite_loop_vars(self, modified, live_in):
-    # The loop variables corresponding to composite symbols (e.g. `self.x`).
-    composite_loop_vars = []
+  def _get_block_composite_vars(self, modified, live_in):
+    # The scope variables corresponding to composite symbols (e.g. `self.x`).
+    composite_scope_vars = []
     for s in modified:
       if not s.is_composite():
         continue
-      # Mutations made to objects created inside the loop will appear as writes
+      # Mutations made to objects created inside the scope will appear as writes
       # to composite symbols. Because these mutations appear as modifications
       # made to composite symbols, we check whether the composite's parent is
-      # actually live into the loop.
+      # actually live into the scope.
       # Example:
       #   while cond:
       #     x = Foo()
-      #     x.foo = 2 * x.foo  # x.foo is live into the loop, but x is not.
+      #     x.foo = 2 * x.foo  # x.foo is live into the scope, but x is not.
       #
       # Note that some parents might not be symbols - for example, in x['foo'],
       # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
@@ -390,40 +159,106 @@ class ControlFlowTransformer(converter.Base):
           sss for sss in s.support_set if sss.is_symbol())
       if not all(sss in live_in for sss in support_set_symbols):
         continue
-      composite_loop_vars.append(s)
-    return frozenset(composite_loop_vars)
+      composite_scope_vars.append(s)
+    return frozenset(composite_scope_vars)
 
-  def _get_loop_vars(self, node, modified):
-    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+  def _get_block_vars(self, node, modified):
+    """Determines the variables affected inside a control flow statement."""
     defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
     live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
     live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
-    reserved_symbols = body_scope.referenced
 
-    basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out)
-    composite_loop_vars = self._get_composite_loop_vars(modified, live_in)
-    loop_vars = tuple(basic_loop_vars | composite_loop_vars)
+    basic_scope_vars = self._get_block_basic_vars(
+        modified,
+        live_in,
+        live_out)
+    composite_scope_vars = self._get_block_composite_vars(modified, live_in)
+    scope_vars = tuple(basic_scope_vars | composite_scope_vars)
 
-    # Variable that are used or defined inside the loop, but not defined
-    # before entering the loop. Only simple variables must be defined. The
+    # Variables that are modified inside the scope, but not defined
+    # before entering it. Only simple variables must be defined. The
     # composite ones will be implicitly checked at runtime.
-    undefined_lives = basic_loop_vars - defined_in
+    # This covers loop variables as well as variables that
+    undefined = tuple(v for v in modified - defined_in if not v.is_composite())
 
-    return loop_vars, reserved_symbols, undefined_lives
+    # Variables that are modified inside the scope, and depend on values outside
+    # it.
+    input_only = basic_scope_vars & live_in - live_out
+
+    # Place the outputs first.
+    scope_vars = sorted(scope_vars, key=lambda v: v in input_only)
+    nouts = len(scope_vars) - len(input_only)
+
+    return scope_vars, undefined, nouts
+
+  def visit_If(self, node):
+    node = self.generic_visit(node)
+    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
+
+    cond_vars, undefined, nouts = self._get_block_vars(
+        node, body_scope.modified | orelse_scope.modified)
+
+    undefined_assigns = self._create_undefined_assigns(undefined)
+
+    nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)
+
+    reserved = body_scope.referenced | orelse_scope.referenced
+    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
+    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
+    state_functions = self._create_state_functions(
+        cond_vars, nonlocal_declarations, state_getter_name, state_setter_name)
+
+    orelse_body = node.orelse
+    if not orelse_body:
+      orelse_body = [gast.Pass()]
+
+    template = """
+      state_functions
+      def body_name():
+        nonlocal_declarations
+        body
+      def orelse_name():
+        nonlocal_declarations
+        orelse
+      undefined_assigns
+      ag__.if_stmt(
+        test,
+        body_name,
+        orelse_name,
+        state_getter_name,
+        state_setter_name,
+        (symbol_names,),
+        nouts)
+    """
+    return templates.replace(
+        template,
+        body=node.body,
+        body_name=self.ctx.namer.new_symbol('if_body', reserved),
+        orelse=orelse_body,
+        orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
+        nonlocal_declarations=nonlocal_declarations,
+        nouts=gast.Constant(nouts, kind=None),
+        state_functions=state_functions,
+        state_getter_name=state_getter_name,
+        state_setter_name=state_setter_name,
+        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
+        test=node.test,
+        undefined_assigns=undefined_assigns)
 
   def visit_While(self, node):
     node = self.generic_visit(node)
     body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
 
-    loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
-        node, body_scope.modified)
+    loop_vars, undefined, _ = self._get_block_vars(node, body_scope.modified)
 
-    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
+    undefined_assigns = self._create_undefined_assigns(undefined)
 
     nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
 
-    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
-    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
+    reserved = body_scope.referenced
+    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
+    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
     state_functions = self._create_state_functions(
         loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
 
@@ -448,7 +283,7 @@ class ControlFlowTransformer(converter.Base):
     return templates.replace(
         template,
         body=node.body,
-        body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
+        body_name=self.ctx.namer.new_symbol('loop_body', reserved),
         nonlocal_declarations=nonlocal_declarations,
         opts=opts,
         state_functions=state_functions,
@@ -456,7 +291,7 @@ class ControlFlowTransformer(converter.Base):
         state_setter_name=state_setter_name,
         symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
         test=node.test,
-        test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
+        test_name=self.ctx.namer.new_symbol('loop_test', reserved),
         undefined_assigns=undefined_assigns)
 
   def visit_For(self, node):
@@ -464,15 +299,16 @@ class ControlFlowTransformer(converter.Base):
     body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
     iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)
 
-    loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
+    loop_vars, undefined, _ = self._get_block_vars(
         node, body_scope.modified | iter_scope.modified)
 
-    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
+    undefined_assigns = self._create_undefined_assigns(undefined)
 
     nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
 
-    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
-    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
+    reserved = body_scope.referenced | iter_scope.referenced
+    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
+    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
     state_functions = self._create_state_functions(
         loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
 
@@ -484,7 +320,7 @@ class ControlFlowTransformer(converter.Base):
     if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
       extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
       extra_test_name = self.ctx.namer.new_symbol(
-          'extra_test', reserved_symbols)
+          'extra_test', reserved)
       template = """
         def extra_test_name():
           nonlocal_declarations
@@ -502,7 +338,7 @@ class ControlFlowTransformer(converter.Base):
 
     # iterate_arg_name holds a single arg with the iterates, which may be a
     # tuple.
-    iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols)
+    iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
     template = """
       iterates = iterate_arg_name
     """
@@ -529,7 +365,7 @@ class ControlFlowTransformer(converter.Base):
     return templates.replace(
         template,
         body=node.body,
-        body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
+        body_name=self.ctx.namer.new_symbol('loop_body', reserved),
         extra_test_function=extra_test_function,
         extra_test_name=extra_test_name,
         iterate_arg_name=iterate_arg_name,
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 32e86400da6..935e2cec4b8 100644
--- a/tensorflow/python/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -1,3 +1,4 @@
+# Lint as: python3
 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -453,6 +454,17 @@ class IfStatementTest(ControlFlowTestBase):
     self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
     self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
 
+  def test_local_remains_local(self):
+
+    def test_fn(n):
+      if n > 0:
+        b = 4
+        n = b + 1
+      return n
+
+    self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
+    self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
+
   def test_no_outputs(self):
 
     def test_fn(n):
@@ -465,6 +477,85 @@ class IfStatementTest(ControlFlowTestBase):
     self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
     self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
 
+  def test_created_outputs(self):
+
+    def test_fn(i):
+      if i == 0:
+        result = i - 1
+      else:
+        result = i + 1
+      return result
+
+    self.assertTransformedResult(test_fn, 0, -1)
+    self.assertTransformedResult(test_fn, 1, 2)
+
+  def test_created_loop_local_outputs(self):
+
+    def test_fn(n, x):
+      for i in n:
+        if i == 0:
+          result = i - 1
+        else:
+          result = i + 1
+        if result > 0:
+          x += 1
+      return x
+
+    self.assertTransformedResult(test_fn, (range(5), 10), 14)
+
+  def test_created_loop_variable(self):
+
+    def test_fn(n, x):
+      for i in n:
+        if i == 0:
+          result = i - 1
+        if i > 0:  # Using the result from previous iteration.
+          if result < 0:
+            x += 1
+      return x
+
+    self.assertTransformedResult(test_fn, (range(5), 10), 14)
+
+  def test_unaffected_global(self):
+
+    def test_fn(i):
+      global g  # pylint:disable=global-variable-undefined
+      if i == 0:
+        g = i - 1
+      return g
+
+    self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3})
+    self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3})
+
+  def test_unaffected_nonlocal(self):
+
+    def test_fn(i):
+      def inner_fn():
+        nonlocal n
+        if i == 0:
+          n = i - 1
+
+      n = 3
+      inner_fn()
+      return n
+
+    self.assertTransformedResult(test_fn, 1, 3)
+    self.assertTransformedResult(test_fn, 0, -1)
+
+  def test_output_defined_in_prior_except(self):
+
+    def test_fn(i):
+      try:
+        raise ValueError()
+      except ValueError:
+        x = 1
+      if i == 0:
+        x = i - 1
+      return x
+
+    self.assertTransformedResult(test_fn, 1, 1)
+    self.assertTransformedResult(test_fn, 0, -1)
+
   def test_unbalanced_multiple_composites(self):
 
     class Foo(object):
diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
index 3851c7b44ba..5f644ea525d 100644
--- a/tensorflow/python/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -22,6 +22,7 @@ py_library(
     name = "operators",
     srcs = [
         "__init__.py",
+        "conditional_expressions.py",
         "control_flow.py",
         "control_flow_deprecated_py2.py",
         "data_structures.py",
@@ -62,6 +63,20 @@ py_test(
     ],
 )
 
+py_test(
+    name = "conditional_expressions_test",
+    srcs = ["conditional_expressions_test.py"],
+    python_version = "PY3",
+    srcs_version = "PY3",
+    tags = [
+        "no_oss_py2",
+    ],
+    deps = [
+        ":operators",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
 py_test(
     name = "control_flow_test",
     srcs = ["control_flow_test.py"],
diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
index f7f9078107c..8ac4e1d8bb3 100644
--- a/tensorflow/python/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -37,6 +37,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.autograph.operators.conditional_expressions import if_exp
 from tensorflow.python.autograph.operators.control_flow import for_stmt
 from tensorflow.python.autograph.operators.control_flow import if_stmt
 from tensorflow.python.autograph.operators.control_flow import while_stmt
diff --git a/tensorflow/python/autograph/operators/conditional_expressions.py b/tensorflow/python/autograph/operators/conditional_expressions.py
new file mode 100644
index 00000000000..7ea2b249935
--- /dev/null
+++ b/tensorflow/python/autograph/operators/conditional_expressions.py
@@ -0,0 +1,56 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Conditional expressions (e.g. the ternary if statement)."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.autograph.operators import control_flow
+from tensorflow.python.autograph.utils import tensors
+from tensorflow.python.ops import control_flow_ops
+
+
+def if_exp(cond, if_true, if_false, expr_repr):
+  if tensors.is_dense_tensor(cond):
+    return _tf_if_exp(cond, if_true, if_false, expr_repr)
+  else:
+    return _py_if_exp(cond, if_true, if_false)
+
+
+def _tf_if_exp(cond, if_true, if_false, expr_repr):
+  """Overload of if_exp that stages a TF cond."""
+  # TODO(mdan): Use nonlocal once we no longer need to support py2.
+  true_val = []
+  false_val = []
+
+  def true_fn():
+    true_val.append(if_true())
+    if true_val and false_val:
+      control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0])
+    return true_val[0]
+
+  def false_fn():
+    false_val.append(if_false())
+    if true_val and false_val:
+      control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0])
+    return false_val[0]
+
+  return control_flow_ops.cond(cond, true_fn, false_fn)
+
+
+def _py_if_exp(cond, if_true, if_false):
+  return if_true() if cond else if_false()
diff --git a/tensorflow/python/autograph/operators/conditional_expressions_test.py b/tensorflow/python/autograph/operators/conditional_expressions_test.py
new file mode 100644
index 00000000000..3f126116023
--- /dev/null
+++ b/tensorflow/python/autograph/operators/conditional_expressions_test.py
@@ -0,0 +1,66 @@
+# Lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for conditional_expressions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import conditional_expressions
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+def _basic_expr(cond):
+  return conditional_expressions.if_exp(
+      cond,
+      lambda: constant_op.constant(1),
+      lambda: constant_op.constant(2),
+      'cond')
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class IfExpTest(test.TestCase):
+
+  def test_tensor(self):
+    self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(True))), 1)
+    self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(False))), 2)
+
+  def test_tensor_mismatched_type(self):
+    # tf.function required because eager cond degenerates to Python if.
+    @def_function.function
+    def test_fn():
+      conditional_expressions.if_exp(
+          constant_op.constant(True), lambda: 1.0, lambda: 2, 'expr_repr')
+
+    with self.assertRaisesRegexp(
+        TypeError,
+        "'expr_repr' has dtype float32 in the main.*int32 in the else"):
+      test_fn()
+
+  def test_python(self):
+    self.assertEqual(self.evaluate(_basic_expr(True)), 1)
+    self.assertEqual(self.evaluate(_basic_expr(False)), 2)
+    self.assertEqual(
+        conditional_expressions.if_exp(True, lambda: 1, lambda: 2, ''), 1)
+    self.assertEqual(
+        conditional_expressions.if_exp(False, lambda: 1, lambda: 2, ''), 2)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index 592281b0ce2..77db7579ece 100644
--- a/tensorflow/python/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -102,7 +102,7 @@ def _verify_loop_init_vars(values, symbol_names):
   """Ensures that all values in the state are defined when entering a loop."""
   for name, value in zip(symbol_names, values):
     if value is None:
-      raise ValueError('"{}" may not be None before the loop.'.format(name))
+      raise ValueError("'{}' may not be None before the loop.".format(name))
     if isinstance(value, variables.UndefinedReturnValue):
       # Assumption: the loop will only capture the variable which tracks the
       # return value if the loop contained a return statement.
@@ -110,7 +110,7 @@ def _verify_loop_init_vars(values, symbol_names):
       raise ValueError(
           'return statements are not supported within a TensorFlow loop.')
     if isinstance(value, variables.Undefined):
-      raise ValueError('"{}" must be defined before the loop.'.format(name))
+      raise ValueError("'{}' must be defined before the loop.".format(name))
 
 
 def _is_subshape(left, right):
@@ -133,9 +133,9 @@ def _is_subshape(left, right):
 def _verify_single_loop_var(
     name, check_shape, init, entry, exit_, shape_invariant):
   """Verifies whether the initial, entry and exit values are consistent."""
-  assert entry is not None, 'no TF op should set "{}" to None?'.format(name)
+  assert entry is not None, "no TF op should set '{}' to None?".format(name)
   if exit_ is None:
-    raise ValueError('"{}" is None at the end of the iteration.'.format(name))
+    raise ValueError("'{}' is None at the end of the iteration.".format(name))
 
   if isinstance(init, (bool, int, float, str, np.ndarray)):
     init = ops.convert_to_tensor_v2(init)
@@ -158,9 +158,8 @@ def _verify_single_loop_var(
 
   if entry.dtype != exit_.dtype:
     raise TypeError(
-        '"{}" has dtype {} before the loop, but dtype {} after one'
-        ' iteration. TensorFlow control flow requires it stays the'
-        ' same.'.format(
+        "'{}' has dtype {} before the loop, but dtype {} after one"
+        ' iteration'.format(
             name,
             entry.dtype.name,
             exit_.dtype.name,
@@ -171,19 +170,19 @@ def _verify_single_loop_var(
       entry_shape = entry.shape
       if not _is_subshape(exit_shape, entry_shape):
         raise ValueError(
-            '"{}" has shape {} before the loop, but shape {} after one'
+            "'{}' has shape {} before the loop, but shape {} after one"
             ' iteration. Use tf.autograph.experimental.set_loop_options to set'
             ' shape invariants.'.format(name, entry_shape, exit_shape))
     else:
       init_shape = init.shape
       if not _is_subshape(init_shape, shape_invariant):
         raise ValueError(
-            '"{}" has shape {} before the loop, which does not conform with'
+            "'{}' has shape {} before the loop, which does not conform with"
             ' the shape invariant {}.'.format(name, init_shape,
                                               shape_invariant))
       if not _is_subshape(exit_shape, shape_invariant):
         raise ValueError(
-            '"{}" has shape {} after one iteration, which does not conform with'
+            "'{}' has shape {} after one iteration, which does not conform with"
             ' the shape invariant {}.'.format(
                 name, exit_shape, shape_invariant))
 
@@ -216,13 +215,13 @@ def _verify_tf_loop_vars(init_vars,
       nest.assert_same_structure(init, entry, expand_composites=True)
       nest.assert_same_structure(entry, exit_, expand_composites=True)
     except (ValueError, TypeError) as e:
-      raise TypeError('"{}" does not have the same nested structure after one'
+      raise TypeError("'{}' does not have the same nested structure after one"
                       ' iteration.\n\n{}'.format(name, e))
     if invariant is not None:
       try:
         nest.assert_same_structure(init, invariant, expand_composites=False)
       except (ValueError, TypeError) as e:
-        raise TypeError('"{}" does not have the same nested structure as its'
+        raise TypeError("'{}' does not have the same nested structure as its"
                         ' corresponding shape invariant.\n\n{}'.format(name, e))
 
     nest.map_structure(
@@ -230,13 +229,13 @@ def _verify_tf_loop_vars(init_vars,
         entry, exit_, invariant)
 
 
-def _verify_single_cond_var(name, body_var, orelse_var):
+def verify_single_cond_var(name, body_var, orelse_var):
   """Verifies whether body_var and orelse_var are consistent."""
   if body_var is None:
-    raise ValueError('"{}" is None at the end of the TRUE branch.'.format(name))
+    raise ValueError("'{}' is None at the end of the main branch.".format(name))
   if orelse_var is None:
     raise ValueError(
-        '"{}" is None at the end of the FALSE branch.'.format(name))
+        "'{}' is None at the end of the else branch.".format(name))
 
   if isinstance(body_var, (bool, int, float, str, np.ndarray)):
     body_var = ops.convert_to_tensor_v2(body_var)
@@ -255,41 +254,37 @@ def _verify_single_cond_var(name, body_var, orelse_var):
 
   if body_var.dtype != orelse_var.dtype:
     raise TypeError(
-        '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
-        ' branch. TensorFlow control flow requires that they are the'
-        ' same.'.format(name, body_var.dtype.name,
-                        orelse_var.dtype.name))
+        "'{}' has dtype {} in the main branch, but dtype {} in the else"
+        ' branch'.format(name, body_var.dtype.name,
+                         orelse_var.dtype.name))
+
+
+def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name):
+  """Verifies variables output by a conditional branch for consistency."""
+  for name, var_ in zip(symbol_names, vars_):
+    if isinstance(var_, variables.Undefined):
+      raise ValueError(
+          "'{}' must also be initialized in the {} branch".format(
+              name, branch_name))
+    if isinstance(var_, variables.UndefinedReturnValue):
+      raise ValueError(
+          'the {} branch must also have a return statement.'.format(
+              branch_name))
 
 
 def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
   """Verifies variables manipulated by a conditional for consistency."""
-  basic_body_vars, composite_body_vars = body_vars
-  basic_orelse_vars, composite_orelse_vars = orelse_vars
-  assert isinstance(composite_body_vars, tuple)
-  assert isinstance(composite_orelse_vars, tuple)
-
-  # TODO(kkb): Make this more consistent.
-  # The basic outputs should always be a tuple.
-  if not isinstance(basic_body_vars, tuple):
-    basic_body_vars = (basic_body_vars,)
-  if not isinstance(basic_orelse_vars, tuple):
-    basic_orelse_vars = (basic_orelse_vars,)
-
-  body_vars = basic_body_vars + composite_body_vars
-  orelse_vars = basic_orelse_vars + composite_orelse_vars
-
   named_vars = zip(symbol_names, body_vars, orelse_vars)
+
   for name, body_var, orelse_var in named_vars:
     try:
-      nest.assert_same_structure(
-          body_var, orelse_var, expand_composites=True)
+      nest.assert_same_structure(body_var, orelse_var, expand_composites=True)
     except (ValueError, TypeError) as e:
       raise TypeError(
-          '"{}" does not have the same nested structure in the TRUE and FALSE'
-          ' branches.\n\n{}'.format(name, str(e)))
-
+          "'{}' must have the same nested structure in the main and else"
+          ' branches:\n\n{}'.format(name, str(e)))
     nest.map_structure(
-        functools.partial(_verify_single_cond_var, name), body_var, orelse_var)
+        functools.partial(verify_single_cond_var, name), body_var, orelse_var)
 
 
 def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
@@ -314,12 +309,16 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
   `extra_test`, `body`, `get_state` and `set_state` functions must bind to the
   original `geo_mean` and `arith_mean` symbols, using `nonlocal`.
 
+  The inputs and outputs of the callables representing the loop blocks are not
+  explicit - instead, these functions must use nonlocal/global for side effects.
+  The inputs and outputs are instead controlled by the set_state/get_state
+  functions.
+
   Args:
     iter_: The entity being iterated over.
-    extra_test: Callable with the state as arguments, and boolean return type.
+    extra_test: Callable with boolean return type.
       An additional loop condition.
-    body: Callable with the iterate and the state as arguments, and state as
-      return type. The actual loop body.
+    body: Callable representing the actual loop body.
     get_state: Additional callable which can capture additional state (such as
       the values of composite symbols). This is only useful when staging the
       loop.
@@ -717,11 +716,14 @@ def while_stmt(test, body, get_state, set_state, symbol_names, opts):
   a tuple of entities that represent an actual state, or a list of arguments
   of the corresponding types.
 
+  The inputs and outputs of the callables representing the loop blocks are not
+  explicit - instead, these functions must use nonlocal/global for side effects.
+  The inputs and outputs are instead controlled by the set_state/get_state
+  functions.
+
   Args:
-    test: Callable with the state as arguments, and boolean return type. The
-      loop condition.
-    body: Callable with the state as arguments, and state as return type. The
-      actual loop body.
+    test: Callable with boolean return type. The loop condition.
+    body: Callable representing the actual loop body.
     get_state: Additional callable which can capture additional state (such as
       the values of composite symbols). This is only useful when staging the
       loop.
@@ -894,21 +896,32 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
   set_state(final_loop_vars)
 
 
-def if_stmt(cond,
-            body,
-            orelse,
-            get_state,
-            set_state,
-            basic_symbol_names,
-            composite_symbol_names):
+def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
   """Functional form of an if statement.
 
+  The conditional operates on a state, which includes all symbols whose values
+  are a function of the branch taken.
+
+  For example, given the code below that calculates the abs function:
+
+  ```
+    x = 1
+    if x > 0:
+      x = -x
+  ```
+
+  The state is represented by the variable `x`. The `body, `orelse` and
+  `set_state` functions must bind to the original `x` symbol, using `nonlocal`.
+
+  The inputs and outputs of the callables representing the loop blocks are not
+  explicit - instead, these functions must use nonlocal/global for side effects.
+  The inputs and outputs are instead controlled by the set_state/get_state
+  functions.
+
   Args:
     cond: Boolean.
-    body: Callable with no arguments, and outputs of the positive (if) branch as
-      return type.
-    orelse: Callable with no arguments, and outputs of the negative (else)
-      branch as return type.
+    body: Callable representing the main block of the conditional.
+    orelse: Callable representing the else block of the conditional.
     get_state: Function that returns a tuple containing the values of all
       composite symbols modified within the conditional. This allows access to
       state that branches may mutate through side effects. This function is not
@@ -920,123 +933,63 @@ def if_stmt(cond,
       restore checkpointed values. The single argument a tuple containing values
       for each composite symbol that may be modified in a branch of the
       conditional. The is usually the result of a call to get_state.
-    basic_symbol_names: Tuple containing basic loop var names.
-    composite_symbol_names: Tuple containing composite loop var names.
-
-  Returns:
-    Tuple containing the statement outputs.
+    symbol_names: Tuple containing basic loop var names.
+    nouts: Number of variables output by the statement. Vars which are
+      not outputs will not be passed through staged control flow such as
+      tf.cond. This includes variables that are defined before the conditional,
+      but are not used after it.
   """
   # Note: tf.cond doesn't support SparseTensor.
   if tensors.is_dense_tensor(cond):
-    return tf_if_stmt(cond, body, orelse, get_state, set_state,
-                      basic_symbol_names, composite_symbol_names)
+    _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
   else:
-    return _py_if_stmt(cond, body, orelse)
+    _py_if_stmt(cond, body, orelse)
 
 
-def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
-               composite_symbol_names):
+def _tf_if_stmt(
+    cond, body, orelse, get_state, set_state, symbol_names, nouts):
   """Overload of if_stmt that stages a TF cond."""
-  body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
-  orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
-  body = _isolate_state(body, get_state, set_state)
-  orelse = _isolate_state(orelse, get_state, set_state)
+  if not nouts:
+    prev_get_state, prev_set_state = get_state, set_state
+    # Control flow V1 wants at least one output.
+    get_state = lambda: (0,) + prev_get_state()
+    set_state = lambda v: prev_set_state(v[1:])
+    symbol_names += ('<unused dummy>',)
+    nouts = 1
 
-  # `state` currently includes the values of any composite symbols (e.g. `a.b`)
-  # composites modified by the loop. `final_vars` includes the values of basic
-  # symbols (e.g. `a`) which cannot be passed by reference and must be returned.
-  # See _isolate_state.
-  # TODO(mdan): We should minimize calls to get/set_state.
+  init_vars = get_state()
 
-  body_branch = 0
-  orelse_branch = 1
-  result = [None, None]
+  # TODO(mdan): Use nonlocal once we no longer need to support py2.
+  new_body_vars_ = [None]
+  new_orelse_vars_ = [None]
 
-  def error_checking_body():
-    result[body_branch] = body()
-    if result[orelse_branch] is not None:
-      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
-                           basic_symbol_names + composite_symbol_names)
-    return result[body_branch]
+  def aug_body():
+    set_state(init_vars)
+    body()
+    new_body_vars = get_state()
+    new_body_vars = new_body_vars[:nouts]
+    new_body_vars_[0] = new_body_vars
+    _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main')
+    if new_orelse_vars_[0] is not None:
+      _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names)
+    return new_body_vars
 
-  def error_checking_orelse():
-    result[orelse_branch] = orelse()
-    if result[body_branch] is not None:
-      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
-                           basic_symbol_names + composite_symbol_names)
-    return result[orelse_branch]
+  def aug_orelse():
+    set_state(init_vars)
+    orelse()
+    new_orelse_vars = get_state()
+    new_orelse_vars = new_orelse_vars[:nouts]
+    new_orelse_vars_[0] = new_orelse_vars
+    _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else')
+    if new_body_vars_[0] is not None:
+      _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names)
+    return new_orelse_vars
 
-  final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
-                                                  error_checking_orelse)
+  final_cond_vars = control_flow_ops.cond(
+      cond, aug_body, aug_orelse, strict=True)
+  final_cond_vars = final_cond_vars + init_vars[nouts:]
 
-  set_state(final_state)
-
-  return final_vars
-
-
-def _isolate_state(func, get_state, set_state):
-  """Wraps func to (best-effort) isolate state mutations that func may do.
-
-  The simplest example of state mutation is mutation of variables (via e.g.
-  attributes), or modification of globals.
-
-  This allows us to more safely execute this function without worrying about
-  side effects when the function wasn't normally expected to execute. For
-  example, staging requires that the function is executed ahead of time, and
-  we need to ensure its effects are not observed during normal execution.
-
-  Args:
-    func: () -> Any
-    get_state: () -> Any, returns the current state
-    set_state: (Any) -> None, resets the state to the specified values.
-      Typically the result of an earlier call to `get_state`.
-
-  Returns:
-    Tuple[Any, Any], where the first element is the return value of `func`,
-    and the second is the final state values.
-  """
-
-  def wrapper():
-    init_state = get_state()
-    new_vars = func()
-    # TODO(mdan): These should be copies, lest set_state might affect them.
-    new_state = get_state()
-    set_state(init_state)
-    return new_vars, new_state
-
-  return wrapper
-
-
-def _wrap_disallow_undefs_from_cond(func, branch_name):
-  """Wraps conditional branch to disallow returning undefined symbols."""
-
-  def wrapper():
-    """Calls function and raises an error if undefined symbols are returned."""
-    results = func()
-
-    if isinstance(results, tuple):
-      results_tuple = results
-    else:
-      results_tuple = results,
-
-    for result in results_tuple:
-      if isinstance(result, variables.UndefinedReturnValue):
-        raise ValueError(
-            'A value must also be returned from the {} branch. If a value is '
-            'returned from one branch of a conditional a value must be '
-            'returned from all branches.'.format(branch_name))
-
-    undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
-    if undefined:
-      raise ValueError(
-          'The following symbols must also be initialized in the {} branch: {}.'
-          ' Alternatively, you may initialize them before the if'
-          ' statement.'.format(branch_name,
-                               tuple(s.symbol_name for s in undefined)))
-
-    return results
-
-  return wrapper
+  set_state(final_cond_vars)
 
 
 def _py_if_stmt(cond, body, orelse):
diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
index 1c4407904b2..57288be9a9f 100644
--- a/tensorflow/python/autograph/operators/control_flow_test.py
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -543,21 +543,21 @@ class ForLoopTest(test.TestCase):
     return s
 
   def test_tensor_illegal_input(self):
-    with self.assertRaisesRegex(ValueError, '"s" may not be None'):
+    with self.assertRaisesRegex(ValueError, '\'s\' may not be None'):
       self._basic_loop(None, lambda i, s: s)
-    with self.assertRaisesRegex(ValueError, '"s" must be defined'):
+    with self.assertRaisesRegex(ValueError, '\'s\' must be defined'):
       self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
 
   def test_tensor_none_output(self):
-    with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
+    with self.assertRaisesRegex(ValueError, '\'s\' is None at the end'):
       self._basic_loop(0, lambda i, s: None)
 
   def test_tensor_dtype_change(self):
-    with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'):
+    with self.assertRaisesRegex(TypeError, '\'s\'.* dtype float32 after'):
       self._basic_loop(0, lambda i, s: 1.0)
 
   def test_tensor_shape_change(self):
-    with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'):
+    with self.assertRaisesRegex(ValueError, r'\'s\'.* shape \(1,\) after'):
       self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
 
 
@@ -782,21 +782,21 @@ class WhileLoopTest(test.TestCase):
     return s
 
   def test_tensor_illegal_input(self):
-    with self.assertRaisesRegex(ValueError, '"s" may not be None'):
+    with self.assertRaisesRegex(ValueError, "'s' may not be None"):
       self._basic_loop(None, lambda i, s: s)
-    with self.assertRaisesRegex(ValueError, '"s" must be defined'):
+    with self.assertRaisesRegex(ValueError, "'s' must be defined"):
       self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
 
   def test_tensor_none_output(self):
-    with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
+    with self.assertRaisesRegex(ValueError, "'s' is None at the end"):
       self._basic_loop(0, lambda i, s: None)
 
   def test_tensor_dtype_change(self):
-    with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'):
+    with self.assertRaisesRegex(TypeError, "'s'.* dtype float32 after"):
       self._basic_loop(0, lambda i, s: 1.0)
 
   def test_tensor_shape_change(self):
-    with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'):
+    with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
       self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
 
 
@@ -806,29 +806,88 @@ class IfStmtTest(test.TestCase):
   def test_tensor(self):
 
     def test_fn(cond):
-      return control_flow.if_stmt(
+      def body():
+        nonlocal i
+        i = constant_op.constant(1)
+
+      def orelse():
+        nonlocal i
+        i = constant_op.constant(-1)
+
+      def set_state(cond_vars):
+        nonlocal i
+        i, = cond_vars
+
+      i = None
+      control_flow.if_stmt(
           cond=cond,
-          body=lambda: constant_op.constant(1),
-          orelse=lambda: constant_op.constant(-1),
-          get_state=lambda: (),
-          set_state=lambda _: None,
-          basic_symbol_names=('_',),
-          composite_symbol_names=())
+          body=body,
+          orelse=orelse,
+          get_state=lambda: (i,),
+          set_state=set_state,
+          symbol_names=('i',),
+          nouts=1)
+      return i
 
     self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
     self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
 
+  def test_tensor_no_outputs(self):
+
+    def test_fn(cond):
+      def body():
+        nonlocal i
+        i = constant_op.constant(1)
+
+      def orelse():
+        nonlocal i
+        i = constant_op.constant(-1.0)
+
+      def set_state(cond_vars):
+        nonlocal i
+        i, = cond_vars
+
+      i = None
+      control_flow.if_stmt(
+          cond=cond,
+          body=body,
+          orelse=orelse,
+          get_state=lambda: (i,),
+          set_state=set_state,
+          symbol_names=('i',),
+          nouts=0)
+      return i
+
+    self.assertEqual(None, test_fn(constant_op.constant(True)))
+    self.assertEqual(None, test_fn(constant_op.constant(False)))
+
   def test_tensor_multiple_returns(self):
 
     def test_fn(cond):
-      return control_flow.if_stmt(
+      def body():
+        nonlocal i, j
+        i = constant_op.constant(1)
+        j = constant_op.constant(2)
+
+      def orelse():
+        nonlocal i, j
+        i = constant_op.constant(-1)
+        j = constant_op.constant(-2)
+
+      def set_state(cond_vars):
+        nonlocal i, j
+        i, j = cond_vars
+
+      i, j = None, None
+      control_flow.if_stmt(
           cond=cond,
-          body=lambda: (constant_op.constant(1), constant_op.constant(2)),
-          orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
-          get_state=lambda: (),
-          set_state=lambda _: None,
-          basic_symbol_names=('_',),
-          composite_symbol_names=())
+          body=body,
+          orelse=orelse,
+          get_state=lambda: (i, j),
+          set_state=set_state,
+          symbol_names=('i', 'j'),
+          nouts=2)
+      return i, j
 
     self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
     self.assertEqual((-1, -2),
@@ -837,14 +896,24 @@ class IfStmtTest(test.TestCase):
   def test_python(self):
 
     def test_fn(cond):
-      return control_flow.if_stmt(
+      def body():
+        nonlocal i
+        i = 1
+
+      def orelse():
+        nonlocal i
+        i = -1
+
+      i = None
+      control_flow.if_stmt(
           cond=cond,
-          body=lambda: 1,
-          orelse=lambda: -1,
-          get_state=lambda: (),
-          set_state=lambda _: None,
-          basic_symbol_names=('_',),
-          composite_symbol_names=())
+          body=body,
+          orelse=orelse,
+          get_state=None,
+          set_state=None,
+          symbol_names=('i',),
+          nouts=1)
+      return i
 
     self.assertEqual(1, test_fn(True))
     self.assertEqual(-1, test_fn(False))
@@ -852,48 +921,75 @@ class IfStmtTest(test.TestCase):
   def test_python_multiple_returns(self):
 
     def test_fn(cond):
-      return control_flow.if_stmt(
+      def body():
+        nonlocal i, j
+        i = 1
+        j = 2
+
+      def orelse():
+        nonlocal i, j
+        i = -1
+        j = -2
+
+      i, j = None, None
+      control_flow.if_stmt(
           cond=cond,
-          body=lambda: (1, 2),
-          orelse=lambda: (-1, -2),
-          get_state=lambda: (),
-          set_state=lambda _: None,
-          basic_symbol_names=('_',),
-          composite_symbol_names=())
+          body=body,
+          orelse=orelse,
+          get_state=None,
+          set_state=None,
+          symbol_names=('i', 'j'),
+          nouts=2)
+      return i, j
 
     self.assertEqual((1, 2), test_fn(True))
     self.assertEqual((-1, -2), test_fn(False))
 
-  def _basic_cond(self, true_value, false_value):
+  def _basic_cond(self, body_fn, else_fn):
+    def body():
+      nonlocal x
+      x = body_fn()
+
+    def orelse():
+      nonlocal x
+      x = else_fn()
+
+    def set_state(cond_vars):
+      nonlocal x
+      x, = cond_vars
+
+    x = 0
     # Eager cond had different semantics, we don't test those here.
     with func_graph.FuncGraph('tmp').as_default():
-      return control_flow.if_stmt(
+      control_flow.if_stmt(
           cond=constant_op.constant(True),
-          body=true_value,
-          orelse=false_value,
-          get_state=lambda: (),
-          set_state=lambda _: None,
-          basic_symbol_names=('s',),
-          composite_symbol_names=())
+          body=body,
+          orelse=orelse,
+          get_state=lambda: (x,),
+          set_state=set_state,
+          symbol_names=('x',),
+          nouts=1)
+    return x
 
   def test_tensor_none_output(self):
     with self.assertRaisesRegex(
-        ValueError, '"s" is None at the end of the TRUE branch'):
+        ValueError, "'x' is None at the end of the main branch"):
       self._basic_cond(lambda: None, lambda: 1)
     with self.assertRaisesRegex(
-        ValueError, '"s" is None at the end of the FALSE branch'):
+        ValueError, "'x' is None at the end of the else branch"):
       self._basic_cond(lambda: 1, lambda: None)
 
   def test_tensor_undefined_output(self):
     with self.assertRaisesRegex(
-        ValueError, "must also be initialized in the if.*'s'"):
-      self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1)
+        ValueError, "'x' must also be initialized in the main branch"):
+      self._basic_cond(lambda: variable_operators.Undefined('x'), lambda: 1)
     with self.assertRaisesRegex(
-        ValueError, "must also be initialized in the else.*'s'"):
+        ValueError, "'x' must also be initialized in the else branch"):
       self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
 
   def test_tensor_dtype_change(self):
-    with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):
+    with self.assertRaisesRegex(
+        TypeError, "'x' has dtype int32.*but.*float32"):
       self._basic_cond(lambda: 1, lambda: 1.0)
 
 
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index ca68bc9911c..0e19da87451 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -70,6 +70,9 @@ class Scope(object):
     globals: Set[qual_names.QN], names that are explicitly marked as global in
       this scope. Note that this doesn't include free read-only vars bound to
       global symbols.
+    nonlocals: Set[qual_names.QN], names that are explicitly marked as nonlocal
+      in this scope. Note that this doesn't include free read-only vars bound to
+      global symbols.
     free_vars: Set[qual_names.QN], the free variables in this scope. See
       https://docs.python.org/3/reference/executionmodel.html for a precise
       definition.
@@ -111,6 +114,7 @@ class Scope(object):
 
     self.bound = set()
     self.globals = set()
+    self.nonlocals = set()
     self.annotations = set()
 
     self.params = weakref.WeakValueDictionary()
@@ -186,6 +190,7 @@ class Scope(object):
         self.parent.modified.update(self.modified - self.isolated_names)
         self.parent.bound.update(self.bound - self.isolated_names)
         self.parent.globals.update(self.globals)
+        self.parent.nonlocals.update(self.nonlocals)
         self.parent.annotations.update(self.annotations)
       else:
         # TODO(mdan): This is not accurate.
@@ -363,6 +368,7 @@ class ActivityAnalyzer(transformer.Base):
       qn = qual_names.QN(name)
       self.scope.read.add(qn)
       self.scope.bound.add(qn)
+      self.scope.nonlocals.add(qn)
     self._exit_and_record_scope(node)
     return node
 
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index 64b00fcbeba..ac91b662a47 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -404,6 +404,46 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
 
     self.assertHasDefinedIn(fn_body[1], ('a',))
 
+  def test_definitions_in_except_block(self):
+
+    def test_fn():
+      try:
+        pass
+      except ValueError:
+        a = None
+      if a:  # pylint:disable=using-constant-test
+        a = None
+      return a
+
+    node = self._parse_and_analyze(test_fn)
+    fn_body = node.body
+
+    self.assertHasDefs(fn_body[1].test, 1)
+    self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
+    self.assertHasDefs(fn_body[2].value, 2)
+
+    self.assertHasDefinedIn(fn_body[1], ('a',))
+
+  def test_definitions_in_except_block_of_raising_try(self):
+
+    def test_fn():
+      try:
+        raise ValueError()
+      except ValueError:
+        a = None
+      if a:  # pylint:disable=using-constant-test
+        a = None
+      return a
+
+    node = self._parse_and_analyze(test_fn)
+    fn_body = node.body
+
+    self.assertHasDefs(fn_body[1].test, 1)
+    self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
+    self.assertHasDefs(fn_body[2].value, 2)
+
+    self.assertHasDefinedIn(fn_body[1], ('a',))
+
   def test_global(self):
 
     def test_fn():