diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index ec780a7c0a1..9cf3bba8dd5 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -118,7 +118,13 @@ py_test( name = "control_flow_test", srcs = ["control_flow_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", + tags = [ + "no_oss_py2", + "no_pip", + "no_windows", + "nopip", + ], deps = [ ":converters", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py index 44ab6dee926..65fb6765fcf 100644 --- a/tensorflow/python/autograph/converters/conditional_expressions.py +++ b/tensorflow/python/autograph/converters/conditional_expressions.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates @@ -26,19 +29,20 @@ class ConditionalExpressionTransformer(converter.Base): """Converts conditional expressions to functional form.""" def visit_IfExp(self, node): - return templates.replace_as_expression( - '''ag__.if_stmt( + template = ''' + ag__.if_exp( test, lambda: true_expr, lambda: false_expr, - lambda: (), - lambda _: None, - ('<internal expr>',), - ()) - ''', + expr_repr) + ''' + expr_repr = parser.unparse(node.test, include_encoding_marker=False).strip() + return templates.replace_as_expression( + template, test=node.test, true_expr=node.body, - false_expr=node.orelse) + false_expr=node.orelse, + expr_repr=gast.Constant(expr_repr, kind=None)) def transform(node, ctx): diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index a903c43bcfc..673781e47dd 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -23,7 +23,6 @@ import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names @@ -57,114 +56,16 @@ class ControlFlowTransformer(converter.Base): fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) return self.generic_visit(node) - def _create_cond_branch(self, body_name, aliased_orig_names, - aliased_new_names, body, returns): - if len(returns) == 1: - template = """ - return retval - """ - return_stmt = templates.replace(template, retval=returns[0]) - else: - template = """ - return (retvals,) - """ - return_stmt = templates.replace(template, retvals=returns) - - if aliased_orig_names: - alias_declarations = [] - for new_name, old_name in zip(aliased_new_names, aliased_orig_names): - template = """ - try: - aliased_new_name = aliased_orig_name - except NameError: - aliased_new_name = ag__.Undefined(symbol_name) - """ - - alias_declarations.extend( - templates.replace( - template, - aliased_new_name=new_name, - aliased_orig_name=old_name, - symbol_name=gast.Constant(str(old_name), kind=None))) - - template = """ - def body_name(): - alias_declarations - body - return_stmt - """ - return templates.replace( - template, - alias_declarations=alias_declarations, - body_name=body_name, - body=body, - return_stmt=return_stmt) - else: - template = """ - def body_name(): - body - return_stmt - """ - return templates.replace( - template, body_name=body_name, body=body, return_stmt=return_stmt) - - def _create_cond_expr(self, results, test, body_name, orelse_name, - state_getter_name, state_setter_name, - basic_symbol_names, composite_symbol_names): - if results is not None: - template = """ - results = ag__.if_stmt(test, body_name, orelse_name, - state_getter_name, state_setter_name, - (basic_symbol_names,), - (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - results=results, - body_name=body_name, - orelse_name=orelse_name, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - else: - template = """ - ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, - (basic_symbol_names,), (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - body_name=body_name, - orelse_name=orelse_name, - getter_name=state_getter_name, - setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - - def _fmt_symbols(self, symbol_set): - if not symbol_set: - return 'no variables' - return ', '.join(map(str, symbol_set)) - - def _determine_aliased_symbols(self, scope, node_defined_in): - modified_live = scope.modified & node_defined_in - # Composite symbols are handled elsewhere, see _create_state_functions - return { - s for s in modified_live - if not s.is_composite() and s not in self.state[_Function].scope.globals - } - - def _create_nonlocal_declarations(self, loop_vars): + def _create_nonlocal_declarations(self, vars_): + vars_ = set(vars_) results = [] global_vars = self.state[_Function].scope.globals if global_vars: - results.append(gast.Global([str(v) for v in global_vars])) + results.append(gast.Global([str(v) for v in vars_])) nonlocal_vars = [ - v for v in loop_vars if not v.is_composite() and v not in global_vars] + v for v in vars_ if not v.is_composite() and v not in global_vars] if nonlocal_vars: results.append(gast.Nonlocal([str(v) for v in nonlocal_vars])) @@ -176,9 +77,9 @@ class ControlFlowTransformer(converter.Base): template = """ def getter_name(): return state_vars, - def setter_name(loop_vars): + def setter_name(vars_): nonlocal_declarations - state_vars, = loop_vars + state_vars, = vars_ """ return templates.replace( template, @@ -222,166 +123,34 @@ class ControlFlowTransformer(converter.Base): symbol_name=gast.Constant(s.ssf(), kind=None)) return assignments - def visit_If(self, node): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) - defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) - live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - - # Note: this information needs to be extracted before the body conversion - # that happens in the call to generic_visit below, because the conversion - # generates nodes that lack static analysis annotations. - need_alias_in_body = self._determine_aliased_symbols( - body_scope, defined_in) - need_alias_in_orelse = self._determine_aliased_symbols( - orelse_scope, defined_in) - - node = self.generic_visit(node) - - modified_in_cond = body_scope.modified | orelse_scope.modified - returned_from_cond = set() - composites = set() - for s in modified_in_cond: - if s in live_out and not s.is_composite(): - returned_from_cond.add(s) - if s.is_composite(): - # Special treatment for compound objects, always return them. - # This allows special handling within the if_stmt itself. - # For example, in TensorFlow we need to restore the state of composite - # symbols to ensure that only effects from the executed branch are seen. - composites.add(s) - - created_in_body = body_scope.modified & returned_from_cond - defined_in - created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in - - basic_created_in_body = tuple( - s for s in created_in_body if not s.is_composite()) - basic_created_in_orelse = tuple( - s for s in created_in_orelse if not s.is_composite()) - - # These variables are defined only in a single branch. This is fine in - # Python so we pass them through. Another backend, e.g. Tensorflow, may need - # to handle these cases specially or throw an Error. - possibly_undefined = (set(basic_created_in_body) ^ - set(basic_created_in_orelse)) - - # Alias the closure variables inside the conditional functions, to allow - # the functions access to the respective variables. - # We will alias variables independently for body and orelse scope, - # because different branches might write different variables. - aliased_body_orig_names = tuple(need_alias_in_body) - aliased_orelse_orig_names = tuple(need_alias_in_orelse) - aliased_body_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) - for s in aliased_body_orig_names) - aliased_orelse_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) - for s in aliased_orelse_orig_names) - - alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) - alias_orelse_map = dict( - zip(aliased_orelse_orig_names, aliased_orelse_new_names)) - - node_body = ast_util.rename_symbols(node.body, alias_body_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) - - cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) - body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) - all_referenced = body_scope.referenced | orelse_scope.referenced - state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) - state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) - - returned_from_cond = tuple(returned_from_cond) - composites = tuple(composites) - - if returned_from_cond: - if len(returned_from_cond) == 1: - cond_results = returned_from_cond[0] - else: - cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) - - returned_from_body = tuple( - alias_body_map[s] if s in need_alias_in_body else s - for s in returned_from_cond) - returned_from_orelse = tuple( - alias_orelse_map[s] if s in need_alias_in_orelse else s - for s in returned_from_cond) - - else: - # When the cond would return no value, we leave the cond called without - # results. That in turn should trigger the side effect guards. The - # branch functions will return a dummy value that ensures cond - # actually has some return value as well. - cond_results = None - # TODO(mdan): Replace with None once side_effect_guards is retired. - returned_from_body = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - returned_from_orelse = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - - cond_assign = self.create_assignment(cond_var_name, node.test) - body_def = self._create_cond_branch( - body_name, - aliased_orig_names=aliased_body_orig_names, - aliased_new_names=aliased_body_new_names, - body=node_body, - returns=returned_from_body) - orelse_def = self._create_cond_branch( - orelse_name, - aliased_orig_names=aliased_orelse_orig_names, - aliased_new_names=aliased_orelse_new_names, - body=node_orelse, - returns=returned_from_orelse) - undefined_assigns = self._create_undefined_assigns(possibly_undefined) - composite_defs = self._create_state_functions( - composites, [], state_getter_name, state_setter_name) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composites) - - cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, - orelse_name, state_getter_name, - state_setter_name, basic_symbol_names, - composite_symbol_names) - - if_ast = ( - undefined_assigns + composite_defs + body_def + orelse_def + - cond_assign + cond_expr) - return if_ast - - def _get_basic_loop_vars(self, modified, live_in, live_out): - # The loop variables corresponding to simple symbols (e.g. `x`). - basic_loop_vars = [] + def _get_block_basic_vars(self, modified, live_in, live_out): + nonlocals = self.state[_Function].scope.nonlocals + basic_scope_vars = [] for s in modified: if s.is_composite(): - # TODO(mdan): Raise an error when this happens for a TF loop. + # TODO(mdan): Raise an error when this happens for a TF scope. continue - # Variables not live into or out of the loop are considered local to the - # loop. - if s not in live_in and s not in live_out: - continue - basic_loop_vars.append(s) - return frozenset(basic_loop_vars) + # Variables not live into or out of the scope are considered local to the + # scope. + if s in live_in or s in live_out or s in nonlocals: + basic_scope_vars.append(s) + continue + return frozenset(basic_scope_vars) - def _get_composite_loop_vars(self, modified, live_in): - # The loop variables corresponding to composite symbols (e.g. `self.x`). - composite_loop_vars = [] + def _get_block_composite_vars(self, modified, live_in): + # The scope variables corresponding to composite symbols (e.g. `self.x`). + composite_scope_vars = [] for s in modified: if not s.is_composite(): continue - # Mutations made to objects created inside the loop will appear as writes + # Mutations made to objects created inside the scope will appear as writes # to composite symbols. Because these mutations appear as modifications # made to composite symbols, we check whether the composite's parent is - # actually live into the loop. + # actually live into the scope. # Example: # while cond: # x = Foo() - # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. + # x.foo = 2 * x.foo # x.foo is live into the scope, but x is not. # # Note that some parents might not be symbols - for example, in x['foo'], # 'foo' is a parent, but it's a literal, not a symbol. We don't check the @@ -390,40 +159,106 @@ class ControlFlowTransformer(converter.Base): sss for sss in s.support_set if sss.is_symbol()) if not all(sss in live_in for sss in support_set_symbols): continue - composite_loop_vars.append(s) - return frozenset(composite_loop_vars) + composite_scope_vars.append(s) + return frozenset(composite_scope_vars) - def _get_loop_vars(self, node, modified): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + def _get_block_vars(self, node, modified): + """Determines the variables affected inside a control flow statement.""" defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - reserved_symbols = body_scope.referenced - basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out) - composite_loop_vars = self._get_composite_loop_vars(modified, live_in) - loop_vars = tuple(basic_loop_vars | composite_loop_vars) + basic_scope_vars = self._get_block_basic_vars( + modified, + live_in, + live_out) + composite_scope_vars = self._get_block_composite_vars(modified, live_in) + scope_vars = tuple(basic_scope_vars | composite_scope_vars) - # Variable that are used or defined inside the loop, but not defined - # before entering the loop. Only simple variables must be defined. The + # Variables that are modified inside the scope, but not defined + # before entering it. Only simple variables must be defined. The # composite ones will be implicitly checked at runtime. - undefined_lives = basic_loop_vars - defined_in + # This covers loop variables as well as variables that + undefined = tuple(v for v in modified - defined_in if not v.is_composite()) - return loop_vars, reserved_symbols, undefined_lives + # Variables that are modified inside the scope, and depend on values outside + # it. + input_only = basic_scope_vars & live_in - live_out + + # Place the outputs first. + scope_vars = sorted(scope_vars, key=lambda v: v in input_only) + nouts = len(scope_vars) - len(input_only) + + return scope_vars, undefined, nouts + + def visit_If(self, node): + node = self.generic_visit(node) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) + + cond_vars, undefined, nouts = self._get_block_vars( + node, body_scope.modified | orelse_scope.modified) + + undefined_assigns = self._create_undefined_assigns(undefined) + + nonlocal_declarations = self._create_nonlocal_declarations(cond_vars) + + reserved = body_scope.referenced | orelse_scope.referenced + state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) + state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) + state_functions = self._create_state_functions( + cond_vars, nonlocal_declarations, state_getter_name, state_setter_name) + + orelse_body = node.orelse + if not orelse_body: + orelse_body = [gast.Pass()] + + template = """ + state_functions + def body_name(): + nonlocal_declarations + body + def orelse_name(): + nonlocal_declarations + orelse + undefined_assigns + ag__.if_stmt( + test, + body_name, + orelse_name, + state_getter_name, + state_setter_name, + (symbol_names,), + nouts) + """ + return templates.replace( + template, + body=node.body, + body_name=self.ctx.namer.new_symbol('if_body', reserved), + orelse=orelse_body, + orelse_name=self.ctx.namer.new_symbol('else_body', reserved), + nonlocal_declarations=nonlocal_declarations, + nouts=gast.Constant(nouts, kind=None), + state_functions=state_functions, + state_getter_name=state_getter_name, + state_setter_name=state_setter_name, + symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), + test=node.test, + undefined_assigns=undefined_assigns) def visit_While(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( - node, body_scope.modified) + loop_vars, undefined, _ = self._get_block_vars(node, body_scope.modified) - undefined_assigns = self._create_undefined_assigns(possibly_undefs) + undefined_assigns = self._create_undefined_assigns(undefined) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) + reserved = body_scope.referenced + state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) + state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) state_functions = self._create_state_functions( loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) @@ -448,7 +283,7 @@ class ControlFlowTransformer(converter.Base): return templates.replace( template, body=node.body, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), + body_name=self.ctx.namer.new_symbol('loop_body', reserved), nonlocal_declarations=nonlocal_declarations, opts=opts, state_functions=state_functions, @@ -456,7 +291,7 @@ class ControlFlowTransformer(converter.Base): state_setter_name=state_setter_name, symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), test=node.test, - test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), + test_name=self.ctx.namer.new_symbol('loop_test', reserved), undefined_assigns=undefined_assigns) def visit_For(self, node): @@ -464,15 +299,16 @@ class ControlFlowTransformer(converter.Base): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) - loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( + loop_vars, undefined, _ = self._get_block_vars( node, body_scope.modified | iter_scope.modified) - undefined_assigns = self._create_undefined_assigns(possibly_undefs) + undefined_assigns = self._create_undefined_assigns(undefined) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) + reserved = body_scope.referenced | iter_scope.referenced + state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) + state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) state_functions = self._create_state_functions( loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) @@ -484,7 +320,7 @@ class ControlFlowTransformer(converter.Base): if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( - 'extra_test', reserved_symbols) + 'extra_test', reserved) template = """ def extra_test_name(): nonlocal_declarations @@ -502,7 +338,7 @@ class ControlFlowTransformer(converter.Base): # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. - iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols) + iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved) template = """ iterates = iterate_arg_name """ @@ -529,7 +365,7 @@ class ControlFlowTransformer(converter.Base): return templates.replace( template, body=node.body, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), + body_name=self.ctx.namer.new_symbol('loop_body', reserved), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py index 32e86400da6..935e2cec4b8 100644 --- a/tensorflow/python/autograph/converters/control_flow_test.py +++ b/tensorflow/python/autograph/converters/control_flow_test.py @@ -1,3 +1,4 @@ +# Lint as: python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -453,6 +454,17 @@ class IfStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, constant_op.constant(1), 5) self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + def test_local_remains_local(self): + + def test_fn(n): + if n > 0: + b = 4 + n = b + 1 + return n + + self.assertTransformedResult(test_fn, constant_op.constant(1), 5) + self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + def test_no_outputs(self): def test_fn(n): @@ -465,6 +477,85 @@ class IfStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, constant_op.constant(1), 1) self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + def test_created_outputs(self): + + def test_fn(i): + if i == 0: + result = i - 1 + else: + result = i + 1 + return result + + self.assertTransformedResult(test_fn, 0, -1) + self.assertTransformedResult(test_fn, 1, 2) + + def test_created_loop_local_outputs(self): + + def test_fn(n, x): + for i in n: + if i == 0: + result = i - 1 + else: + result = i + 1 + if result > 0: + x += 1 + return x + + self.assertTransformedResult(test_fn, (range(5), 10), 14) + + def test_created_loop_variable(self): + + def test_fn(n, x): + for i in n: + if i == 0: + result = i - 1 + if i > 0: # Using the result from previous iteration. + if result < 0: + x += 1 + return x + + self.assertTransformedResult(test_fn, (range(5), 10), 14) + + def test_unaffected_global(self): + + def test_fn(i): + global g # pylint:disable=global-variable-undefined + if i == 0: + g = i - 1 + return g + + self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3}) + self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3}) + + def test_unaffected_nonlocal(self): + + def test_fn(i): + def inner_fn(): + nonlocal n + if i == 0: + n = i - 1 + + n = 3 + inner_fn() + return n + + self.assertTransformedResult(test_fn, 1, 3) + self.assertTransformedResult(test_fn, 0, -1) + + def test_output_defined_in_prior_except(self): + + def test_fn(i): + try: + raise ValueError() + except ValueError: + x = 1 + if i == 0: + x = i - 1 + return x + + self.assertTransformedResult(test_fn, 1, 1) + self.assertTransformedResult(test_fn, 0, -1) + def test_unbalanced_multiple_composites(self): class Foo(object): diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index 3851c7b44ba..5f644ea525d 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( name = "operators", srcs = [ "__init__.py", + "conditional_expressions.py", "control_flow.py", "control_flow_deprecated_py2.py", "data_structures.py", @@ -62,6 +63,20 @@ py_test( ], ) +py_test( + name = "conditional_expressions_test", + srcs = ["conditional_expressions_test.py"], + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss_py2", + ], + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "control_flow_test", srcs = ["control_flow_test.py"], diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index f7f9078107c..8ac4e1d8bb3 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -37,6 +37,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.autograph.operators.conditional_expressions import if_exp from tensorflow.python.autograph.operators.control_flow import for_stmt from tensorflow.python.autograph.operators.control_flow import if_stmt from tensorflow.python.autograph.operators.control_flow import while_stmt diff --git a/tensorflow/python/autograph/operators/conditional_expressions.py b/tensorflow/python/autograph/operators/conditional_expressions.py new file mode 100644 index 00000000000..7ea2b249935 --- /dev/null +++ b/tensorflow/python/autograph/operators/conditional_expressions.py @@ -0,0 +1,56 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conditional expressions (e.g. the ternary if statement).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.autograph.operators import control_flow +from tensorflow.python.autograph.utils import tensors +from tensorflow.python.ops import control_flow_ops + + +def if_exp(cond, if_true, if_false, expr_repr): + if tensors.is_dense_tensor(cond): + return _tf_if_exp(cond, if_true, if_false, expr_repr) + else: + return _py_if_exp(cond, if_true, if_false) + + +def _tf_if_exp(cond, if_true, if_false, expr_repr): + """Overload of if_exp that stages a TF cond.""" + # TODO(mdan): Use nonlocal once we no longer need to support py2. + true_val = [] + false_val = [] + + def true_fn(): + true_val.append(if_true()) + if true_val and false_val: + control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0]) + return true_val[0] + + def false_fn(): + false_val.append(if_false()) + if true_val and false_val: + control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0]) + return false_val[0] + + return control_flow_ops.cond(cond, true_fn, false_fn) + + +def _py_if_exp(cond, if_true, if_false): + return if_true() if cond else if_false() diff --git a/tensorflow/python/autograph/operators/conditional_expressions_test.py b/tensorflow/python/autograph/operators/conditional_expressions_test.py new file mode 100644 index 00000000000..3f126116023 --- /dev/null +++ b/tensorflow/python/autograph/operators/conditional_expressions_test.py @@ -0,0 +1,66 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for conditional_expressions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.operators import conditional_expressions +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +def _basic_expr(cond): + return conditional_expressions.if_exp( + cond, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2), + 'cond') + + +@test_util.run_all_in_graph_and_eager_modes +class IfExpTest(test.TestCase): + + def test_tensor(self): + self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(True))), 1) + self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(False))), 2) + + def test_tensor_mismatched_type(self): + # tf.function required because eager cond degenerates to Python if. + @def_function.function + def test_fn(): + conditional_expressions.if_exp( + constant_op.constant(True), lambda: 1.0, lambda: 2, 'expr_repr') + + with self.assertRaisesRegexp( + TypeError, + "'expr_repr' has dtype float32 in the main.*int32 in the else"): + test_fn() + + def test_python(self): + self.assertEqual(self.evaluate(_basic_expr(True)), 1) + self.assertEqual(self.evaluate(_basic_expr(False)), 2) + self.assertEqual( + conditional_expressions.if_exp(True, lambda: 1, lambda: 2, ''), 1) + self.assertEqual( + conditional_expressions.if_exp(False, lambda: 1, lambda: 2, ''), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 592281b0ce2..77db7579ece 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -102,7 +102,7 @@ def _verify_loop_init_vars(values, symbol_names): """Ensures that all values in the state are defined when entering a loop.""" for name, value in zip(symbol_names, values): if value is None: - raise ValueError('"{}" may not be None before the loop.'.format(name)) + raise ValueError("'{}' may not be None before the loop.".format(name)) if isinstance(value, variables.UndefinedReturnValue): # Assumption: the loop will only capture the variable which tracks the # return value if the loop contained a return statement. @@ -110,7 +110,7 @@ def _verify_loop_init_vars(values, symbol_names): raise ValueError( 'return statements are not supported within a TensorFlow loop.') if isinstance(value, variables.Undefined): - raise ValueError('"{}" must be defined before the loop.'.format(name)) + raise ValueError("'{}' must be defined before the loop.".format(name)) def _is_subshape(left, right): @@ -133,9 +133,9 @@ def _is_subshape(left, right): def _verify_single_loop_var( name, check_shape, init, entry, exit_, shape_invariant): """Verifies whether the initial, entry and exit values are consistent.""" - assert entry is not None, 'no TF op should set "{}" to None?'.format(name) + assert entry is not None, "no TF op should set '{}' to None?".format(name) if exit_ is None: - raise ValueError('"{}" is None at the end of the iteration.'.format(name)) + raise ValueError("'{}' is None at the end of the iteration.".format(name)) if isinstance(init, (bool, int, float, str, np.ndarray)): init = ops.convert_to_tensor_v2(init) @@ -158,9 +158,8 @@ def _verify_single_loop_var( if entry.dtype != exit_.dtype: raise TypeError( - '"{}" has dtype {} before the loop, but dtype {} after one' - ' iteration. TensorFlow control flow requires it stays the' - ' same.'.format( + "'{}' has dtype {} before the loop, but dtype {} after one" + ' iteration'.format( name, entry.dtype.name, exit_.dtype.name, @@ -171,19 +170,19 @@ def _verify_single_loop_var( entry_shape = entry.shape if not _is_subshape(exit_shape, entry_shape): raise ValueError( - '"{}" has shape {} before the loop, but shape {} after one' + "'{}' has shape {} before the loop, but shape {} after one" ' iteration. Use tf.autograph.experimental.set_loop_options to set' ' shape invariants.'.format(name, entry_shape, exit_shape)) else: init_shape = init.shape if not _is_subshape(init_shape, shape_invariant): raise ValueError( - '"{}" has shape {} before the loop, which does not conform with' + "'{}' has shape {} before the loop, which does not conform with" ' the shape invariant {}.'.format(name, init_shape, shape_invariant)) if not _is_subshape(exit_shape, shape_invariant): raise ValueError( - '"{}" has shape {} after one iteration, which does not conform with' + "'{}' has shape {} after one iteration, which does not conform with" ' the shape invariant {}.'.format( name, exit_shape, shape_invariant)) @@ -216,13 +215,13 @@ def _verify_tf_loop_vars(init_vars, nest.assert_same_structure(init, entry, expand_composites=True) nest.assert_same_structure(entry, exit_, expand_composites=True) except (ValueError, TypeError) as e: - raise TypeError('"{}" does not have the same nested structure after one' + raise TypeError("'{}' does not have the same nested structure after one" ' iteration.\n\n{}'.format(name, e)) if invariant is not None: try: nest.assert_same_structure(init, invariant, expand_composites=False) except (ValueError, TypeError) as e: - raise TypeError('"{}" does not have the same nested structure as its' + raise TypeError("'{}' does not have the same nested structure as its" ' corresponding shape invariant.\n\n{}'.format(name, e)) nest.map_structure( @@ -230,13 +229,13 @@ def _verify_tf_loop_vars(init_vars, entry, exit_, invariant) -def _verify_single_cond_var(name, body_var, orelse_var): +def verify_single_cond_var(name, body_var, orelse_var): """Verifies whether body_var and orelse_var are consistent.""" if body_var is None: - raise ValueError('"{}" is None at the end of the TRUE branch.'.format(name)) + raise ValueError("'{}' is None at the end of the main branch.".format(name)) if orelse_var is None: raise ValueError( - '"{}" is None at the end of the FALSE branch.'.format(name)) + "'{}' is None at the end of the else branch.".format(name)) if isinstance(body_var, (bool, int, float, str, np.ndarray)): body_var = ops.convert_to_tensor_v2(body_var) @@ -255,41 +254,37 @@ def _verify_single_cond_var(name, body_var, orelse_var): if body_var.dtype != orelse_var.dtype: raise TypeError( - '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' - ' branch. TensorFlow control flow requires that they are the' - ' same.'.format(name, body_var.dtype.name, - orelse_var.dtype.name)) + "'{}' has dtype {} in the main branch, but dtype {} in the else" + ' branch'.format(name, body_var.dtype.name, + orelse_var.dtype.name)) + + +def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name): + """Verifies variables output by a conditional branch for consistency.""" + for name, var_ in zip(symbol_names, vars_): + if isinstance(var_, variables.Undefined): + raise ValueError( + "'{}' must also be initialized in the {} branch".format( + name, branch_name)) + if isinstance(var_, variables.UndefinedReturnValue): + raise ValueError( + 'the {} branch must also have a return statement.'.format( + branch_name)) def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): """Verifies variables manipulated by a conditional for consistency.""" - basic_body_vars, composite_body_vars = body_vars - basic_orelse_vars, composite_orelse_vars = orelse_vars - assert isinstance(composite_body_vars, tuple) - assert isinstance(composite_orelse_vars, tuple) - - # TODO(kkb): Make this more consistent. - # The basic outputs should always be a tuple. - if not isinstance(basic_body_vars, tuple): - basic_body_vars = (basic_body_vars,) - if not isinstance(basic_orelse_vars, tuple): - basic_orelse_vars = (basic_orelse_vars,) - - body_vars = basic_body_vars + composite_body_vars - orelse_vars = basic_orelse_vars + composite_orelse_vars - named_vars = zip(symbol_names, body_vars, orelse_vars) + for name, body_var, orelse_var in named_vars: try: - nest.assert_same_structure( - body_var, orelse_var, expand_composites=True) + nest.assert_same_structure(body_var, orelse_var, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( - '"{}" does not have the same nested structure in the TRUE and FALSE' - ' branches.\n\n{}'.format(name, str(e))) - + "'{}' must have the same nested structure in the main and else" + ' branches:\n\n{}'.format(name, str(e))) nest.map_structure( - functools.partial(_verify_single_cond_var, name), body_var, orelse_var) + functools.partial(verify_single_cond_var, name), body_var, orelse_var) def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): @@ -314,12 +309,16 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): `extra_test`, `body`, `get_state` and `set_state` functions must bind to the original `geo_mean` and `arith_mean` symbols, using `nonlocal`. + The inputs and outputs of the callables representing the loop blocks are not + explicit - instead, these functions must use nonlocal/global for side effects. + The inputs and outputs are instead controlled by the set_state/get_state + functions. + Args: iter_: The entity being iterated over. - extra_test: Callable with the state as arguments, and boolean return type. + extra_test: Callable with boolean return type. An additional loop condition. - body: Callable with the iterate and the state as arguments, and state as - return type. The actual loop body. + body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. @@ -717,11 +716,14 @@ def while_stmt(test, body, get_state, set_state, symbol_names, opts): a tuple of entities that represent an actual state, or a list of arguments of the corresponding types. + The inputs and outputs of the callables representing the loop blocks are not + explicit - instead, these functions must use nonlocal/global for side effects. + The inputs and outputs are instead controlled by the set_state/get_state + functions. + Args: - test: Callable with the state as arguments, and boolean return type. The - loop condition. - body: Callable with the state as arguments, and state as return type. The - actual loop body. + test: Callable with boolean return type. The loop condition. + body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. @@ -894,21 +896,32 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): set_state(final_loop_vars) -def if_stmt(cond, - body, - orelse, - get_state, - set_state, - basic_symbol_names, - composite_symbol_names): +def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): """Functional form of an if statement. + The conditional operates on a state, which includes all symbols whose values + are a function of the branch taken. + + For example, given the code below that calculates the abs function: + + ``` + x = 1 + if x > 0: + x = -x + ``` + + The state is represented by the variable `x`. The `body, `orelse` and + `set_state` functions must bind to the original `x` symbol, using `nonlocal`. + + The inputs and outputs of the callables representing the loop blocks are not + explicit - instead, these functions must use nonlocal/global for side effects. + The inputs and outputs are instead controlled by the set_state/get_state + functions. + Args: cond: Boolean. - body: Callable with no arguments, and outputs of the positive (if) branch as - return type. - orelse: Callable with no arguments, and outputs of the negative (else) - branch as return type. + body: Callable representing the main block of the conditional. + orelse: Callable representing the else block of the conditional. get_state: Function that returns a tuple containing the values of all composite symbols modified within the conditional. This allows access to state that branches may mutate through side effects. This function is not @@ -920,123 +933,63 @@ def if_stmt(cond, restore checkpointed values. The single argument a tuple containing values for each composite symbol that may be modified in a branch of the conditional. The is usually the result of a call to get_state. - basic_symbol_names: Tuple containing basic loop var names. - composite_symbol_names: Tuple containing composite loop var names. - - Returns: - Tuple containing the statement outputs. + symbol_names: Tuple containing basic loop var names. + nouts: Number of variables output by the statement. Vars which are + not outputs will not be passed through staged control flow such as + tf.cond. This includes variables that are defined before the conditional, + but are not used after it. """ # Note: tf.cond doesn't support SparseTensor. if tensors.is_dense_tensor(cond): - return tf_if_stmt(cond, body, orelse, get_state, set_state, - basic_symbol_names, composite_symbol_names) + _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) else: - return _py_if_stmt(cond, body, orelse) + _py_if_stmt(cond, body, orelse) -def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, - composite_symbol_names): +def _tf_if_stmt( + cond, body, orelse, get_state, set_state, symbol_names, nouts): """Overload of if_stmt that stages a TF cond.""" - body = _wrap_disallow_undefs_from_cond(body, branch_name='if') - orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else') - body = _isolate_state(body, get_state, set_state) - orelse = _isolate_state(orelse, get_state, set_state) + if not nouts: + prev_get_state, prev_set_state = get_state, set_state + # Control flow V1 wants at least one output. + get_state = lambda: (0,) + prev_get_state() + set_state = lambda v: prev_set_state(v[1:]) + symbol_names += ('<unused dummy>',) + nouts = 1 - # `state` currently includes the values of any composite symbols (e.g. `a.b`) - # composites modified by the loop. `final_vars` includes the values of basic - # symbols (e.g. `a`) which cannot be passed by reference and must be returned. - # See _isolate_state. - # TODO(mdan): We should minimize calls to get/set_state. + init_vars = get_state() - body_branch = 0 - orelse_branch = 1 - result = [None, None] + # TODO(mdan): Use nonlocal once we no longer need to support py2. + new_body_vars_ = [None] + new_orelse_vars_ = [None] - def error_checking_body(): - result[body_branch] = body() - if result[orelse_branch] is not None: - _verify_tf_cond_vars(result[body_branch], result[orelse_branch], - basic_symbol_names + composite_symbol_names) - return result[body_branch] + def aug_body(): + set_state(init_vars) + body() + new_body_vars = get_state() + new_body_vars = new_body_vars[:nouts] + new_body_vars_[0] = new_body_vars + _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main') + if new_orelse_vars_[0] is not None: + _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names) + return new_body_vars - def error_checking_orelse(): - result[orelse_branch] = orelse() - if result[body_branch] is not None: - _verify_tf_cond_vars(result[body_branch], result[orelse_branch], - basic_symbol_names + composite_symbol_names) - return result[orelse_branch] + def aug_orelse(): + set_state(init_vars) + orelse() + new_orelse_vars = get_state() + new_orelse_vars = new_orelse_vars[:nouts] + new_orelse_vars_[0] = new_orelse_vars + _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else') + if new_body_vars_[0] is not None: + _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names) + return new_orelse_vars - final_vars, final_state = control_flow_ops.cond(cond, error_checking_body, - error_checking_orelse) + final_cond_vars = control_flow_ops.cond( + cond, aug_body, aug_orelse, strict=True) + final_cond_vars = final_cond_vars + init_vars[nouts:] - set_state(final_state) - - return final_vars - - -def _isolate_state(func, get_state, set_state): - """Wraps func to (best-effort) isolate state mutations that func may do. - - The simplest example of state mutation is mutation of variables (via e.g. - attributes), or modification of globals. - - This allows us to more safely execute this function without worrying about - side effects when the function wasn't normally expected to execute. For - example, staging requires that the function is executed ahead of time, and - we need to ensure its effects are not observed during normal execution. - - Args: - func: () -> Any - get_state: () -> Any, returns the current state - set_state: (Any) -> None, resets the state to the specified values. - Typically the result of an earlier call to `get_state`. - - Returns: - Tuple[Any, Any], where the first element is the return value of `func`, - and the second is the final state values. - """ - - def wrapper(): - init_state = get_state() - new_vars = func() - # TODO(mdan): These should be copies, lest set_state might affect them. - new_state = get_state() - set_state(init_state) - return new_vars, new_state - - return wrapper - - -def _wrap_disallow_undefs_from_cond(func, branch_name): - """Wraps conditional branch to disallow returning undefined symbols.""" - - def wrapper(): - """Calls function and raises an error if undefined symbols are returned.""" - results = func() - - if isinstance(results, tuple): - results_tuple = results - else: - results_tuple = results, - - for result in results_tuple: - if isinstance(result, variables.UndefinedReturnValue): - raise ValueError( - 'A value must also be returned from the {} branch. If a value is ' - 'returned from one branch of a conditional a value must be ' - 'returned from all branches.'.format(branch_name)) - - undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)] - if undefined: - raise ValueError( - 'The following symbols must also be initialized in the {} branch: {}.' - ' Alternatively, you may initialize them before the if' - ' statement.'.format(branch_name, - tuple(s.symbol_name for s in undefined))) - - return results - - return wrapper + set_state(final_cond_vars) def _py_if_stmt(cond, body, orelse): diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 1c4407904b2..57288be9a9f 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -543,21 +543,21 @@ class ForLoopTest(test.TestCase): return s def test_tensor_illegal_input(self): - with self.assertRaisesRegex(ValueError, '"s" may not be None'): + with self.assertRaisesRegex(ValueError, '\'s\' may not be None'): self._basic_loop(None, lambda i, s: s) - with self.assertRaisesRegex(ValueError, '"s" must be defined'): + with self.assertRaisesRegex(ValueError, '\'s\' must be defined'): self._basic_loop(variable_operators.Undefined(''), lambda i, s: s) def test_tensor_none_output(self): - with self.assertRaisesRegex(ValueError, '"s" is None at the end'): + with self.assertRaisesRegex(ValueError, '\'s\' is None at the end'): self._basic_loop(0, lambda i, s: None) def test_tensor_dtype_change(self): - with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'): + with self.assertRaisesRegex(TypeError, '\'s\'.* dtype float32 after'): self._basic_loop(0, lambda i, s: 1.0) def test_tensor_shape_change(self): - with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'): + with self.assertRaisesRegex(ValueError, r'\'s\'.* shape \(1,\) after'): self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) @@ -782,21 +782,21 @@ class WhileLoopTest(test.TestCase): return s def test_tensor_illegal_input(self): - with self.assertRaisesRegex(ValueError, '"s" may not be None'): + with self.assertRaisesRegex(ValueError, "'s' may not be None"): self._basic_loop(None, lambda i, s: s) - with self.assertRaisesRegex(ValueError, '"s" must be defined'): + with self.assertRaisesRegex(ValueError, "'s' must be defined"): self._basic_loop(variable_operators.Undefined(''), lambda i, s: s) def test_tensor_none_output(self): - with self.assertRaisesRegex(ValueError, '"s" is None at the end'): + with self.assertRaisesRegex(ValueError, "'s' is None at the end"): self._basic_loop(0, lambda i, s: None) def test_tensor_dtype_change(self): - with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'): + with self.assertRaisesRegex(TypeError, "'s'.* dtype float32 after"): self._basic_loop(0, lambda i, s: 1.0) def test_tensor_shape_change(self): - with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'): + with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"): self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) @@ -806,29 +806,88 @@ class IfStmtTest(test.TestCase): def test_tensor(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i + i = constant_op.constant(1) + + def orelse(): + nonlocal i + i = constant_op.constant(-1) + + def set_state(cond_vars): + nonlocal i + i, = cond_vars + + i = None + control_flow.if_stmt( cond=cond, - body=lambda: constant_op.constant(1), - orelse=lambda: constant_op.constant(-1), - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=lambda: (i,), + set_state=set_state, + symbol_names=('i',), + nouts=1) + return i self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False)))) + def test_tensor_no_outputs(self): + + def test_fn(cond): + def body(): + nonlocal i + i = constant_op.constant(1) + + def orelse(): + nonlocal i + i = constant_op.constant(-1.0) + + def set_state(cond_vars): + nonlocal i + i, = cond_vars + + i = None + control_flow.if_stmt( + cond=cond, + body=body, + orelse=orelse, + get_state=lambda: (i,), + set_state=set_state, + symbol_names=('i',), + nouts=0) + return i + + self.assertEqual(None, test_fn(constant_op.constant(True))) + self.assertEqual(None, test_fn(constant_op.constant(False))) + def test_tensor_multiple_returns(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i, j + i = constant_op.constant(1) + j = constant_op.constant(2) + + def orelse(): + nonlocal i, j + i = constant_op.constant(-1) + j = constant_op.constant(-2) + + def set_state(cond_vars): + nonlocal i, j + i, j = cond_vars + + i, j = None, None + control_flow.if_stmt( cond=cond, - body=lambda: (constant_op.constant(1), constant_op.constant(2)), - orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)), - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=lambda: (i, j), + set_state=set_state, + symbol_names=('i', 'j'), + nouts=2) + return i, j self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual((-1, -2), @@ -837,14 +896,24 @@ class IfStmtTest(test.TestCase): def test_python(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i + i = 1 + + def orelse(): + nonlocal i + i = -1 + + i = None + control_flow.if_stmt( cond=cond, - body=lambda: 1, - orelse=lambda: -1, - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=None, + set_state=None, + symbol_names=('i',), + nouts=1) + return i self.assertEqual(1, test_fn(True)) self.assertEqual(-1, test_fn(False)) @@ -852,48 +921,75 @@ class IfStmtTest(test.TestCase): def test_python_multiple_returns(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i, j + i = 1 + j = 2 + + def orelse(): + nonlocal i, j + i = -1 + j = -2 + + i, j = None, None + control_flow.if_stmt( cond=cond, - body=lambda: (1, 2), - orelse=lambda: (-1, -2), - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=None, + set_state=None, + symbol_names=('i', 'j'), + nouts=2) + return i, j self.assertEqual((1, 2), test_fn(True)) self.assertEqual((-1, -2), test_fn(False)) - def _basic_cond(self, true_value, false_value): + def _basic_cond(self, body_fn, else_fn): + def body(): + nonlocal x + x = body_fn() + + def orelse(): + nonlocal x + x = else_fn() + + def set_state(cond_vars): + nonlocal x + x, = cond_vars + + x = 0 # Eager cond had different semantics, we don't test those here. with func_graph.FuncGraph('tmp').as_default(): - return control_flow.if_stmt( + control_flow.if_stmt( cond=constant_op.constant(True), - body=true_value, - orelse=false_value, - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('s',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=lambda: (x,), + set_state=set_state, + symbol_names=('x',), + nouts=1) + return x def test_tensor_none_output(self): with self.assertRaisesRegex( - ValueError, '"s" is None at the end of the TRUE branch'): + ValueError, "'x' is None at the end of the main branch"): self._basic_cond(lambda: None, lambda: 1) with self.assertRaisesRegex( - ValueError, '"s" is None at the end of the FALSE branch'): + ValueError, "'x' is None at the end of the else branch"): self._basic_cond(lambda: 1, lambda: None) def test_tensor_undefined_output(self): with self.assertRaisesRegex( - ValueError, "must also be initialized in the if.*'s'"): - self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1) + ValueError, "'x' must also be initialized in the main branch"): + self._basic_cond(lambda: variable_operators.Undefined('x'), lambda: 1) with self.assertRaisesRegex( - ValueError, "must also be initialized in the else.*'s'"): + ValueError, "'x' must also be initialized in the else branch"): self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s')) def test_tensor_dtype_change(self): - with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'): + with self.assertRaisesRegex( + TypeError, "'x' has dtype int32.*but.*float32"): self._basic_cond(lambda: 1, lambda: 1.0) diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index ca68bc9911c..0e19da87451 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -70,6 +70,9 @@ class Scope(object): globals: Set[qual_names.QN], names that are explicitly marked as global in this scope. Note that this doesn't include free read-only vars bound to global symbols. + nonlocals: Set[qual_names.QN], names that are explicitly marked as nonlocal + in this scope. Note that this doesn't include free read-only vars bound to + global symbols. free_vars: Set[qual_names.QN], the free variables in this scope. See https://docs.python.org/3/reference/executionmodel.html for a precise definition. @@ -111,6 +114,7 @@ class Scope(object): self.bound = set() self.globals = set() + self.nonlocals = set() self.annotations = set() self.params = weakref.WeakValueDictionary() @@ -186,6 +190,7 @@ class Scope(object): self.parent.modified.update(self.modified - self.isolated_names) self.parent.bound.update(self.bound - self.isolated_names) self.parent.globals.update(self.globals) + self.parent.nonlocals.update(self.nonlocals) self.parent.annotations.update(self.annotations) else: # TODO(mdan): This is not accurate. @@ -363,6 +368,7 @@ class ActivityAnalyzer(transformer.Base): qn = qual_names.QN(name) self.scope.read.add(qn) self.scope.bound.add(qn) + self.scope.nonlocals.add(qn) self._exit_and_record_scope(node) return node diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index 64b00fcbeba..ac91b662a47 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -404,6 +404,46 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase): self.assertHasDefinedIn(fn_body[1], ('a',)) + def test_definitions_in_except_block(self): + + def test_fn(): + try: + pass + except ValueError: + a = None + if a: # pylint:disable=using-constant-test + a = None + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body + + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a',)) + + def test_definitions_in_except_block_of_raising_try(self): + + def test_fn(): + try: + raise ValueError() + except ValueError: + a = None + if a: # pylint:disable=using-constant-test + a = None + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body + + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a',)) + def test_global(self): def test_fn():