Fix liveness analysis for variables closed over by functions. Previously, these variables were only live at the closing function's point of definition. After this change, these variables are live at every statement that may be reached by this function definition. This is more consistent with Python's late binding mechanism.

For reasons of backward compatibility, lambda functions continue to be treated as before - any variables they close over are only live in the statement that contains the lambda. This is consistent with the typical usage of lambdas as arguments to higher-order functions like map, reduce, etc..

PiperOrigin-RevId: 310334670
Change-Id: I04da533e0017178156851d43050cd81b05a49704
This commit is contained in:
Dan Moldovan 2020-05-07 04:38:07 -07:00 committed by TensorFlower Gardener
parent a8b9d64276
commit 0103d8ee27
12 changed files with 528 additions and 130 deletions

View File

@ -32,6 +32,7 @@ from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
from tensorflow.python.autograph.utils import compat_util
@ -554,7 +555,8 @@ def transform(node, ctx):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
node = reaching_definitions.resolve(node, ctx, graphs)
node = reaching_fndefs.resolve(node, ctx, graphs)
node = liveness.resolve(node, ctx, graphs)
node = ControlFlowTransformer(ctx).visit(node)

View File

@ -35,6 +35,7 @@ from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
# TODO(mdan): Refactor functions to make them smaller.
@ -630,7 +631,8 @@ def transform(node, ctx):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
node = reaching_definitions.resolve(node, ctx, graphs)
node = reaching_fndefs.resolve(node, ctx, graphs)
node = liveness.resolve(node, ctx, graphs)
node = ControlFlowTransformer(ctx).visit(node)

View File

@ -16,6 +16,88 @@ should not be confused with TensorFlow variables.
Key Term: A TensorFlow loop variable (or loop variable for short) refers to a
value (typically a `tf.Tensor`) modified by a loop. See `tf.while_loop`.
### Undefined and None values in TensorFlow
TensorFlow does not support undefined or `None` values. All tensors must have
a value.
Example:
```
x = tf.cond(
tf.random.uniform(()) > 0.5,
lambda: tf.constant(1),
lambda: None) # Error -- a Tensor cannot be None
```
The same restriction carries over in AutoGraph. If a variable is created inside
control flow, and used after, then it must be defined before the control flow
statement:
```
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
else:
x = None
tf.print(x) # Error -- x may be None here
```
For this reason, AutoGraph forbids variables to be defined in only one branch
of a TensorFlow conditional, if the variable is used afterwards:
```
del x
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
else:
pass
tf.print(x) # Error -- x may be undefined here
```
Note that if the variable is not used after the control flow statement, then it
is considered local to the control flow block, and is not subject to these
restrictions.
```
del x
if tf.random.uniform(()) > 0.5:
x = tf.constant(1) # Okay -- x does not need to be returned from the TF cond
else:
pass
```
Similarly, variables may not be defined inside a TensorFlow loop, unless they
are local to the loop. A variable is local to the loop if (1) it's not used
after the loop and (2) the value from a previour iteration is not used in the
next iteration:
```
del x
while tf.random.uniform(()) > 0.5: # Error -- x must be defined before the loop
x = tf.constant(1)
tf.print(x)
```
```
del x
while tf.random.uniform(()) > 0.5: # Okay -- x is local to the loop
x = tf.constant(1)
```
Avoid these limitations by defining a default value before the control flow
statement:
```
x = tf.constant()
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
tf.print(x) # Okay -- x is either 0 or 1
```
Note: `None` values and undefined symbols are allowed in Eager control flow,
because Eager execution uses Python control flow, rather than TensorFlow
control flow ops.
### Indirect modifications and hidden side effects in TensorFlow control flow
Key Point: We recommend using a functional programming style, immutable Python
@ -187,6 +269,62 @@ objects, but it does support basic collection objects such as `list`, `dict`,
`tuple`, `namedtuple` and their subclasses. Design your objects as subclasses
of [namedtuple](https://docs.python.org/3/library/collections.html#collections.namedtuple).
#### Variables closed over by lambda functions
AutoGraph assumes that variables that local functions close over may be used
anywhere in the parent function, because in general it is possible to hide a
function call in almost any Python statement). For this reason, these variables
are accounted within TensorFlow loops.
For example, the following code correctly captures `a` in the TensorFlow loop
variables:
```
a = 0
def f():
tf.print(a)
for i in tf.range(3):
a = i
f() # Prints 2
```
An consequence is that these variables must be defined before the loop (see
Undefined and None values above). So the following code will raise an error,
even if the variable is never used after the loop:
```
def f():
tf.print(a)
for i in tf.range(3): # Error -- `a` must be defined before the loop.
a = i
```
However, lambda functions are handled differently, for reasons of backward
compatibility. Lambda functions are assumed to be used in the statement where
they are used, or at least in the same block.
```
a = 0
foo(lambda: a) # This lambda is not expected to be called anywhere else.
for i in tf.range(3): # Okay -- `a` is local to the loop.
a = i
```
Due to that reason, the following code will not work as expected for TensorFlow
loops.
```
a = 0
l = lambda: tf.print(a)
for i in tf.range(3):
a = i # `a` is considered local to the loop
l() # Prints 0!
```
Note that none of these restrictions only apply to TensorFlow loops; Python
loops correctly correctly handle closures in all cases.
### Python collections in TensorFlow control flow
Key Point: Use TensorFlow collection classes instead of Python collections.
@ -489,69 +627,6 @@ while tf.random.uniform(()) > 0.5:
x = tf.constant((1, 2, 3)) # Error -- inconsistent shapes: (), (3,)
```
### Undefined and None values in TensorFlow
TensorFlow does not support undefined and `None` values. All tensors must have
a value.
Example:
```
x = tf.cond(
tf.random.uniform(()) > 0.5,
lambda: tf.constant(1),
lambda: None) # Error -- a Tensor cannot be None
```
The same restriction carries over in AutoGraph, but only if the symbol is used
after the conditional (otherwise AutoGraph avoids making it a return value
of the `tf.cond`):
```
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
else:
x = None
tf.print(x) # Error -- x may be None here
```
A related but less obvious restriction in AutoGraph forbids symbols to be
defined in only one branch of TensorFlow control flow, if the symbol is
used afterwards:
```
del x
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
else:
pass
tf.print(x) # Error -- x may be undefined here
```
Similarly, variables defined in a loop may not be used outside the loop, again
if the symbol is used afterwards:
```
del x
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
tf.print(x) # Error -- x may be undefined here
```
Avoid these limitations by defining a default value before the control flow
statement:
```
x = tf.constant()
if tf.random.uniform(()) > 0.5:
x = tf.constant(1)
tf.print(x) # Okay -- x is either 0 or 1
```
Note: `None` values and undefined symbols are allowed in Eager control flow,
because Eager execution uses Python control flow, rather than TensorFlow
control flow ops.
### Access to source code
Key point: AutoGraph can only handle functions whose source code can be

View File

@ -93,6 +93,9 @@ class Static(NoValue):
ORIG_DEFINITIONS = (
'The value of DEFINITIONS that applied to the original code before any'
' conversion.')
DEFINED_FNS_IN = (
'Local function definitions that may exist when exiting the node. See'
' reaching_fndefs.py')
DEFINED_VARS_IN = (
'Symbols defined when entering the node. See reaching_definitions.py.')
LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')

View File

@ -23,6 +23,7 @@ py_library(
"annos.py",
"liveness.py",
"reaching_definitions.py",
"reaching_fndefs.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],

View File

@ -617,9 +617,23 @@ class ActivityAnalyzer(transformer.Base):
# TODO(mdan): Do remove it, it's confusing.
self._enter_scope(False)
node.body = self.visit(node.body)
# The lambda body can contain nodes of types normally not found as
# statements, and may not have the SCOPE annotation needed by the CFG.
# So we attach one if necessary.
if not anno.hasanno(node.body, anno.Static.SCOPE):
anno.setanno(node.body, anno.Static.SCOPE, self.scope)
self._exit_and_record_scope(node, NodeAnno.BODY_SCOPE)
lambda_scope = self.scope
self._exit_and_record_scope(node, NodeAnno.ARGS_AND_BODY_SCOPE)
# Exception: lambdas are assumed to be used in the place where
# they are defined. Therefore, their activity is passed on to the
# calling statement.
self.scope.read.update(lambda_scope.read - lambda_scope.bound)
return node
def visit_With(self, node):

View File

@ -393,6 +393,31 @@ class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
self.assertScopeIs(scope, ('x', 'y'), ('y',))
self.assertSymbolSetsAre(('x', 'y'), scope.bound, 'BOUND')
def test_nested_lambda(self):
def test_fn(a):
return lambda x: (x * a)
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a',), ())
return_node = node.body[0]
scope = anno.getanno(return_node, anno.Static.SCOPE)
self.assertScopeIs(scope, ('a',), ())
lam_def_node = return_node.value
scope = anno.getanno(lam_def_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'x'), ())
scope = anno.getanno(lam_def_node, NodeAnno.ARGS_AND_BODY_SCOPE)
self.assertScopeIs(scope, ('a', 'x'), ())
self.assertSymbolSetsAre(('x',), scope.bound, 'BOUND')
def test_nested_function_arg_defaults(self):
def test_fn(a):

View File

@ -42,9 +42,6 @@ class Analyzer(cfg.GraphVisitor):
def __init__(self, graph, include_annotations):
super(Analyzer, self).__init__(graph)
# This allows communicating that nodes generate extra symbols,
# e.g. those that a function definition closes over.
self.extra_gen = {}
self.include_annotations = include_annotations
def init_state(self, _):
@ -56,7 +53,7 @@ class Analyzer(cfg.GraphVisitor):
if anno.hasanno(node.ast_node, anno.Static.SCOPE):
node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
gen = node_scope.read | self.extra_gen.get(node.ast_node, frozenset())
gen = node_scope.read
if not self.include_annotations:
gen -= node_scope.annotations
# TODO(mdan): verify whether composites' parents need to be added.
@ -69,6 +66,18 @@ class Analyzer(cfg.GraphVisitor):
live_out |= self.in_[n]
live_in = gen | (live_out - kill)
reaching_functions = anno.getanno(
node.ast_node, anno.Static.DEFINED_FNS_IN)
for fn_ast_node in reaching_functions:
if isinstance(fn_ast_node, gast.Lambda):
# Exception: lambda functions are assumed to be used only in the
# place where they are defined, and not later.
continue
fn_scope = anno.getanno(fn_ast_node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
# Any closure of a reaching function definition is conservatively
# considered live.
live_in |= (fn_scope.read - fn_scope.bound)
else:
assert self.can_ignore(node), (node.ast_node, node)
@ -84,7 +93,7 @@ class Analyzer(cfg.GraphVisitor):
return prev_live_in != live_in
class WholeTreeAnalyzer(transformer.Base):
class TreeAnnotator(transformer.Base):
"""Runs liveness analysis on each of the functions defined in the AST.
If a function defined other local functions, those will have separate CFGs.
@ -94,7 +103,7 @@ class WholeTreeAnalyzer(transformer.Base):
subfunction. For example:
def foo():
# baz is live here
# baz is live from here on
def bar():
print(baz)
@ -103,63 +112,14 @@ class WholeTreeAnalyzer(transformer.Base):
"""
def __init__(self, source_info, graphs, include_annotations):
super(WholeTreeAnalyzer, self).__init__(source_info)
super(TreeAnnotator, self).__init__(source_info)
self.include_annotations = include_annotations
self.allow_skips = False
self.graphs = graphs
self.current_analyzer = None
self.analyzers = {}
def visit_FunctionDef(self, node):
parent_analyzer = self.current_analyzer
subgraph = self.graphs[node]
# Postorder tree processing makes this a bit complicated:
# 1. construct an analyzer object and put it on stack
# 2. recursively walk the subtree; this will initialize the analyzer's
# in_ state properly (done in a block below)
# 3. run the final analysis
analyzer = Analyzer(subgraph, self.include_annotations)
self.current_analyzer = analyzer
node = self.generic_visit(node)
analyzer.visit_reverse()
if parent_analyzer is not None:
# Wire the state between the two subgraphs' analyzers.
child_in_state = analyzer.in_[subgraph.entry]
# Exception: symbols modified in the child function are local to it
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
for qn in body_scope.modified:
# Note: a function modifying the symbol doesn't make that symbol
# live at the function's entry. In fact when that happens it is
# probably a case of undefined assignment, like this:
#
# bar = 0
# def foo():
# print(bar) # bar is undefined here!
# bar = 1
#
# Hence we use discard and not remove below.
child_in_state.discard(qn)
parent_analyzer.extra_gen[node] = frozenset(child_in_state,)
self.analyzers[node] = analyzer
self.current_analyzer = parent_analyzer
return node
class Annotator(transformer.Base):
"""AST visitor that annotates each control flow block with live symbols."""
# Note: additional nodes may be added as needed.
def __init__(self, source_info, cross_function_analyzer):
super(Annotator, self).__init__(source_info)
self.cross_function_analyzer = cross_function_analyzer
self.current_analyzer = None
def visit(self, node):
node = super(Annotator, self).visit(node)
node = super(TreeAnnotator, self).visit(node)
if (self.current_analyzer is not None and
isinstance(node, gast.stmt) and
node in self.current_analyzer.graph.index):
@ -168,14 +128,23 @@ class Annotator(transformer.Base):
frozenset(self.current_analyzer.in_[cfg_node]))
return node
def visit_FunctionDef(self, node):
def _analyze_function(self, node, is_lambda):
parent_analyzer = self.current_analyzer
self.current_analyzer = self.cross_function_analyzer.analyzers[node]
analyzer = Analyzer(self.graphs[node], self.include_annotations)
analyzer.visit_reverse()
self.current_analyzer = analyzer
node = self.generic_visit(node)
self.current_analyzer = parent_analyzer
return node
def visit_Lambda(self, node):
return self._analyze_function(node, is_lambda=True)
def visit_FunctionDef(self, node):
return self._analyze_function(node, is_lambda=False)
def _block_statement_live_out(self, node):
successors = self.current_analyzer.graph.stmt_next[node]
stmt_live_out = set()
@ -246,9 +215,5 @@ def resolve(node, source_info, graphs, include_annotations=True):
Returns:
ast.AST
"""
cross_function_analyzer = WholeTreeAnalyzer(
source_info, graphs, include_annotations)
node = cross_function_analyzer.visit(node)
visitor = Annotator(source_info, cross_function_analyzer)
node = visitor.visit(node)
node = TreeAnnotator(source_info, graphs, include_annotations).visit(node)
return node

View File

@ -26,6 +26,7 @@ from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
from tensorflow.python.platform import test
@ -49,7 +50,8 @@ class LivenessAnalyzerTestBase(test.TestCase):
ctx = transformer.Context(entity_info, namer, None)
node = activity.resolve(node, ctx)
graphs = cfg.build(node)
liveness.resolve(node, ctx, graphs)
node = reaching_fndefs.resolve(node, ctx, graphs)
node = liveness.resolve(node, ctx, graphs)
return node
def assertHasLiveOut(self, node, expected):
@ -191,6 +193,73 @@ class LivenessAnalyzerTest(LivenessAnalyzerTestBase):
self.assertHasLiveOut(fn_body[0], 'a')
def test_live_out_nested_functions_defined_ahead(self):
def test_fn(a, b):
def foo():
return a
if b:
a = []
return foo
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasLiveOut(fn_body[1], ('a', 'foo'))
def test_live_out_nested_functions_defined_after(self):
def test_fn(a, b):
if b:
a = []
def foo():
return a
return foo
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasLiveOut(fn_body[0], ('a',))
def test_live_out_lambda(self):
def test_fn(a, b):
if b:
a = []
foo = lambda: a
if b:
pass
return foo
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasLiveOut(fn_body[0], ('a', 'b'))
self.assertHasLiveOut(fn_body[2], ('foo',))
def test_live_out_nested_functions_hidden_by_argument(self):
def test_fn(b):
def foo(a):
return a
if b:
a = [] # pylint:disable=unused-variable
return foo
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasLiveOut(fn_body[1], ('foo'))
def test_live_out_nested_functions_isolation(self):
def test_fn(b):

View File

@ -0,0 +1,182 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An analysis that determines the reach of a function definition.
A function definition is said to reach a statement if that function may exist
(and therefore may be called) when that statement executes.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import transformer
class Definition(object):
"""Definition objects describe a unique definition of a function."""
def __init__(self, def_node):
self.def_node = def_node
class _NodeState(object):
"""Abstraction for the state of the CFG walk for reaching definition analysis.
This is a value type. Only implements the strictly necessary operators.
Attributes:
value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
their possible definitions
"""
def __init__(self, init_from=None):
if init_from:
self.value = set(init_from)
else:
self.value = set()
def __eq__(self, other):
return self.value == other.value
def __ne__(self, other):
return self.value != other.value
def __or__(self, other):
assert isinstance(other, _NodeState)
result = _NodeState(self.value)
result.value.update(other.value)
return result
def __add__(self, value):
result = _NodeState(self.value)
result.value.add(value)
return result
def __repr__(self):
return 'NodeState[%s]=%s' % (id(self), repr(self.value))
class Analyzer(cfg.GraphVisitor):
"""CFG visitor that determines reaching definitions at statement level."""
def __init__(self, graph, external_defs):
super(Analyzer, self).__init__(graph)
# This allows communicating that nodes have extra reaching definitions,
# e.g. those that a function closes over.
self.external_defs = external_defs
def init_state(self, _):
return _NodeState()
def visit_node(self, node):
prev_defs_out = self.out[node]
if node is self.graph.entry:
defs_in = _NodeState(self.external_defs)
else:
defs_in = prev_defs_out
for n in node.prev:
defs_in |= self.out[n]
defs_out = defs_in
if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
defs_out += node.ast_node
self.in_[node] = defs_in
self.out[node] = defs_out
return prev_defs_out != defs_out
class TreeAnnotator(transformer.Base):
"""AST visitor that annotates each symbol name with its reaching definitions.
Simultaneously, the visitor runs the dataflow analysis on each function node,
accounting for the effect of closures. For example:
def foo():
def f():
pass
def g():
# `def f` reaches here
"""
def __init__(self, source_info, graphs):
super(TreeAnnotator, self).__init__(source_info)
self.graphs = graphs
self.allow_skips = False
self.current_analyzer = None
def _proces_function(self, node):
parent_analyzer = self.current_analyzer
subgraph = self.graphs[node]
if (self.current_analyzer is not None
and node in self.current_analyzer.graph.index):
cfg_node = self.current_analyzer.graph.index[node]
defined_in = self.current_analyzer.in_[cfg_node].value
else:
defined_in = ()
analyzer = Analyzer(subgraph, defined_in)
analyzer.visit_forward()
self.current_analyzer = analyzer
node = self.generic_visit(node)
self.current_analyzer = parent_analyzer
return node
def visit_FunctionDef(self, node):
return self._proces_function(node)
def visit_Lambda(self, node):
return self._proces_function(node)
def visit(self, node):
# This can happen before entering the top level function
if (self.current_analyzer is not None
and node in self.current_analyzer.graph.index):
cfg_node = self.current_analyzer.graph.index[node]
anno.setanno(node, anno.Static.DEFINED_FNS_IN,
self.current_analyzer.in_[cfg_node].value)
extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
if extra_node is not None:
cfg_node = self.current_analyzer.graph.index[extra_node]
anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
self.current_analyzer.in_[cfg_node].value)
return super(TreeAnnotator, self).visit(node)
def resolve(node, source_info, graphs):
"""Resolves reaching definitions for each symbol.
Args:
node: ast.AST
source_info: transformer.SourceInfo
graphs: Dict[ast.FunctionDef, cfg.Graph]
Returns:
ast.AST
"""
visitor = TreeAnnotator(source_info, graphs)
node = visitor.visit(node)
return node

View File

@ -0,0 +1,58 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for reaching_fndefs module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
from tensorflow.python.platform import test
class ReachingFndefsAnalyzerTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
# TODO(mdan): Use a custom FunctionTransformer here.
node, source = parser.parse_entity(test_fn, future_features=())
entity_info = transformer.EntityInfo(
name=test_fn.__name__,
source_code=source,
source_file=None,
future_features=(),
namespace={})
node = qual_names.resolve(node)
namer = naming.Namer({})
ctx = transformer.Context(entity_info, namer, None)
node = activity.resolve(node, ctx)
graphs = cfg.build(node)
node = reaching_definitions.resolve(node, ctx, graphs)
node = reaching_fndefs.resolve(node, ctx, graphs)
return node
def assertHasFnDefs(self, node):
anno.getanno(node, anno.Static.DEFINED_FNS_IN)
if __name__ == '__main__':
test.main()

View File

@ -36,6 +36,7 @@ from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
@ -208,6 +209,7 @@ def _live_tensors(f, attr_name="inputs"):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_fndefs.resolve(node, ctx, graphs)
node = liveness.resolve(node, ctx, graphs)
op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)