Fix liveness analysis for variables closed over by functions. Previously, these variables were only live at the closing function's point of definition. After this change, these variables are live at every statement that may be reached by this function definition. This is more consistent with Python's late binding mechanism.
For reasons of backward compatibility, lambda functions continue to be treated as before - any variables they close over are only live in the statement that contains the lambda. This is consistent with the typical usage of lambdas as arguments to higher-order functions like map, reduce, etc.. PiperOrigin-RevId: 310334670 Change-Id: I04da533e0017178156851d43050cd81b05a49704
This commit is contained in:
parent
a8b9d64276
commit
0103d8ee27
tensorflow/python
autograph
converters
g3doc/reference
pyct
eager
@ -32,6 +32,7 @@ from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||||
from tensorflow.python.autograph.utils import compat_util
|
||||
|
||||
|
||||
@ -554,7 +555,8 @@ def transform(node, ctx):
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
|
||||
node = ControlFlowTransformer(ctx).visit(node)
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||||
|
||||
|
||||
# TODO(mdan): Refactor functions to make them smaller.
|
||||
@ -630,7 +631,8 @@ def transform(node, ctx):
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
|
||||
node = ControlFlowTransformer(ctx).visit(node)
|
||||
|
@ -16,6 +16,88 @@ should not be confused with TensorFlow variables.
|
||||
Key Term: A TensorFlow loop variable (or loop variable for short) refers to a
|
||||
value (typically a `tf.Tensor`) modified by a loop. See `tf.while_loop`.
|
||||
|
||||
### Undefined and None values in TensorFlow
|
||||
|
||||
TensorFlow does not support undefined or `None` values. All tensors must have
|
||||
a value.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
x = tf.cond(
|
||||
tf.random.uniform(()) > 0.5,
|
||||
lambda: tf.constant(1),
|
||||
lambda: None) # Error -- a Tensor cannot be None
|
||||
```
|
||||
|
||||
The same restriction carries over in AutoGraph. If a variable is created inside
|
||||
control flow, and used after, then it must be defined before the control flow
|
||||
statement:
|
||||
|
||||
```
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
x = None
|
||||
tf.print(x) # Error -- x may be None here
|
||||
```
|
||||
|
||||
For this reason, AutoGraph forbids variables to be defined in only one branch
|
||||
of a TensorFlow conditional, if the variable is used afterwards:
|
||||
|
||||
```
|
||||
del x
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
pass
|
||||
tf.print(x) # Error -- x may be undefined here
|
||||
```
|
||||
|
||||
Note that if the variable is not used after the control flow statement, then it
|
||||
is considered local to the control flow block, and is not subject to these
|
||||
restrictions.
|
||||
|
||||
```
|
||||
del x
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1) # Okay -- x does not need to be returned from the TF cond
|
||||
else:
|
||||
pass
|
||||
```
|
||||
|
||||
Similarly, variables may not be defined inside a TensorFlow loop, unless they
|
||||
are local to the loop. A variable is local to the loop if (1) it's not used
|
||||
after the loop and (2) the value from a previour iteration is not used in the
|
||||
next iteration:
|
||||
|
||||
```
|
||||
del x
|
||||
while tf.random.uniform(()) > 0.5: # Error -- x must be defined before the loop
|
||||
x = tf.constant(1)
|
||||
tf.print(x)
|
||||
```
|
||||
|
||||
```
|
||||
del x
|
||||
while tf.random.uniform(()) > 0.5: # Okay -- x is local to the loop
|
||||
x = tf.constant(1)
|
||||
```
|
||||
|
||||
Avoid these limitations by defining a default value before the control flow
|
||||
statement:
|
||||
|
||||
```
|
||||
x = tf.constant()
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
tf.print(x) # Okay -- x is either 0 or 1
|
||||
```
|
||||
|
||||
Note: `None` values and undefined symbols are allowed in Eager control flow,
|
||||
because Eager execution uses Python control flow, rather than TensorFlow
|
||||
control flow ops.
|
||||
|
||||
### Indirect modifications and hidden side effects in TensorFlow control flow
|
||||
|
||||
Key Point: We recommend using a functional programming style, immutable Python
|
||||
@ -187,6 +269,62 @@ objects, but it does support basic collection objects such as `list`, `dict`,
|
||||
`tuple`, `namedtuple` and their subclasses. Design your objects as subclasses
|
||||
of [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple).
|
||||
|
||||
#### Variables closed over by lambda functions
|
||||
|
||||
AutoGraph assumes that variables that local functions close over may be used
|
||||
anywhere in the parent function, because in general it is possible to hide a
|
||||
function call in almost any Python statement). For this reason, these variables
|
||||
are accounted within TensorFlow loops.
|
||||
|
||||
For example, the following code correctly captures `a` in the TensorFlow loop
|
||||
variables:
|
||||
|
||||
```
|
||||
a = 0
|
||||
def f():
|
||||
tf.print(a)
|
||||
for i in tf.range(3):
|
||||
a = i
|
||||
f() # Prints 2
|
||||
```
|
||||
|
||||
An consequence is that these variables must be defined before the loop (see
|
||||
Undefined and None values above). So the following code will raise an error,
|
||||
even if the variable is never used after the loop:
|
||||
|
||||
```
|
||||
def f():
|
||||
tf.print(a)
|
||||
for i in tf.range(3): # Error -- `a` must be defined before the loop.
|
||||
a = i
|
||||
```
|
||||
|
||||
However, lambda functions are handled differently, for reasons of backward
|
||||
compatibility. Lambda functions are assumed to be used in the statement where
|
||||
they are used, or at least in the same block.
|
||||
|
||||
```
|
||||
a = 0
|
||||
foo(lambda: a) # This lambda is not expected to be called anywhere else.
|
||||
for i in tf.range(3): # Okay -- `a` is local to the loop.
|
||||
a = i
|
||||
```
|
||||
|
||||
Due to that reason, the following code will not work as expected for TensorFlow
|
||||
loops.
|
||||
|
||||
```
|
||||
a = 0
|
||||
l = lambda: tf.print(a)
|
||||
for i in tf.range(3):
|
||||
a = i # `a` is considered local to the loop
|
||||
l() # Prints 0!
|
||||
```
|
||||
|
||||
Note that none of these restrictions only apply to TensorFlow loops; Python
|
||||
loops correctly correctly handle closures in all cases.
|
||||
|
||||
|
||||
### Python collections in TensorFlow control flow
|
||||
|
||||
Key Point: Use TensorFlow collection classes instead of Python collections.
|
||||
@ -489,69 +627,6 @@ while tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant((1, 2, 3)) # Error -- inconsistent shapes: (), (3,)
|
||||
```
|
||||
|
||||
### Undefined and None values in TensorFlow
|
||||
|
||||
TensorFlow does not support undefined and `None` values. All tensors must have
|
||||
a value.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
x = tf.cond(
|
||||
tf.random.uniform(()) > 0.5,
|
||||
lambda: tf.constant(1),
|
||||
lambda: None) # Error -- a Tensor cannot be None
|
||||
```
|
||||
|
||||
The same restriction carries over in AutoGraph, but only if the symbol is used
|
||||
after the conditional (otherwise AutoGraph avoids making it a return value
|
||||
of the `tf.cond`):
|
||||
|
||||
```
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
x = None
|
||||
tf.print(x) # Error -- x may be None here
|
||||
```
|
||||
|
||||
A related but less obvious restriction in AutoGraph forbids symbols to be
|
||||
defined in only one branch of TensorFlow control flow, if the symbol is
|
||||
used afterwards:
|
||||
|
||||
```
|
||||
del x
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
else:
|
||||
pass
|
||||
tf.print(x) # Error -- x may be undefined here
|
||||
```
|
||||
|
||||
Similarly, variables defined in a loop may not be used outside the loop, again
|
||||
if the symbol is used afterwards:
|
||||
|
||||
```
|
||||
del x
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
tf.print(x) # Error -- x may be undefined here
|
||||
```
|
||||
|
||||
Avoid these limitations by defining a default value before the control flow
|
||||
statement:
|
||||
|
||||
```
|
||||
x = tf.constant()
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
x = tf.constant(1)
|
||||
tf.print(x) # Okay -- x is either 0 or 1
|
||||
```
|
||||
|
||||
Note: `None` values and undefined symbols are allowed in Eager control flow,
|
||||
because Eager execution uses Python control flow, rather than TensorFlow
|
||||
control flow ops.
|
||||
|
||||
### Access to source code
|
||||
|
||||
Key point: AutoGraph can only handle functions whose source code can be
|
||||
|
@ -93,6 +93,9 @@ class Static(NoValue):
|
||||
ORIG_DEFINITIONS = (
|
||||
'The value of DEFINITIONS that applied to the original code before any'
|
||||
' conversion.')
|
||||
DEFINED_FNS_IN = (
|
||||
'Local function definitions that may exist when exiting the node. See'
|
||||
' reaching_fndefs.py')
|
||||
DEFINED_VARS_IN = (
|
||||
'Symbols defined when entering the node. See reaching_definitions.py.')
|
||||
LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
|
||||
|
@ -23,6 +23,7 @@ py_library(
|
||||
"annos.py",
|
||||
"liveness.py",
|
||||
"reaching_definitions.py",
|
||||
"reaching_fndefs.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -617,9 +617,23 @@ class ActivityAnalyzer(transformer.Base):
|
||||
# TODO(mdan): Do remove it, it's confusing.
|
||||
self._enter_scope(False)
|
||||
node.body = self.visit(node.body)
|
||||
|
||||
# The lambda body can contain nodes of types normally not found as
|
||||
# statements, and may not have the SCOPE annotation needed by the CFG.
|
||||
# So we attach one if necessary.
|
||||
if not anno.hasanno(node.body, anno.Static.SCOPE):
|
||||
anno.setanno(node.body, anno.Static.SCOPE, self.scope)
|
||||
|
||||
self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)
|
||||
|
||||
lambda_scope = self.scope
|
||||
self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
|
||||
# Exception: lambdas are assumed to be used in the place where
|
||||
# they are defined. Therefore, their activity is passed on to the
|
||||
# calling statement.
|
||||
self.scope.read.update(lambda_scope.read - lambda_scope.bound)
|
||||
|
||||
return node
|
||||
|
||||
def visit_With(self, node):
|
||||
|
@ -393,6 +393,31 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
||||
self.assertScopeIs(scope, ('x', 'y'), ('y',))
|
||||
self.assertSymbolSetsAre(('x', 'y'), scope.bound, 'BOUND')
|
||||
|
||||
def test_nested_lambda(self):
|
||||
|
||||
def test_fn(a):
|
||||
return lambda x: (x * a)
|
||||
|
||||
node, _ = self._parse_and_analyze(test_fn)
|
||||
|
||||
fn_node = node
|
||||
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a',), ())
|
||||
|
||||
return_node = node.body[0]
|
||||
|
||||
scope = anno.getanno(return_node, anno.Static.SCOPE)
|
||||
self.assertScopeIs(scope, ('a',), ())
|
||||
|
||||
lam_def_node = return_node.value
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'x'), ())
|
||||
|
||||
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
self.assertScopeIs(scope, ('a', 'x'), ())
|
||||
self.assertSymbolSetsAre(('x',), scope.bound, 'BOUND')
|
||||
|
||||
def test_nested_function_arg_defaults(self):
|
||||
|
||||
def test_fn(a):
|
||||
|
@ -42,9 +42,6 @@ class Analyzer(cfg.GraphVisitor):
|
||||
|
||||
def __init__(self, graph, include_annotations):
|
||||
super(Analyzer, self).__init__(graph)
|
||||
# This allows communicating that nodes generate extra symbols,
|
||||
# e.g. those that a function definition closes over.
|
||||
self.extra_gen = {}
|
||||
self.include_annotations = include_annotations
|
||||
|
||||
def init_state(self, _):
|
||||
@ -56,7 +53,7 @@ class Analyzer(cfg.GraphVisitor):
|
||||
if anno.hasanno(node.ast_node, anno.Static.SCOPE):
|
||||
node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
|
||||
|
||||
gen = node_scope.read | self.extra_gen.get(node.ast_node, frozenset())
|
||||
gen = node_scope.read
|
||||
if not self.include_annotations:
|
||||
gen -= node_scope.annotations
|
||||
# TODO(mdan): verify whether composites' parents need to be added.
|
||||
@ -69,6 +66,18 @@ class Analyzer(cfg.GraphVisitor):
|
||||
live_out |= self.in_[n]
|
||||
live_in = gen | (live_out - kill)
|
||||
|
||||
reaching_functions = anno.getanno(
|
||||
node.ast_node, anno.Static.DEFINED_FNS_IN)
|
||||
for fn_ast_node in reaching_functions:
|
||||
if isinstance(fn_ast_node, gast.Lambda):
|
||||
# Exception: lambda functions are assumed to be used only in the
|
||||
# place where they are defined, and not later.
|
||||
continue
|
||||
fn_scope = anno.getanno(fn_ast_node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
|
||||
# Any closure of a reaching function definition is conservatively
|
||||
# considered live.
|
||||
live_in |= (fn_scope.read - fn_scope.bound)
|
||||
|
||||
else:
|
||||
assert self.can_ignore(node), (node.ast_node, node)
|
||||
|
||||
@ -84,7 +93,7 @@ class Analyzer(cfg.GraphVisitor):
|
||||
return prev_live_in != live_in
|
||||
|
||||
|
||||
class WholeTreeAnalyzer(transformer.Base):
|
||||
class TreeAnnotator(transformer.Base):
|
||||
"""Runs liveness analysis on each of the functions defined in the AST.
|
||||
|
||||
If a function defined other local functions, those will have separate CFGs.
|
||||
@ -94,7 +103,7 @@ class WholeTreeAnalyzer(transformer.Base):
|
||||
subfunction. For example:
|
||||
|
||||
def foo():
|
||||
# baz is live here
|
||||
# baz is live from here on
|
||||
def bar():
|
||||
print(baz)
|
||||
|
||||
@ -103,63 +112,14 @@ class WholeTreeAnalyzer(transformer.Base):
|
||||
"""
|
||||
|
||||
def __init__(self, source_info, graphs, include_annotations):
|
||||
super(WholeTreeAnalyzer, self).__init__(source_info)
|
||||
super(TreeAnnotator, self).__init__(source_info)
|
||||
self.include_annotations = include_annotations
|
||||
self.allow_skips = False
|
||||
self.graphs = graphs
|
||||
self.current_analyzer = None
|
||||
self.analyzers = {}
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
parent_analyzer = self.current_analyzer
|
||||
subgraph = self.graphs[node]
|
||||
|
||||
# Postorder tree processing makes this a bit complicated:
|
||||
# 1. construct an analyzer object and put it on stack
|
||||
# 2. recursively walk the subtree; this will initialize the analyzer's
|
||||
# in_ state properly (done in a block below)
|
||||
# 3. run the final analysis
|
||||
analyzer = Analyzer(subgraph, self.include_annotations)
|
||||
self.current_analyzer = analyzer
|
||||
node = self.generic_visit(node)
|
||||
analyzer.visit_reverse()
|
||||
|
||||
if parent_analyzer is not None:
|
||||
# Wire the state between the two subgraphs' analyzers.
|
||||
child_in_state = analyzer.in_[subgraph.entry]
|
||||
# Exception: symbols modified in the child function are local to it
|
||||
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
|
||||
for qn in body_scope.modified:
|
||||
# Note: a function modifying the symbol doesn't make that symbol
|
||||
# live at the function's entry. In fact when that happens it is
|
||||
# probably a case of undefined assignment, like this:
|
||||
#
|
||||
# bar = 0
|
||||
# def foo():
|
||||
# print(bar) # bar is undefined here!
|
||||
# bar = 1
|
||||
#
|
||||
# Hence we use discard and not remove below.
|
||||
child_in_state.discard(qn)
|
||||
parent_analyzer.extra_gen[node] = frozenset(child_in_state,)
|
||||
|
||||
self.analyzers[node] = analyzer
|
||||
self.current_analyzer = parent_analyzer
|
||||
return node
|
||||
|
||||
|
||||
class Annotator(transformer.Base):
|
||||
"""AST visitor that annotates each control flow block with live symbols."""
|
||||
|
||||
# Note: additional nodes may be added as needed.
|
||||
|
||||
def __init__(self, source_info, cross_function_analyzer):
|
||||
super(Annotator, self).__init__(source_info)
|
||||
self.cross_function_analyzer = cross_function_analyzer
|
||||
self.current_analyzer = None
|
||||
|
||||
def visit(self, node):
|
||||
node = super(Annotator, self).visit(node)
|
||||
node = super(TreeAnnotator, self).visit(node)
|
||||
if (self.current_analyzer is not None and
|
||||
isinstance(node, gast.stmt) and
|
||||
node in self.current_analyzer.graph.index):
|
||||
@ -168,14 +128,23 @@ class Annotator(transformer.Base):
|
||||
frozenset(self.current_analyzer.in_[cfg_node]))
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
def _analyze_function(self, node, is_lambda):
|
||||
parent_analyzer = self.current_analyzer
|
||||
self.current_analyzer = self.cross_function_analyzer.analyzers[node]
|
||||
|
||||
analyzer = Analyzer(self.graphs[node], self.include_annotations)
|
||||
analyzer.visit_reverse()
|
||||
self.current_analyzer = analyzer
|
||||
node = self.generic_visit(node)
|
||||
|
||||
self.current_analyzer = parent_analyzer
|
||||
return node
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
return self._analyze_function(node, is_lambda=True)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
return self._analyze_function(node, is_lambda=False)
|
||||
|
||||
def _block_statement_live_out(self, node):
|
||||
successors = self.current_analyzer.graph.stmt_next[node]
|
||||
stmt_live_out = set()
|
||||
@ -246,9 +215,5 @@ def resolve(node, source_info, graphs, include_annotations=True):
|
||||
Returns:
|
||||
ast.AST
|
||||
"""
|
||||
cross_function_analyzer = WholeTreeAnalyzer(
|
||||
source_info, graphs, include_annotations)
|
||||
node = cross_function_analyzer.visit(node)
|
||||
visitor = Annotator(source_info, cross_function_analyzer)
|
||||
node = visitor.visit(node)
|
||||
node = TreeAnnotator(source_info, graphs, include_annotations).visit(node)
|
||||
return node
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -49,7 +50,8 @@ class LivenessAnalyzerTestBase(test.TestCase):
|
||||
ctx = transformer.Context(entity_info, namer, None)
|
||||
node = activity.resolve(node, ctx)
|
||||
graphs = cfg.build(node)
|
||||
liveness.resolve(node, ctx, graphs)
|
||||
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
return node
|
||||
|
||||
def assertHasLiveOut(self, node, expected):
|
||||
@ -191,6 +193,73 @@ class LivenessAnalyzerTest(LivenessAnalyzerTestBase):
|
||||
|
||||
self.assertHasLiveOut(fn_body[0], 'a')
|
||||
|
||||
def test_live_out_nested_functions_defined_ahead(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
def foo():
|
||||
return a
|
||||
|
||||
if b:
|
||||
a = []
|
||||
|
||||
return foo
|
||||
|
||||
node = self._parse_and_analyze(test_fn)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertHasLiveOut(fn_body[1], ('a', 'foo'))
|
||||
|
||||
def test_live_out_nested_functions_defined_after(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
if b:
|
||||
a = []
|
||||
|
||||
def foo():
|
||||
return a
|
||||
|
||||
return foo
|
||||
|
||||
node = self._parse_and_analyze(test_fn)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertHasLiveOut(fn_body[0], ('a',))
|
||||
|
||||
def test_live_out_lambda(self):
|
||||
|
||||
def test_fn(a, b):
|
||||
if b:
|
||||
a = []
|
||||
|
||||
foo = lambda: a
|
||||
|
||||
if b:
|
||||
pass
|
||||
|
||||
return foo
|
||||
|
||||
node = self._parse_and_analyze(test_fn)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertHasLiveOut(fn_body[0], ('a', 'b'))
|
||||
self.assertHasLiveOut(fn_body[2], ('foo',))
|
||||
|
||||
def test_live_out_nested_functions_hidden_by_argument(self):
|
||||
|
||||
def test_fn(b):
|
||||
def foo(a):
|
||||
return a
|
||||
|
||||
if b:
|
||||
a = [] # pylint:disable=unused-variable
|
||||
|
||||
return foo
|
||||
|
||||
node = self._parse_and_analyze(test_fn)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertHasLiveOut(fn_body[1], ('foo'))
|
||||
|
||||
def test_live_out_nested_functions_isolation(self):
|
||||
|
||||
def test_fn(b):
|
||||
|
@ -0,0 +1,182 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An analysis that determines the reach of a function definition.
|
||||
|
||||
A function definition is said to reach a statement if that function may exist
|
||||
(and therefore may be called) when that statement executes.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
|
||||
|
||||
class Definition(object):
|
||||
"""Definition objects describe a unique definition of a function."""
|
||||
|
||||
def __init__(self, def_node):
|
||||
self.def_node = def_node
|
||||
|
||||
|
||||
class _NodeState(object):
|
||||
"""Abstraction for the state of the CFG walk for reaching definition analysis.
|
||||
|
||||
This is a value type. Only implements the strictly necessary operators.
|
||||
|
||||
Attributes:
|
||||
value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
|
||||
their possible definitions
|
||||
"""
|
||||
|
||||
def __init__(self, init_from=None):
|
||||
if init_from:
|
||||
self.value = set(init_from)
|
||||
else:
|
||||
self.value = set()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.value != other.value
|
||||
|
||||
def __or__(self, other):
|
||||
assert isinstance(other, _NodeState)
|
||||
result = _NodeState(self.value)
|
||||
result.value.update(other.value)
|
||||
return result
|
||||
|
||||
def __add__(self, value):
|
||||
result = _NodeState(self.value)
|
||||
result.value.add(value)
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return 'NodeState[%s]=%s' % (id(self), repr(self.value))
|
||||
|
||||
|
||||
class Analyzer(cfg.GraphVisitor):
|
||||
"""CFG visitor that determines reaching definitions at statement level."""
|
||||
|
||||
def __init__(self, graph, external_defs):
|
||||
super(Analyzer, self).__init__(graph)
|
||||
# This allows communicating that nodes have extra reaching definitions,
|
||||
# e.g. those that a function closes over.
|
||||
self.external_defs = external_defs
|
||||
|
||||
def init_state(self, _):
|
||||
return _NodeState()
|
||||
|
||||
def visit_node(self, node):
|
||||
prev_defs_out = self.out[node]
|
||||
|
||||
if node is self.graph.entry:
|
||||
defs_in = _NodeState(self.external_defs)
|
||||
else:
|
||||
defs_in = prev_defs_out
|
||||
|
||||
for n in node.prev:
|
||||
defs_in |= self.out[n]
|
||||
|
||||
defs_out = defs_in
|
||||
if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
|
||||
defs_out += node.ast_node
|
||||
|
||||
self.in_[node] = defs_in
|
||||
self.out[node] = defs_out
|
||||
|
||||
return prev_defs_out != defs_out
|
||||
|
||||
|
||||
class TreeAnnotator(transformer.Base):
|
||||
"""AST visitor that annotates each symbol name with its reaching definitions.
|
||||
|
||||
Simultaneously, the visitor runs the dataflow analysis on each function node,
|
||||
accounting for the effect of closures. For example:
|
||||
|
||||
def foo():
|
||||
def f():
|
||||
pass
|
||||
def g():
|
||||
# `def f` reaches here
|
||||
"""
|
||||
|
||||
def __init__(self, source_info, graphs):
|
||||
super(TreeAnnotator, self).__init__(source_info)
|
||||
self.graphs = graphs
|
||||
self.allow_skips = False
|
||||
self.current_analyzer = None
|
||||
|
||||
def _proces_function(self, node):
|
||||
parent_analyzer = self.current_analyzer
|
||||
subgraph = self.graphs[node]
|
||||
|
||||
if (self.current_analyzer is not None
|
||||
and node in self.current_analyzer.graph.index):
|
||||
cfg_node = self.current_analyzer.graph.index[node]
|
||||
defined_in = self.current_analyzer.in_[cfg_node].value
|
||||
else:
|
||||
defined_in = ()
|
||||
|
||||
analyzer = Analyzer(subgraph, defined_in)
|
||||
analyzer.visit_forward()
|
||||
|
||||
self.current_analyzer = analyzer
|
||||
node = self.generic_visit(node)
|
||||
self.current_analyzer = parent_analyzer
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
return self._proces_function(node)
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
return self._proces_function(node)
|
||||
|
||||
def visit(self, node):
|
||||
# This can happen before entering the top level function
|
||||
if (self.current_analyzer is not None
|
||||
and node in self.current_analyzer.graph.index):
|
||||
cfg_node = self.current_analyzer.graph.index[node]
|
||||
anno.setanno(node, anno.Static.DEFINED_FNS_IN,
|
||||
self.current_analyzer.in_[cfg_node].value)
|
||||
|
||||
extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
|
||||
if extra_node is not None:
|
||||
cfg_node = self.current_analyzer.graph.index[extra_node]
|
||||
anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
|
||||
self.current_analyzer.in_[cfg_node].value)
|
||||
|
||||
return super(TreeAnnotator, self).visit(node)
|
||||
|
||||
|
||||
def resolve(node, source_info, graphs):
|
||||
"""Resolves reaching definitions for each symbol.
|
||||
|
||||
Args:
|
||||
node: ast.AST
|
||||
source_info: transformer.SourceInfo
|
||||
graphs: Dict[ast.FunctionDef, cfg.Graph]
|
||||
Returns:
|
||||
ast.AST
|
||||
"""
|
||||
visitor = TreeAnnotator(source_info, graphs)
|
||||
node = visitor.visit(node)
|
||||
return node
|
@ -0,0 +1,58 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for reaching_fndefs module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReachingFndefsAnalyzerTest(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn):
|
||||
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||
node, source = parser.parse_entity(test_fn, future_features=())
|
||||
entity_info = transformer.EntityInfo(
|
||||
name=test_fn.__name__,
|
||||
source_code=source,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace={})
|
||||
node = qual_names.resolve(node)
|
||||
namer = naming.Namer({})
|
||||
ctx = transformer.Context(entity_info, namer, None)
|
||||
node = activity.resolve(node, ctx)
|
||||
graphs = cfg.build(node)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||||
return node
|
||||
|
||||
def assertHasFnDefs(self, node):
|
||||
anno.getanno(node, anno.Static.DEFINED_FNS_IN)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -36,6 +36,7 @@ from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
@ -208,6 +209,7 @@ def _live_tensors(f, attr_name="inputs"):
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_fndefs.resolve(node, ctx, graphs)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
|
||||
op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)
|
||||
|
Loading…
Reference in New Issue
Block a user