Use the nonlocal mechanism for if statements. This is the same mechanism used by for and while loops and it allows reusing much of the code.
This required the ternary if operator to be split in a separate implementation, but that better accounts for its different nature. This should also allow more consistent verification and error messages throughout. PiperOrigin-RevId: 312360755 Change-Id: I57989c6cd40a16653521e18ccf21f2b0e994bd96
This commit is contained in:
parent
baa3e80ca5
commit
53215ab702
@ -118,7 +118,13 @@ py_test(
|
||||
name = "control_flow_test",
|
||||
srcs = ["control_flow_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
srcs_version = "PY3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"nopip",
|
||||
],
|
||||
deps = [
|
||||
":converters",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -18,7 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
|
||||
|
||||
@ -26,19 +29,20 @@ class ConditionalExpressionTransformer(converter.Base):
|
||||
"""Converts conditional expressions to functional form."""
|
||||
|
||||
def visit_IfExp(self, node):
|
||||
return templates.replace_as_expression(
|
||||
'''ag__.if_stmt(
|
||||
template = '''
|
||||
ag__.if_exp(
|
||||
test,
|
||||
lambda: true_expr,
|
||||
lambda: false_expr,
|
||||
lambda: (),
|
||||
lambda _: None,
|
||||
('<internal expr>',),
|
||||
())
|
||||
''',
|
||||
expr_repr)
|
||||
'''
|
||||
expr_repr = parser.unparse(node.test, include_encoding_marker=False).strip()
|
||||
return templates.replace_as_expression(
|
||||
template,
|
||||
test=node.test,
|
||||
true_expr=node.body,
|
||||
false_expr=node.orelse)
|
||||
false_expr=node.orelse,
|
||||
expr_repr=gast.Constant(expr_repr, kind=None))
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
|
@ -23,7 +23,6 @@ import gast
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
@ -57,114 +56,16 @@ class ControlFlowTransformer(converter.Base):
|
||||
fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def _create_cond_branch(self, body_name, aliased_orig_names,
|
||||
aliased_new_names, body, returns):
|
||||
if len(returns) == 1:
|
||||
template = """
|
||||
return retval
|
||||
"""
|
||||
return_stmt = templates.replace(template, retval=returns[0])
|
||||
else:
|
||||
template = """
|
||||
return (retvals,)
|
||||
"""
|
||||
return_stmt = templates.replace(template, retvals=returns)
|
||||
|
||||
if aliased_orig_names:
|
||||
alias_declarations = []
|
||||
for new_name, old_name in zip(aliased_new_names, aliased_orig_names):
|
||||
template = """
|
||||
try:
|
||||
aliased_new_name = aliased_orig_name
|
||||
except NameError:
|
||||
aliased_new_name = ag__.Undefined(symbol_name)
|
||||
"""
|
||||
|
||||
alias_declarations.extend(
|
||||
templates.replace(
|
||||
template,
|
||||
aliased_new_name=new_name,
|
||||
aliased_orig_name=old_name,
|
||||
symbol_name=gast.Constant(str(old_name), kind=None)))
|
||||
|
||||
template = """
|
||||
def body_name():
|
||||
alias_declarations
|
||||
body
|
||||
return_stmt
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
alias_declarations=alias_declarations,
|
||||
body_name=body_name,
|
||||
body=body,
|
||||
return_stmt=return_stmt)
|
||||
else:
|
||||
template = """
|
||||
def body_name():
|
||||
body
|
||||
return_stmt
|
||||
"""
|
||||
return templates.replace(
|
||||
template, body_name=body_name, body=body, return_stmt=return_stmt)
|
||||
|
||||
def _create_cond_expr(self, results, test, body_name, orelse_name,
|
||||
state_getter_name, state_setter_name,
|
||||
basic_symbol_names, composite_symbol_names):
|
||||
if results is not None:
|
||||
template = """
|
||||
results = ag__.if_stmt(test, body_name, orelse_name,
|
||||
state_getter_name, state_setter_name,
|
||||
(basic_symbol_names,),
|
||||
(composite_symbol_names,))
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
test=test,
|
||||
results=results,
|
||||
body_name=body_name,
|
||||
orelse_name=orelse_name,
|
||||
state_getter_name=state_getter_name,
|
||||
state_setter_name=state_setter_name,
|
||||
basic_symbol_names=basic_symbol_names,
|
||||
composite_symbol_names=composite_symbol_names)
|
||||
else:
|
||||
template = """
|
||||
ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name,
|
||||
(basic_symbol_names,), (composite_symbol_names,))
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
test=test,
|
||||
body_name=body_name,
|
||||
orelse_name=orelse_name,
|
||||
getter_name=state_getter_name,
|
||||
setter_name=state_setter_name,
|
||||
basic_symbol_names=basic_symbol_names,
|
||||
composite_symbol_names=composite_symbol_names)
|
||||
|
||||
def _fmt_symbols(self, symbol_set):
|
||||
if not symbol_set:
|
||||
return 'no variables'
|
||||
return ', '.join(map(str, symbol_set))
|
||||
|
||||
def _determine_aliased_symbols(self, scope, node_defined_in):
|
||||
modified_live = scope.modified & node_defined_in
|
||||
# Composite symbols are handled elsewhere, see _create_state_functions
|
||||
return {
|
||||
s for s in modified_live
|
||||
if not s.is_composite() and s not in self.state[_Function].scope.globals
|
||||
}
|
||||
|
||||
def _create_nonlocal_declarations(self, loop_vars):
|
||||
def _create_nonlocal_declarations(self, vars_):
|
||||
vars_ = set(vars_)
|
||||
results = []
|
||||
global_vars = self.state[_Function].scope.globals
|
||||
|
||||
if global_vars:
|
||||
results.append(gast.Global([str(v) for v in global_vars]))
|
||||
results.append(gast.Global([str(v) for v in vars_]))
|
||||
|
||||
nonlocal_vars = [
|
||||
v for v in loop_vars if not v.is_composite() and v not in global_vars]
|
||||
v for v in vars_ if not v.is_composite() and v not in global_vars]
|
||||
if nonlocal_vars:
|
||||
results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))
|
||||
|
||||
@ -176,9 +77,9 @@ class ControlFlowTransformer(converter.Base):
|
||||
template = """
|
||||
def getter_name():
|
||||
return state_vars,
|
||||
def setter_name(loop_vars):
|
||||
def setter_name(vars_):
|
||||
nonlocal_declarations
|
||||
state_vars, = loop_vars
|
||||
state_vars, = vars_
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
@ -222,166 +123,34 @@ class ControlFlowTransformer(converter.Base):
|
||||
symbol_name=gast.Constant(s.ssf(), kind=None))
|
||||
return assignments
|
||||
|
||||
def visit_If(self, node):
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
|
||||
defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
|
||||
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
|
||||
|
||||
# Note: this information needs to be extracted before the body conversion
|
||||
# that happens in the call to generic_visit below, because the conversion
|
||||
# generates nodes that lack static analysis annotations.
|
||||
need_alias_in_body = self._determine_aliased_symbols(
|
||||
body_scope, defined_in)
|
||||
need_alias_in_orelse = self._determine_aliased_symbols(
|
||||
orelse_scope, defined_in)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
modified_in_cond = body_scope.modified | orelse_scope.modified
|
||||
returned_from_cond = set()
|
||||
composites = set()
|
||||
for s in modified_in_cond:
|
||||
if s in live_out and not s.is_composite():
|
||||
returned_from_cond.add(s)
|
||||
if s.is_composite():
|
||||
# Special treatment for compound objects, always return them.
|
||||
# This allows special handling within the if_stmt itself.
|
||||
# For example, in TensorFlow we need to restore the state of composite
|
||||
# symbols to ensure that only effects from the executed branch are seen.
|
||||
composites.add(s)
|
||||
|
||||
created_in_body = body_scope.modified & returned_from_cond - defined_in
|
||||
created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
|
||||
|
||||
basic_created_in_body = tuple(
|
||||
s for s in created_in_body if not s.is_composite())
|
||||
basic_created_in_orelse = tuple(
|
||||
s for s in created_in_orelse if not s.is_composite())
|
||||
|
||||
# These variables are defined only in a single branch. This is fine in
|
||||
# Python so we pass them through. Another backend, e.g. Tensorflow, may need
|
||||
# to handle these cases specially or throw an Error.
|
||||
possibly_undefined = (set(basic_created_in_body) ^
|
||||
set(basic_created_in_orelse))
|
||||
|
||||
# Alias the closure variables inside the conditional functions, to allow
|
||||
# the functions access to the respective variables.
|
||||
# We will alias variables independently for body and orelse scope,
|
||||
# because different branches might write different variables.
|
||||
aliased_body_orig_names = tuple(need_alias_in_body)
|
||||
aliased_orelse_orig_names = tuple(need_alias_in_orelse)
|
||||
aliased_body_new_names = tuple(
|
||||
self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
|
||||
for s in aliased_body_orig_names)
|
||||
aliased_orelse_new_names = tuple(
|
||||
self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
|
||||
for s in aliased_orelse_orig_names)
|
||||
|
||||
alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
|
||||
alias_orelse_map = dict(
|
||||
zip(aliased_orelse_orig_names, aliased_orelse_new_names))
|
||||
|
||||
node_body = ast_util.rename_symbols(node.body, alias_body_map)
|
||||
node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
|
||||
|
||||
cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
|
||||
body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
|
||||
orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
|
||||
all_referenced = body_scope.referenced | orelse_scope.referenced
|
||||
state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
|
||||
state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
|
||||
|
||||
returned_from_cond = tuple(returned_from_cond)
|
||||
composites = tuple(composites)
|
||||
|
||||
if returned_from_cond:
|
||||
if len(returned_from_cond) == 1:
|
||||
cond_results = returned_from_cond[0]
|
||||
else:
|
||||
cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
|
||||
|
||||
returned_from_body = tuple(
|
||||
alias_body_map[s] if s in need_alias_in_body else s
|
||||
for s in returned_from_cond)
|
||||
returned_from_orelse = tuple(
|
||||
alias_orelse_map[s] if s in need_alias_in_orelse else s
|
||||
for s in returned_from_cond)
|
||||
|
||||
else:
|
||||
# When the cond would return no value, we leave the cond called without
|
||||
# results. That in turn should trigger the side effect guards. The
|
||||
# branch functions will return a dummy value that ensures cond
|
||||
# actually has some return value as well.
|
||||
cond_results = None
|
||||
# TODO(mdan): Replace with None once side_effect_guards is retired.
|
||||
returned_from_body = (templates.replace_as_expression(
|
||||
'ag__.match_staging_level(1, cond_var_name)',
|
||||
cond_var_name=cond_var_name),)
|
||||
returned_from_orelse = (templates.replace_as_expression(
|
||||
'ag__.match_staging_level(1, cond_var_name)',
|
||||
cond_var_name=cond_var_name),)
|
||||
|
||||
cond_assign = self.create_assignment(cond_var_name, node.test)
|
||||
body_def = self._create_cond_branch(
|
||||
body_name,
|
||||
aliased_orig_names=aliased_body_orig_names,
|
||||
aliased_new_names=aliased_body_new_names,
|
||||
body=node_body,
|
||||
returns=returned_from_body)
|
||||
orelse_def = self._create_cond_branch(
|
||||
orelse_name,
|
||||
aliased_orig_names=aliased_orelse_orig_names,
|
||||
aliased_new_names=aliased_orelse_new_names,
|
||||
body=node_orelse,
|
||||
returns=returned_from_orelse)
|
||||
undefined_assigns = self._create_undefined_assigns(possibly_undefined)
|
||||
composite_defs = self._create_state_functions(
|
||||
composites, [], state_getter_name, state_setter_name)
|
||||
|
||||
basic_symbol_names = tuple(
|
||||
gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond)
|
||||
composite_symbol_names = tuple(
|
||||
gast.Constant(str(symbol), kind=None) for symbol in composites)
|
||||
|
||||
cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
|
||||
orelse_name, state_getter_name,
|
||||
state_setter_name, basic_symbol_names,
|
||||
composite_symbol_names)
|
||||
|
||||
if_ast = (
|
||||
undefined_assigns + composite_defs + body_def + orelse_def +
|
||||
cond_assign + cond_expr)
|
||||
return if_ast
|
||||
|
||||
def _get_basic_loop_vars(self, modified, live_in, live_out):
|
||||
# The loop variables corresponding to simple symbols (e.g. `x`).
|
||||
basic_loop_vars = []
|
||||
def _get_block_basic_vars(self, modified, live_in, live_out):
|
||||
nonlocals = self.state[_Function].scope.nonlocals
|
||||
basic_scope_vars = []
|
||||
for s in modified:
|
||||
if s.is_composite():
|
||||
# TODO(mdan): Raise an error when this happens for a TF loop.
|
||||
# TODO(mdan): Raise an error when this happens for a TF scope.
|
||||
continue
|
||||
# Variables not live into or out of the loop are considered local to the
|
||||
# loop.
|
||||
if s not in live_in and s not in live_out:
|
||||
continue
|
||||
basic_loop_vars.append(s)
|
||||
return frozenset(basic_loop_vars)
|
||||
# Variables not live into or out of the scope are considered local to the
|
||||
# scope.
|
||||
if s in live_in or s in live_out or s in nonlocals:
|
||||
basic_scope_vars.append(s)
|
||||
continue
|
||||
return frozenset(basic_scope_vars)
|
||||
|
||||
def _get_composite_loop_vars(self, modified, live_in):
|
||||
# The loop variables corresponding to composite symbols (e.g. `self.x`).
|
||||
composite_loop_vars = []
|
||||
def _get_block_composite_vars(self, modified, live_in):
|
||||
# The scope variables corresponding to composite symbols (e.g. `self.x`).
|
||||
composite_scope_vars = []
|
||||
for s in modified:
|
||||
if not s.is_composite():
|
||||
continue
|
||||
# Mutations made to objects created inside the loop will appear as writes
|
||||
# Mutations made to objects created inside the scope will appear as writes
|
||||
# to composite symbols. Because these mutations appear as modifications
|
||||
# made to composite symbols, we check whether the composite's parent is
|
||||
# actually live into the loop.
|
||||
# actually live into the scope.
|
||||
# Example:
|
||||
# while cond:
|
||||
# x = Foo()
|
||||
# x.foo = 2 * x.foo # x.foo is live into the loop, but x is not.
|
||||
# x.foo = 2 * x.foo # x.foo is live into the scope, but x is not.
|
||||
#
|
||||
# Note that some parents might not be symbols - for example, in x['foo'],
|
||||
# 'foo' is a parent, but it's a literal, not a symbol. We don't check the
|
||||
@ -390,40 +159,106 @@ class ControlFlowTransformer(converter.Base):
|
||||
sss for sss in s.support_set if sss.is_symbol())
|
||||
if not all(sss in live_in for sss in support_set_symbols):
|
||||
continue
|
||||
composite_loop_vars.append(s)
|
||||
return frozenset(composite_loop_vars)
|
||||
composite_scope_vars.append(s)
|
||||
return frozenset(composite_scope_vars)
|
||||
|
||||
def _get_loop_vars(self, node, modified):
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
def _get_block_vars(self, node, modified):
|
||||
"""Determines the variables affected inside a control flow statement."""
|
||||
defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
|
||||
live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
|
||||
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
|
||||
reserved_symbols = body_scope.referenced
|
||||
|
||||
basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out)
|
||||
composite_loop_vars = self._get_composite_loop_vars(modified, live_in)
|
||||
loop_vars = tuple(basic_loop_vars | composite_loop_vars)
|
||||
basic_scope_vars = self._get_block_basic_vars(
|
||||
modified,
|
||||
live_in,
|
||||
live_out)
|
||||
composite_scope_vars = self._get_block_composite_vars(modified, live_in)
|
||||
scope_vars = tuple(basic_scope_vars | composite_scope_vars)
|
||||
|
||||
# Variable that are used or defined inside the loop, but not defined
|
||||
# before entering the loop. Only simple variables must be defined. The
|
||||
# Variables that are modified inside the scope, but not defined
|
||||
# before entering it. Only simple variables must be defined. The
|
||||
# composite ones will be implicitly checked at runtime.
|
||||
undefined_lives = basic_loop_vars - defined_in
|
||||
# This covers loop variables as well as variables that
|
||||
undefined = tuple(v for v in modified - defined_in if not v.is_composite())
|
||||
|
||||
return loop_vars, reserved_symbols, undefined_lives
|
||||
# Variables that are modified inside the scope, and depend on values outside
|
||||
# it.
|
||||
input_only = basic_scope_vars & live_in - live_out
|
||||
|
||||
# Place the outputs first.
|
||||
scope_vars = sorted(scope_vars, key=lambda v: v in input_only)
|
||||
nouts = len(scope_vars) - len(input_only)
|
||||
|
||||
return scope_vars, undefined, nouts
|
||||
|
||||
def visit_If(self, node):
|
||||
node = self.generic_visit(node)
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
|
||||
|
||||
cond_vars, undefined, nouts = self._get_block_vars(
|
||||
node, body_scope.modified | orelse_scope.modified)
|
||||
|
||||
undefined_assigns = self._create_undefined_assigns(undefined)
|
||||
|
||||
nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)
|
||||
|
||||
reserved = body_scope.referenced | orelse_scope.referenced
|
||||
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
|
||||
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
|
||||
state_functions = self._create_state_functions(
|
||||
cond_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||||
|
||||
orelse_body = node.orelse
|
||||
if not orelse_body:
|
||||
orelse_body = [gast.Pass()]
|
||||
|
||||
template = """
|
||||
state_functions
|
||||
def body_name():
|
||||
nonlocal_declarations
|
||||
body
|
||||
def orelse_name():
|
||||
nonlocal_declarations
|
||||
orelse
|
||||
undefined_assigns
|
||||
ag__.if_stmt(
|
||||
test,
|
||||
body_name,
|
||||
orelse_name,
|
||||
state_getter_name,
|
||||
state_setter_name,
|
||||
(symbol_names,),
|
||||
nouts)
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
body_name=self.ctx.namer.new_symbol('if_body', reserved),
|
||||
orelse=orelse_body,
|
||||
orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
|
||||
nonlocal_declarations=nonlocal_declarations,
|
||||
nouts=gast.Constant(nouts, kind=None),
|
||||
state_functions=state_functions,
|
||||
state_getter_name=state_getter_name,
|
||||
state_setter_name=state_setter_name,
|
||||
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
|
||||
test=node.test,
|
||||
undefined_assigns=undefined_assigns)
|
||||
|
||||
def visit_While(self, node):
|
||||
node = self.generic_visit(node)
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
|
||||
loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
|
||||
node, body_scope.modified)
|
||||
loop_vars, undefined, _ = self._get_block_vars(node, body_scope.modified)
|
||||
|
||||
undefined_assigns = self._create_undefined_assigns(possibly_undefs)
|
||||
undefined_assigns = self._create_undefined_assigns(undefined)
|
||||
|
||||
nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
|
||||
|
||||
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
|
||||
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
|
||||
reserved = body_scope.referenced
|
||||
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
|
||||
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
|
||||
state_functions = self._create_state_functions(
|
||||
loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||||
|
||||
@ -448,7 +283,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
return templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
|
||||
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
|
||||
nonlocal_declarations=nonlocal_declarations,
|
||||
opts=opts,
|
||||
state_functions=state_functions,
|
||||
@ -456,7 +291,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
state_setter_name=state_setter_name,
|
||||
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
|
||||
test=node.test,
|
||||
test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
|
||||
test_name=self.ctx.namer.new_symbol('loop_test', reserved),
|
||||
undefined_assigns=undefined_assigns)
|
||||
|
||||
def visit_For(self, node):
|
||||
@ -464,15 +299,16 @@ class ControlFlowTransformer(converter.Base):
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)
|
||||
|
||||
loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
|
||||
loop_vars, undefined, _ = self._get_block_vars(
|
||||
node, body_scope.modified | iter_scope.modified)
|
||||
|
||||
undefined_assigns = self._create_undefined_assigns(possibly_undefs)
|
||||
undefined_assigns = self._create_undefined_assigns(undefined)
|
||||
|
||||
nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
|
||||
|
||||
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols)
|
||||
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols)
|
||||
reserved = body_scope.referenced | iter_scope.referenced
|
||||
state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
|
||||
state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
|
||||
state_functions = self._create_state_functions(
|
||||
loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||||
|
||||
@ -484,7 +320,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||||
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
|
||||
extra_test_name = self.ctx.namer.new_symbol(
|
||||
'extra_test', reserved_symbols)
|
||||
'extra_test', reserved)
|
||||
template = """
|
||||
def extra_test_name():
|
||||
nonlocal_declarations
|
||||
@ -502,7 +338,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
|
||||
# iterate_arg_name holds a single arg with the iterates, which may be a
|
||||
# tuple.
|
||||
iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols)
|
||||
iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
|
||||
template = """
|
||||
iterates = iterate_arg_name
|
||||
"""
|
||||
@ -529,7 +365,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
return templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
|
||||
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
|
||||
extra_test_function=extra_test_function,
|
||||
extra_test_name=extra_test_name,
|
||||
iterate_arg_name=iterate_arg_name,
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -453,6 +454,17 @@ class IfStatementTest(ControlFlowTestBase):
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
||||
|
||||
def test_local_remains_local(self):
|
||||
|
||||
def test_fn(n):
|
||||
if n > 0:
|
||||
b = 4
|
||||
n = b + 1
|
||||
return n
|
||||
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
||||
|
||||
def test_no_outputs(self):
|
||||
|
||||
def test_fn(n):
|
||||
@ -465,6 +477,85 @@ class IfStatementTest(ControlFlowTestBase):
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
||||
|
||||
def test_created_outputs(self):
|
||||
|
||||
def test_fn(i):
|
||||
if i == 0:
|
||||
result = i - 1
|
||||
else:
|
||||
result = i + 1
|
||||
return result
|
||||
|
||||
self.assertTransformedResult(test_fn, 0, -1)
|
||||
self.assertTransformedResult(test_fn, 1, 2)
|
||||
|
||||
def test_created_loop_local_outputs(self):
|
||||
|
||||
def test_fn(n, x):
|
||||
for i in n:
|
||||
if i == 0:
|
||||
result = i - 1
|
||||
else:
|
||||
result = i + 1
|
||||
if result > 0:
|
||||
x += 1
|
||||
return x
|
||||
|
||||
self.assertTransformedResult(test_fn, (range(5), 10), 14)
|
||||
|
||||
def test_created_loop_variable(self):
|
||||
|
||||
def test_fn(n, x):
|
||||
for i in n:
|
||||
if i == 0:
|
||||
result = i - 1
|
||||
if i > 0: # Using the result from previous iteration.
|
||||
if result < 0:
|
||||
x += 1
|
||||
return x
|
||||
|
||||
self.assertTransformedResult(test_fn, (range(5), 10), 14)
|
||||
|
||||
def test_unaffected_global(self):
|
||||
|
||||
def test_fn(i):
|
||||
global g # pylint:disable=global-variable-undefined
|
||||
if i == 0:
|
||||
g = i - 1
|
||||
return g
|
||||
|
||||
self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3})
|
||||
self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3})
|
||||
|
||||
def test_unaffected_nonlocal(self):
|
||||
|
||||
def test_fn(i):
|
||||
def inner_fn():
|
||||
nonlocal n
|
||||
if i == 0:
|
||||
n = i - 1
|
||||
|
||||
n = 3
|
||||
inner_fn()
|
||||
return n
|
||||
|
||||
self.assertTransformedResult(test_fn, 1, 3)
|
||||
self.assertTransformedResult(test_fn, 0, -1)
|
||||
|
||||
def test_output_defined_in_prior_except(self):
|
||||
|
||||
def test_fn(i):
|
||||
try:
|
||||
raise ValueError()
|
||||
except ValueError:
|
||||
x = 1
|
||||
if i == 0:
|
||||
x = i - 1
|
||||
return x
|
||||
|
||||
self.assertTransformedResult(test_fn, 1, 1)
|
||||
self.assertTransformedResult(test_fn, 0, -1)
|
||||
|
||||
def test_unbalanced_multiple_composites(self):
|
||||
|
||||
class Foo(object):
|
||||
|
@ -22,6 +22,7 @@ py_library(
|
||||
name = "operators",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"conditional_expressions.py",
|
||||
"control_flow.py",
|
||||
"control_flow_deprecated_py2.py",
|
||||
"data_structures.py",
|
||||
@ -62,6 +63,20 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "conditional_expressions_test",
|
||||
srcs = ["conditional_expressions_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
],
|
||||
deps = [
|
||||
":operators",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "control_flow_test",
|
||||
srcs = ["control_flow_test.py"],
|
||||
|
@ -37,6 +37,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators.conditional_expressions import if_exp
|
||||
from tensorflow.python.autograph.operators.control_flow import for_stmt
|
||||
from tensorflow.python.autograph.operators.control_flow import if_stmt
|
||||
from tensorflow.python.autograph.operators.control_flow import while_stmt
|
||||
|
@ -0,0 +1,56 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Conditional expressions (e.g. the ternary if statement)."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.autograph.operators import control_flow
|
||||
from tensorflow.python.autograph.utils import tensors
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
||||
|
||||
def if_exp(cond, if_true, if_false, expr_repr):
|
||||
if tensors.is_dense_tensor(cond):
|
||||
return _tf_if_exp(cond, if_true, if_false, expr_repr)
|
||||
else:
|
||||
return _py_if_exp(cond, if_true, if_false)
|
||||
|
||||
|
||||
def _tf_if_exp(cond, if_true, if_false, expr_repr):
|
||||
"""Overload of if_exp that stages a TF cond."""
|
||||
# TODO(mdan): Use nonlocal once we no longer need to support py2.
|
||||
true_val = []
|
||||
false_val = []
|
||||
|
||||
def true_fn():
|
||||
true_val.append(if_true())
|
||||
if true_val and false_val:
|
||||
control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0])
|
||||
return true_val[0]
|
||||
|
||||
def false_fn():
|
||||
false_val.append(if_false())
|
||||
if true_val and false_val:
|
||||
control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0])
|
||||
return false_val[0]
|
||||
|
||||
return control_flow_ops.cond(cond, true_fn, false_fn)
|
||||
|
||||
|
||||
def _py_if_exp(cond, if_true, if_false):
|
||||
return if_true() if cond else if_false()
|
@ -0,0 +1,66 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for conditional_expressions module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.operators import conditional_expressions
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _basic_expr(cond):
|
||||
return conditional_expressions.if_exp(
|
||||
cond,
|
||||
lambda: constant_op.constant(1),
|
||||
lambda: constant_op.constant(2),
|
||||
'cond')
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class IfExpTest(test.TestCase):
|
||||
|
||||
def test_tensor(self):
|
||||
self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(True))), 1)
|
||||
self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(False))), 2)
|
||||
|
||||
def test_tensor_mismatched_type(self):
|
||||
# tf.function required because eager cond degenerates to Python if.
|
||||
@def_function.function
|
||||
def test_fn():
|
||||
conditional_expressions.if_exp(
|
||||
constant_op.constant(True), lambda: 1.0, lambda: 2, 'expr_repr')
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"'expr_repr' has dtype float32 in the main.*int32 in the else"):
|
||||
test_fn()
|
||||
|
||||
def test_python(self):
|
||||
self.assertEqual(self.evaluate(_basic_expr(True)), 1)
|
||||
self.assertEqual(self.evaluate(_basic_expr(False)), 2)
|
||||
self.assertEqual(
|
||||
conditional_expressions.if_exp(True, lambda: 1, lambda: 2, ''), 1)
|
||||
self.assertEqual(
|
||||
conditional_expressions.if_exp(False, lambda: 1, lambda: 2, ''), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -102,7 +102,7 @@ def _verify_loop_init_vars(values, symbol_names):
|
||||
"""Ensures that all values in the state are defined when entering a loop."""
|
||||
for name, value in zip(symbol_names, values):
|
||||
if value is None:
|
||||
raise ValueError('"{}" may not be None before the loop.'.format(name))
|
||||
raise ValueError("'{}' may not be None before the loop.".format(name))
|
||||
if isinstance(value, variables.UndefinedReturnValue):
|
||||
# Assumption: the loop will only capture the variable which tracks the
|
||||
# return value if the loop contained a return statement.
|
||||
@ -110,7 +110,7 @@ def _verify_loop_init_vars(values, symbol_names):
|
||||
raise ValueError(
|
||||
'return statements are not supported within a TensorFlow loop.')
|
||||
if isinstance(value, variables.Undefined):
|
||||
raise ValueError('"{}" must be defined before the loop.'.format(name))
|
||||
raise ValueError("'{}' must be defined before the loop.".format(name))
|
||||
|
||||
|
||||
def _is_subshape(left, right):
|
||||
@ -133,9 +133,9 @@ def _is_subshape(left, right):
|
||||
def _verify_single_loop_var(
|
||||
name, check_shape, init, entry, exit_, shape_invariant):
|
||||
"""Verifies whether the initial, entry and exit values are consistent."""
|
||||
assert entry is not None, 'no TF op should set "{}" to None?'.format(name)
|
||||
assert entry is not None, "no TF op should set '{}' to None?".format(name)
|
||||
if exit_ is None:
|
||||
raise ValueError('"{}" is None at the end of the iteration.'.format(name))
|
||||
raise ValueError("'{}' is None at the end of the iteration.".format(name))
|
||||
|
||||
if isinstance(init, (bool, int, float, str, np.ndarray)):
|
||||
init = ops.convert_to_tensor_v2(init)
|
||||
@ -158,9 +158,8 @@ def _verify_single_loop_var(
|
||||
|
||||
if entry.dtype != exit_.dtype:
|
||||
raise TypeError(
|
||||
'"{}" has dtype {} before the loop, but dtype {} after one'
|
||||
' iteration. TensorFlow control flow requires it stays the'
|
||||
' same.'.format(
|
||||
"'{}' has dtype {} before the loop, but dtype {} after one"
|
||||
' iteration'.format(
|
||||
name,
|
||||
entry.dtype.name,
|
||||
exit_.dtype.name,
|
||||
@ -171,19 +170,19 @@ def _verify_single_loop_var(
|
||||
entry_shape = entry.shape
|
||||
if not _is_subshape(exit_shape, entry_shape):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} before the loop, but shape {} after one'
|
||||
"'{}' has shape {} before the loop, but shape {} after one"
|
||||
' iteration. Use tf.autograph.experimental.set_loop_options to set'
|
||||
' shape invariants.'.format(name, entry_shape, exit_shape))
|
||||
else:
|
||||
init_shape = init.shape
|
||||
if not _is_subshape(init_shape, shape_invariant):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} before the loop, which does not conform with'
|
||||
"'{}' has shape {} before the loop, which does not conform with"
|
||||
' the shape invariant {}.'.format(name, init_shape,
|
||||
shape_invariant))
|
||||
if not _is_subshape(exit_shape, shape_invariant):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} after one iteration, which does not conform with'
|
||||
"'{}' has shape {} after one iteration, which does not conform with"
|
||||
' the shape invariant {}.'.format(
|
||||
name, exit_shape, shape_invariant))
|
||||
|
||||
@ -216,13 +215,13 @@ def _verify_tf_loop_vars(init_vars,
|
||||
nest.assert_same_structure(init, entry, expand_composites=True)
|
||||
nest.assert_same_structure(entry, exit_, expand_composites=True)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError('"{}" does not have the same nested structure after one'
|
||||
raise TypeError("'{}' does not have the same nested structure after one"
|
||||
' iteration.\n\n{}'.format(name, e))
|
||||
if invariant is not None:
|
||||
try:
|
||||
nest.assert_same_structure(init, invariant, expand_composites=False)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError('"{}" does not have the same nested structure as its'
|
||||
raise TypeError("'{}' does not have the same nested structure as its"
|
||||
' corresponding shape invariant.\n\n{}'.format(name, e))
|
||||
|
||||
nest.map_structure(
|
||||
@ -230,13 +229,13 @@ def _verify_tf_loop_vars(init_vars,
|
||||
entry, exit_, invariant)
|
||||
|
||||
|
||||
def _verify_single_cond_var(name, body_var, orelse_var):
|
||||
def verify_single_cond_var(name, body_var, orelse_var):
|
||||
"""Verifies whether body_var and orelse_var are consistent."""
|
||||
if body_var is None:
|
||||
raise ValueError('"{}" is None at the end of the TRUE branch.'.format(name))
|
||||
raise ValueError("'{}' is None at the end of the main branch.".format(name))
|
||||
if orelse_var is None:
|
||||
raise ValueError(
|
||||
'"{}" is None at the end of the FALSE branch.'.format(name))
|
||||
"'{}' is None at the end of the else branch.".format(name))
|
||||
|
||||
if isinstance(body_var, (bool, int, float, str, np.ndarray)):
|
||||
body_var = ops.convert_to_tensor_v2(body_var)
|
||||
@ -255,41 +254,37 @@ def _verify_single_cond_var(name, body_var, orelse_var):
|
||||
|
||||
if body_var.dtype != orelse_var.dtype:
|
||||
raise TypeError(
|
||||
'"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
|
||||
' branch. TensorFlow control flow requires that they are the'
|
||||
' same.'.format(name, body_var.dtype.name,
|
||||
orelse_var.dtype.name))
|
||||
"'{}' has dtype {} in the main branch, but dtype {} in the else"
|
||||
' branch'.format(name, body_var.dtype.name,
|
||||
orelse_var.dtype.name))
|
||||
|
||||
|
||||
def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name):
|
||||
"""Verifies variables output by a conditional branch for consistency."""
|
||||
for name, var_ in zip(symbol_names, vars_):
|
||||
if isinstance(var_, variables.Undefined):
|
||||
raise ValueError(
|
||||
"'{}' must also be initialized in the {} branch".format(
|
||||
name, branch_name))
|
||||
if isinstance(var_, variables.UndefinedReturnValue):
|
||||
raise ValueError(
|
||||
'the {} branch must also have a return statement.'.format(
|
||||
branch_name))
|
||||
|
||||
|
||||
def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
|
||||
"""Verifies variables manipulated by a conditional for consistency."""
|
||||
basic_body_vars, composite_body_vars = body_vars
|
||||
basic_orelse_vars, composite_orelse_vars = orelse_vars
|
||||
assert isinstance(composite_body_vars, tuple)
|
||||
assert isinstance(composite_orelse_vars, tuple)
|
||||
|
||||
# TODO(kkb): Make this more consistent.
|
||||
# The basic outputs should always be a tuple.
|
||||
if not isinstance(basic_body_vars, tuple):
|
||||
basic_body_vars = (basic_body_vars,)
|
||||
if not isinstance(basic_orelse_vars, tuple):
|
||||
basic_orelse_vars = (basic_orelse_vars,)
|
||||
|
||||
body_vars = basic_body_vars + composite_body_vars
|
||||
orelse_vars = basic_orelse_vars + composite_orelse_vars
|
||||
|
||||
named_vars = zip(symbol_names, body_vars, orelse_vars)
|
||||
|
||||
for name, body_var, orelse_var in named_vars:
|
||||
try:
|
||||
nest.assert_same_structure(
|
||||
body_var, orelse_var, expand_composites=True)
|
||||
nest.assert_same_structure(body_var, orelse_var, expand_composites=True)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(
|
||||
'"{}" does not have the same nested structure in the TRUE and FALSE'
|
||||
' branches.\n\n{}'.format(name, str(e)))
|
||||
|
||||
"'{}' must have the same nested structure in the main and else"
|
||||
' branches:\n\n{}'.format(name, str(e)))
|
||||
nest.map_structure(
|
||||
functools.partial(_verify_single_cond_var, name), body_var, orelse_var)
|
||||
functools.partial(verify_single_cond_var, name), body_var, orelse_var)
|
||||
|
||||
|
||||
def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
|
||||
@ -314,12 +309,16 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
|
||||
`extra_test`, `body`, `get_state` and `set_state` functions must bind to the
|
||||
original `geo_mean` and `arith_mean` symbols, using `nonlocal`.
|
||||
|
||||
The inputs and outputs of the callables representing the loop blocks are not
|
||||
explicit - instead, these functions must use nonlocal/global for side effects.
|
||||
The inputs and outputs are instead controlled by the set_state/get_state
|
||||
functions.
|
||||
|
||||
Args:
|
||||
iter_: The entity being iterated over.
|
||||
extra_test: Callable with the state as arguments, and boolean return type.
|
||||
extra_test: Callable with boolean return type.
|
||||
An additional loop condition.
|
||||
body: Callable with the iterate and the state as arguments, and state as
|
||||
return type. The actual loop body.
|
||||
body: Callable representing the actual loop body.
|
||||
get_state: Additional callable which can capture additional state (such as
|
||||
the values of composite symbols). This is only useful when staging the
|
||||
loop.
|
||||
@ -717,11 +716,14 @@ def while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
a tuple of entities that represent an actual state, or a list of arguments
|
||||
of the corresponding types.
|
||||
|
||||
The inputs and outputs of the callables representing the loop blocks are not
|
||||
explicit - instead, these functions must use nonlocal/global for side effects.
|
||||
The inputs and outputs are instead controlled by the set_state/get_state
|
||||
functions.
|
||||
|
||||
Args:
|
||||
test: Callable with the state as arguments, and boolean return type. The
|
||||
loop condition.
|
||||
body: Callable with the state as arguments, and state as return type. The
|
||||
actual loop body.
|
||||
test: Callable with boolean return type. The loop condition.
|
||||
body: Callable representing the actual loop body.
|
||||
get_state: Additional callable which can capture additional state (such as
|
||||
the values of composite symbols). This is only useful when staging the
|
||||
loop.
|
||||
@ -894,21 +896,32 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
set_state(final_loop_vars)
|
||||
|
||||
|
||||
def if_stmt(cond,
|
||||
body,
|
||||
orelse,
|
||||
get_state,
|
||||
set_state,
|
||||
basic_symbol_names,
|
||||
composite_symbol_names):
|
||||
def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||
"""Functional form of an if statement.
|
||||
|
||||
The conditional operates on a state, which includes all symbols whose values
|
||||
are a function of the branch taken.
|
||||
|
||||
For example, given the code below that calculates the abs function:
|
||||
|
||||
```
|
||||
x = 1
|
||||
if x > 0:
|
||||
x = -x
|
||||
```
|
||||
|
||||
The state is represented by the variable `x`. The `body, `orelse` and
|
||||
`set_state` functions must bind to the original `x` symbol, using `nonlocal`.
|
||||
|
||||
The inputs and outputs of the callables representing the loop blocks are not
|
||||
explicit - instead, these functions must use nonlocal/global for side effects.
|
||||
The inputs and outputs are instead controlled by the set_state/get_state
|
||||
functions.
|
||||
|
||||
Args:
|
||||
cond: Boolean.
|
||||
body: Callable with no arguments, and outputs of the positive (if) branch as
|
||||
return type.
|
||||
orelse: Callable with no arguments, and outputs of the negative (else)
|
||||
branch as return type.
|
||||
body: Callable representing the main block of the conditional.
|
||||
orelse: Callable representing the else block of the conditional.
|
||||
get_state: Function that returns a tuple containing the values of all
|
||||
composite symbols modified within the conditional. This allows access to
|
||||
state that branches may mutate through side effects. This function is not
|
||||
@ -920,123 +933,63 @@ def if_stmt(cond,
|
||||
restore checkpointed values. The single argument a tuple containing values
|
||||
for each composite symbol that may be modified in a branch of the
|
||||
conditional. The is usually the result of a call to get_state.
|
||||
basic_symbol_names: Tuple containing basic loop var names.
|
||||
composite_symbol_names: Tuple containing composite loop var names.
|
||||
|
||||
Returns:
|
||||
Tuple containing the statement outputs.
|
||||
symbol_names: Tuple containing basic loop var names.
|
||||
nouts: Number of variables output by the statement. Vars which are
|
||||
not outputs will not be passed through staged control flow such as
|
||||
tf.cond. This includes variables that are defined before the conditional,
|
||||
but are not used after it.
|
||||
"""
|
||||
# Note: tf.cond doesn't support SparseTensor.
|
||||
if tensors.is_dense_tensor(cond):
|
||||
return tf_if_stmt(cond, body, orelse, get_state, set_state,
|
||||
basic_symbol_names, composite_symbol_names)
|
||||
_tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts)
|
||||
else:
|
||||
return _py_if_stmt(cond, body, orelse)
|
||||
_py_if_stmt(cond, body, orelse)
|
||||
|
||||
|
||||
def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
|
||||
composite_symbol_names):
|
||||
def _tf_if_stmt(
|
||||
cond, body, orelse, get_state, set_state, symbol_names, nouts):
|
||||
"""Overload of if_stmt that stages a TF cond."""
|
||||
body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
|
||||
orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
|
||||
body = _isolate_state(body, get_state, set_state)
|
||||
orelse = _isolate_state(orelse, get_state, set_state)
|
||||
if not nouts:
|
||||
prev_get_state, prev_set_state = get_state, set_state
|
||||
# Control flow V1 wants at least one output.
|
||||
get_state = lambda: (0,) + prev_get_state()
|
||||
set_state = lambda v: prev_set_state(v[1:])
|
||||
symbol_names += ('<unused dummy>',)
|
||||
nouts = 1
|
||||
|
||||
# `state` currently includes the values of any composite symbols (e.g. `a.b`)
|
||||
# composites modified by the loop. `final_vars` includes the values of basic
|
||||
# symbols (e.g. `a`) which cannot be passed by reference and must be returned.
|
||||
# See _isolate_state.
|
||||
# TODO(mdan): We should minimize calls to get/set_state.
|
||||
init_vars = get_state()
|
||||
|
||||
body_branch = 0
|
||||
orelse_branch = 1
|
||||
result = [None, None]
|
||||
# TODO(mdan): Use nonlocal once we no longer need to support py2.
|
||||
new_body_vars_ = [None]
|
||||
new_orelse_vars_ = [None]
|
||||
|
||||
def error_checking_body():
|
||||
result[body_branch] = body()
|
||||
if result[orelse_branch] is not None:
|
||||
_verify_tf_cond_vars(result[body_branch], result[orelse_branch],
|
||||
basic_symbol_names + composite_symbol_names)
|
||||
return result[body_branch]
|
||||
def aug_body():
|
||||
set_state(init_vars)
|
||||
body()
|
||||
new_body_vars = get_state()
|
||||
new_body_vars = new_body_vars[:nouts]
|
||||
new_body_vars_[0] = new_body_vars
|
||||
_verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main')
|
||||
if new_orelse_vars_[0] is not None:
|
||||
_verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names)
|
||||
return new_body_vars
|
||||
|
||||
def error_checking_orelse():
|
||||
result[orelse_branch] = orelse()
|
||||
if result[body_branch] is not None:
|
||||
_verify_tf_cond_vars(result[body_branch], result[orelse_branch],
|
||||
basic_symbol_names + composite_symbol_names)
|
||||
return result[orelse_branch]
|
||||
def aug_orelse():
|
||||
set_state(init_vars)
|
||||
orelse()
|
||||
new_orelse_vars = get_state()
|
||||
new_orelse_vars = new_orelse_vars[:nouts]
|
||||
new_orelse_vars_[0] = new_orelse_vars
|
||||
_verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else')
|
||||
if new_body_vars_[0] is not None:
|
||||
_verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names)
|
||||
return new_orelse_vars
|
||||
|
||||
final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
|
||||
error_checking_orelse)
|
||||
final_cond_vars = control_flow_ops.cond(
|
||||
cond, aug_body, aug_orelse, strict=True)
|
||||
final_cond_vars = final_cond_vars + init_vars[nouts:]
|
||||
|
||||
set_state(final_state)
|
||||
|
||||
return final_vars
|
||||
|
||||
|
||||
def _isolate_state(func, get_state, set_state):
|
||||
"""Wraps func to (best-effort) isolate state mutations that func may do.
|
||||
|
||||
The simplest example of state mutation is mutation of variables (via e.g.
|
||||
attributes), or modification of globals.
|
||||
|
||||
This allows us to more safely execute this function without worrying about
|
||||
side effects when the function wasn't normally expected to execute. For
|
||||
example, staging requires that the function is executed ahead of time, and
|
||||
we need to ensure its effects are not observed during normal execution.
|
||||
|
||||
Args:
|
||||
func: () -> Any
|
||||
get_state: () -> Any, returns the current state
|
||||
set_state: (Any) -> None, resets the state to the specified values.
|
||||
Typically the result of an earlier call to `get_state`.
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Any], where the first element is the return value of `func`,
|
||||
and the second is the final state values.
|
||||
"""
|
||||
|
||||
def wrapper():
|
||||
init_state = get_state()
|
||||
new_vars = func()
|
||||
# TODO(mdan): These should be copies, lest set_state might affect them.
|
||||
new_state = get_state()
|
||||
set_state(init_state)
|
||||
return new_vars, new_state
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _wrap_disallow_undefs_from_cond(func, branch_name):
|
||||
"""Wraps conditional branch to disallow returning undefined symbols."""
|
||||
|
||||
def wrapper():
|
||||
"""Calls function and raises an error if undefined symbols are returned."""
|
||||
results = func()
|
||||
|
||||
if isinstance(results, tuple):
|
||||
results_tuple = results
|
||||
else:
|
||||
results_tuple = results,
|
||||
|
||||
for result in results_tuple:
|
||||
if isinstance(result, variables.UndefinedReturnValue):
|
||||
raise ValueError(
|
||||
'A value must also be returned from the {} branch. If a value is '
|
||||
'returned from one branch of a conditional a value must be '
|
||||
'returned from all branches.'.format(branch_name))
|
||||
|
||||
undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
|
||||
if undefined:
|
||||
raise ValueError(
|
||||
'The following symbols must also be initialized in the {} branch: {}.'
|
||||
' Alternatively, you may initialize them before the if'
|
||||
' statement.'.format(branch_name,
|
||||
tuple(s.symbol_name for s in undefined)))
|
||||
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
set_state(final_cond_vars)
|
||||
|
||||
|
||||
def _py_if_stmt(cond, body, orelse):
|
||||
|
@ -543,21 +543,21 @@ class ForLoopTest(test.TestCase):
|
||||
return s
|
||||
|
||||
def test_tensor_illegal_input(self):
|
||||
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
||||
with self.assertRaisesRegex(ValueError, '\'s\' may not be None'):
|
||||
self._basic_loop(None, lambda i, s: s)
|
||||
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
||||
with self.assertRaisesRegex(ValueError, '\'s\' must be defined'):
|
||||
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
|
||||
|
||||
def test_tensor_none_output(self):
|
||||
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
||||
with self.assertRaisesRegex(ValueError, '\'s\' is None at the end'):
|
||||
self._basic_loop(0, lambda i, s: None)
|
||||
|
||||
def test_tensor_dtype_change(self):
|
||||
with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'):
|
||||
with self.assertRaisesRegex(TypeError, '\'s\'.* dtype float32 after'):
|
||||
self._basic_loop(0, lambda i, s: 1.0)
|
||||
|
||||
def test_tensor_shape_change(self):
|
||||
with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'):
|
||||
with self.assertRaisesRegex(ValueError, r'\'s\'.* shape \(1,\) after'):
|
||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||
|
||||
|
||||
@ -782,21 +782,21 @@ class WhileLoopTest(test.TestCase):
|
||||
return s
|
||||
|
||||
def test_tensor_illegal_input(self):
|
||||
with self.assertRaisesRegex(ValueError, '"s" may not be None'):
|
||||
with self.assertRaisesRegex(ValueError, "'s' may not be None"):
|
||||
self._basic_loop(None, lambda i, s: s)
|
||||
with self.assertRaisesRegex(ValueError, '"s" must be defined'):
|
||||
with self.assertRaisesRegex(ValueError, "'s' must be defined"):
|
||||
self._basic_loop(variable_operators.Undefined(''), lambda i, s: s)
|
||||
|
||||
def test_tensor_none_output(self):
|
||||
with self.assertRaisesRegex(ValueError, '"s" is None at the end'):
|
||||
with self.assertRaisesRegex(ValueError, "'s' is None at the end"):
|
||||
self._basic_loop(0, lambda i, s: None)
|
||||
|
||||
def test_tensor_dtype_change(self):
|
||||
with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'):
|
||||
with self.assertRaisesRegex(TypeError, "'s'.* dtype float32 after"):
|
||||
self._basic_loop(0, lambda i, s: 1.0)
|
||||
|
||||
def test_tensor_shape_change(self):
|
||||
with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'):
|
||||
with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"):
|
||||
self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32))
|
||||
|
||||
|
||||
@ -806,29 +806,88 @@ class IfStmtTest(test.TestCase):
|
||||
def test_tensor(self):
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
def body():
|
||||
nonlocal i
|
||||
i = constant_op.constant(1)
|
||||
|
||||
def orelse():
|
||||
nonlocal i
|
||||
i = constant_op.constant(-1)
|
||||
|
||||
def set_state(cond_vars):
|
||||
nonlocal i
|
||||
i, = cond_vars
|
||||
|
||||
i = None
|
||||
control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: constant_op.constant(1),
|
||||
orelse=lambda: constant_op.constant(-1),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
basic_symbol_names=('_',),
|
||||
composite_symbol_names=())
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (i,),
|
||||
set_state=set_state,
|
||||
symbol_names=('i',),
|
||||
nouts=1)
|
||||
return i
|
||||
|
||||
self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True))))
|
||||
self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False))))
|
||||
|
||||
def test_tensor_no_outputs(self):
|
||||
|
||||
def test_fn(cond):
|
||||
def body():
|
||||
nonlocal i
|
||||
i = constant_op.constant(1)
|
||||
|
||||
def orelse():
|
||||
nonlocal i
|
||||
i = constant_op.constant(-1.0)
|
||||
|
||||
def set_state(cond_vars):
|
||||
nonlocal i
|
||||
i, = cond_vars
|
||||
|
||||
i = None
|
||||
control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (i,),
|
||||
set_state=set_state,
|
||||
symbol_names=('i',),
|
||||
nouts=0)
|
||||
return i
|
||||
|
||||
self.assertEqual(None, test_fn(constant_op.constant(True)))
|
||||
self.assertEqual(None, test_fn(constant_op.constant(False)))
|
||||
|
||||
def test_tensor_multiple_returns(self):
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
def body():
|
||||
nonlocal i, j
|
||||
i = constant_op.constant(1)
|
||||
j = constant_op.constant(2)
|
||||
|
||||
def orelse():
|
||||
nonlocal i, j
|
||||
i = constant_op.constant(-1)
|
||||
j = constant_op.constant(-2)
|
||||
|
||||
def set_state(cond_vars):
|
||||
nonlocal i, j
|
||||
i, j = cond_vars
|
||||
|
||||
i, j = None, None
|
||||
control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
basic_symbol_names=('_',),
|
||||
composite_symbol_names=())
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (i, j),
|
||||
set_state=set_state,
|
||||
symbol_names=('i', 'j'),
|
||||
nouts=2)
|
||||
return i, j
|
||||
|
||||
self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True))))
|
||||
self.assertEqual((-1, -2),
|
||||
@ -837,14 +896,24 @@ class IfStmtTest(test.TestCase):
|
||||
def test_python(self):
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
def body():
|
||||
nonlocal i
|
||||
i = 1
|
||||
|
||||
def orelse():
|
||||
nonlocal i
|
||||
i = -1
|
||||
|
||||
i = None
|
||||
control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: 1,
|
||||
orelse=lambda: -1,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
basic_symbol_names=('_',),
|
||||
composite_symbol_names=())
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=None,
|
||||
set_state=None,
|
||||
symbol_names=('i',),
|
||||
nouts=1)
|
||||
return i
|
||||
|
||||
self.assertEqual(1, test_fn(True))
|
||||
self.assertEqual(-1, test_fn(False))
|
||||
@ -852,48 +921,75 @@ class IfStmtTest(test.TestCase):
|
||||
def test_python_multiple_returns(self):
|
||||
|
||||
def test_fn(cond):
|
||||
return control_flow.if_stmt(
|
||||
def body():
|
||||
nonlocal i, j
|
||||
i = 1
|
||||
j = 2
|
||||
|
||||
def orelse():
|
||||
nonlocal i, j
|
||||
i = -1
|
||||
j = -2
|
||||
|
||||
i, j = None, None
|
||||
control_flow.if_stmt(
|
||||
cond=cond,
|
||||
body=lambda: (1, 2),
|
||||
orelse=lambda: (-1, -2),
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
basic_symbol_names=('_',),
|
||||
composite_symbol_names=())
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=None,
|
||||
set_state=None,
|
||||
symbol_names=('i', 'j'),
|
||||
nouts=2)
|
||||
return i, j
|
||||
|
||||
self.assertEqual((1, 2), test_fn(True))
|
||||
self.assertEqual((-1, -2), test_fn(False))
|
||||
|
||||
def _basic_cond(self, true_value, false_value):
|
||||
def _basic_cond(self, body_fn, else_fn):
|
||||
def body():
|
||||
nonlocal x
|
||||
x = body_fn()
|
||||
|
||||
def orelse():
|
||||
nonlocal x
|
||||
x = else_fn()
|
||||
|
||||
def set_state(cond_vars):
|
||||
nonlocal x
|
||||
x, = cond_vars
|
||||
|
||||
x = 0
|
||||
# Eager cond had different semantics, we don't test those here.
|
||||
with func_graph.FuncGraph('tmp').as_default():
|
||||
return control_flow.if_stmt(
|
||||
control_flow.if_stmt(
|
||||
cond=constant_op.constant(True),
|
||||
body=true_value,
|
||||
orelse=false_value,
|
||||
get_state=lambda: (),
|
||||
set_state=lambda _: None,
|
||||
basic_symbol_names=('s',),
|
||||
composite_symbol_names=())
|
||||
body=body,
|
||||
orelse=orelse,
|
||||
get_state=lambda: (x,),
|
||||
set_state=set_state,
|
||||
symbol_names=('x',),
|
||||
nouts=1)
|
||||
return x
|
||||
|
||||
def test_tensor_none_output(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '"s" is None at the end of the TRUE branch'):
|
||||
ValueError, "'x' is None at the end of the main branch"):
|
||||
self._basic_cond(lambda: None, lambda: 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, '"s" is None at the end of the FALSE branch'):
|
||||
ValueError, "'x' is None at the end of the else branch"):
|
||||
self._basic_cond(lambda: 1, lambda: None)
|
||||
|
||||
def test_tensor_undefined_output(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "must also be initialized in the if.*'s'"):
|
||||
self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1)
|
||||
ValueError, "'x' must also be initialized in the main branch"):
|
||||
self._basic_cond(lambda: variable_operators.Undefined('x'), lambda: 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "must also be initialized in the else.*'s'"):
|
||||
ValueError, "'x' must also be initialized in the else branch"):
|
||||
self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s'))
|
||||
|
||||
def test_tensor_dtype_change(self):
|
||||
with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "'x' has dtype int32.*but.*float32"):
|
||||
self._basic_cond(lambda: 1, lambda: 1.0)
|
||||
|
||||
|
||||
|
@ -70,6 +70,9 @@ class Scope(object):
|
||||
globals: Set[qual_names.QN], names that are explicitly marked as global in
|
||||
this scope. Note that this doesn't include free read-only vars bound to
|
||||
global symbols.
|
||||
nonlocals: Set[qual_names.QN], names that are explicitly marked as nonlocal
|
||||
in this scope. Note that this doesn't include free read-only vars bound to
|
||||
global symbols.
|
||||
free_vars: Set[qual_names.QN], the free variables in this scope. See
|
||||
https://docs.python.org/3/reference/executionmodel.html for a precise
|
||||
definition.
|
||||
@ -111,6 +114,7 @@ class Scope(object):
|
||||
|
||||
self.bound = set()
|
||||
self.globals = set()
|
||||
self.nonlocals = set()
|
||||
self.annotations = set()
|
||||
|
||||
self.params = weakref.WeakValueDictionary()
|
||||
@ -186,6 +190,7 @@ class Scope(object):
|
||||
self.parent.modified.update(self.modified - self.isolated_names)
|
||||
self.parent.bound.update(self.bound - self.isolated_names)
|
||||
self.parent.globals.update(self.globals)
|
||||
self.parent.nonlocals.update(self.nonlocals)
|
||||
self.parent.annotations.update(self.annotations)
|
||||
else:
|
||||
# TODO(mdan): This is not accurate.
|
||||
@ -363,6 +368,7 @@ class ActivityAnalyzer(transformer.Base):
|
||||
qn = qual_names.QN(name)
|
||||
self.scope.read.add(qn)
|
||||
self.scope.bound.add(qn)
|
||||
self.scope.nonlocals.add(qn)
|
||||
self._exit_and_record_scope(node)
|
||||
return node
|
||||
|
||||
|
@ -404,6 +404,46 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
|
||||
|
||||
self.assertHasDefinedIn(fn_body[1], ('a',))
|
||||
|
||||
def test_definitions_in_except_block(self):
|
||||
|
||||
def test_fn():
|
||||
try:
|
||||
pass
|
||||
except ValueError:
|
||||
a = None
|
||||
if a: # pylint:disable=using-constant-test
|
||||
a = None
|
||||
return a
|
||||
|
||||
node = self._parse_and_analyze(test_fn)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertHasDefs(fn_body[1].test, 1)
|
||||
self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
|
||||
self.assertHasDefs(fn_body[2].value, 2)
|
||||
|
||||
self.assertHasDefinedIn(fn_body[1], ('a',))
|
||||
|
||||
def test_definitions_in_except_block_of_raising_try(self):
|
||||
|
||||
def test_fn():
|
||||
try:
|
||||
raise ValueError()
|
||||
except ValueError:
|
||||
a = None
|
||||
if a: # pylint:disable=using-constant-test
|
||||
a = None
|
||||
return a
|
||||
|
||||
node = self._parse_and_analyze(test_fn)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertHasDefs(fn_body[1].test, 1)
|
||||
self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
|
||||
self.assertHasDefs(fn_body[2].value, 2)
|
||||
|
||||
self.assertHasDefinedIn(fn_body[1], ('a',))
|
||||
|
||||
def test_global(self):
|
||||
|
||||
def test_fn():
|
||||
|
Loading…
Reference in New Issue
Block a user