Internal cleanup: let each converter only perform the analysis it needs. Since the bottleneck is inside the conversion process, this is expected to speed it up noticeably.
PiperOrigin-RevId: 309550095 Change-Id: I7e9a3713b7d9455ca407b847410936de70739534
This commit is contained in:
parent
81688c64e0
commit
b951b63196
tensorflow/python/autograph
@ -48,4 +48,5 @@ class AssertTransformer(converter.Base):
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
return AssertTransformer(ctx).visit(node)
|
||||
node = AssertTransformer(ctx).visit(node)
|
||||
return node
|
||||
|
@ -20,7 +20,9 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
|
||||
|
||||
|
||||
@ -179,6 +181,9 @@ class BreakTransformer(converter.Base):
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
transformer = BreakTransformer(ctx)
|
||||
node = transformer.visit(node)
|
||||
return node
|
||||
|
@ -29,6 +29,7 @@ import gast
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
|
||||
@ -218,4 +219,7 @@ def transform(node, ctx):
|
||||
node: The transformed AST
|
||||
new_names: set(string), containing any newly-generated names
|
||||
"""
|
||||
return CallTreeTransformer(ctx).visit(node)
|
||||
node = qual_names.resolve(node)
|
||||
|
||||
node = CallTreeTransformer(ctx).visit(node)
|
||||
return node
|
||||
|
@ -20,7 +20,9 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
|
||||
|
||||
|
||||
@ -159,6 +161,8 @@ class ContinueCanonicalizationTransformer(converter.Base):
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
transformer = ContinueCanonicalizationTransformer(ctx)
|
||||
node = transformer.visit(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
node = ContinueCanonicalizationTransformer(ctx).visit(node)
|
||||
return node
|
||||
|
@ -24,9 +24,14 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.utils import compat_util
|
||||
|
||||
|
||||
@ -528,9 +533,22 @@ class ControlFlowTransformer(converter.Base):
|
||||
undefined_assigns=undefined_assigns)
|
||||
|
||||
|
||||
class AnnotatedDef(reaching_definitions.Definition):
|
||||
|
||||
def __init__(self):
|
||||
super(AnnotatedDef, self).__init__()
|
||||
self.directives = {}
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
transformer = ControlFlowTransformer(ctx)
|
||||
return transformer.visit(node)
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
|
||||
node = ControlFlowTransformer(ctx).visit(node)
|
||||
return node
|
||||
|
||||
|
||||
compat_util.deprecated_py2_support(__name__)
|
||||
|
@ -27,9 +27,14 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
|
||||
|
||||
# TODO(mdan): Refactor functions to make them smaller.
|
||||
@ -604,6 +609,19 @@ class ControlFlowTransformer(converter.Base):
|
||||
opts=opts)
|
||||
|
||||
|
||||
class AnnotatedDef(reaching_definitions.Definition):
|
||||
|
||||
def __init__(self):
|
||||
super(AnnotatedDef, self).__init__()
|
||||
self.directives = {}
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
|
||||
node = liveness.resolve(node, ctx, graphs)
|
||||
|
||||
node = ControlFlowTransformer(ctx).visit(node)
|
||||
return node
|
||||
|
@ -23,7 +23,9 @@ import gast
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
|
||||
|
||||
@ -139,4 +141,7 @@ class FunctionTransformer(converter.Base):
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
return FunctionTransformer(ctx).visit(node)
|
||||
|
@ -36,7 +36,9 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
|
||||
|
||||
|
||||
@ -235,4 +237,7 @@ class ListTransformer(converter.Base):
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
return ListTransformer(ctx).visit(node)
|
||||
|
@ -23,7 +23,9 @@ import gast
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
|
||||
|
||||
|
||||
@ -396,8 +398,14 @@ def transform(node, ctx, default_to_null_return=True):
|
||||
# Note: Technically, these two could be merged into a single walk, but
|
||||
# keeping them separate helps with readability.
|
||||
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
node = ConditionalReturnRewriter(ctx).visit(node)
|
||||
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
|
||||
transformer = ReturnStatementsTransformer(
|
||||
ctx, default_to_null_return=default_to_null_return)
|
||||
node = transformer.visit(node)
|
||||
|
@ -68,14 +68,9 @@ import enum
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import liveness
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# TODO(mdan): These contexts can be refactored into first class objects.
|
||||
@ -326,56 +321,3 @@ class Base(transformer.Base):
|
||||
return super(Base, self).visit(node)
|
||||
finally:
|
||||
self._ast_depth -= 1
|
||||
|
||||
|
||||
class AnnotatedDef(reaching_definitions.Definition):
|
||||
|
||||
def __init__(self):
|
||||
super(AnnotatedDef, self).__init__()
|
||||
self.directives = {}
|
||||
|
||||
|
||||
def standard_analysis(node, context, is_initial=False):
|
||||
"""Performs a complete static analysis of the given code.
|
||||
|
||||
Args:
|
||||
node: ast.AST
|
||||
context: converter.EntityContext
|
||||
is_initial: bool, whether this is the initial analysis done on the input
|
||||
source code
|
||||
|
||||
Returns:
|
||||
ast.AST, same as node, with the static analysis annotations added
|
||||
"""
|
||||
# TODO(mdan): Clear static analysis here.
|
||||
# TODO(mdan): Consider not running all analyses every time.
|
||||
# TODO(mdan): Don't return a node because it's modified by reference.
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, context, None)
|
||||
node = reaching_definitions.resolve(node, context, graphs, AnnotatedDef)
|
||||
node = liveness.resolve(node, context, graphs)
|
||||
if is_initial:
|
||||
anno.dup(
|
||||
node,
|
||||
{
|
||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
||||
},
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
def apply_(node, context, converter_module):
|
||||
"""Applies a converter to an AST.
|
||||
|
||||
Args:
|
||||
node: ast.AST
|
||||
context: converter.EntityContext
|
||||
converter_module: converter.Base
|
||||
|
||||
Returns:
|
||||
ast.AST, the result of applying converter to node
|
||||
"""
|
||||
node = standard_analysis(node, context)
|
||||
node = converter_module.transform(node, context)
|
||||
return node
|
||||
|
@ -31,12 +31,17 @@ from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -137,8 +142,7 @@ class TestCase(test.TestCase):
|
||||
|
||||
if not isinstance(converter_module, (list, tuple)):
|
||||
converter_module = (converter_module,)
|
||||
for i, m in enumerate(converter_module):
|
||||
node = converter.standard_analysis(node, ctx, is_initial=not i)
|
||||
for m in converter_module:
|
||||
node = m.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, namespace, tf_symbols) as result:
|
||||
@ -177,5 +181,16 @@ class TestCase(test.TestCase):
|
||||
namespace=namespace)
|
||||
ctx = transformer.Context(entity_info, namer, program_ctx)
|
||||
origin_info.resolve_entity(node, source, test_fn)
|
||||
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
anno.dup(
|
||||
node,
|
||||
{
|
||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
||||
},
|
||||
)
|
||||
|
||||
return node, ctx
|
||||
|
@ -41,9 +41,14 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cache
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transpiler
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.util import tf_inspect
|
||||
@ -58,24 +63,35 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
||||
# TODO(mdan): Insert list_comprehensions somewhere.
|
||||
unsupported_features_checker.verify(node)
|
||||
|
||||
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||
node = converter.apply_(node, ctx, functions)
|
||||
node = converter.apply_(node, ctx, directives)
|
||||
node = converter.apply_(node, ctx, break_statements)
|
||||
# Run initial analysis.
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
anno.dup(
|
||||
node,
|
||||
{
|
||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
||||
},
|
||||
)
|
||||
|
||||
node = functions.transform(node, ctx)
|
||||
node = directives.transform(node, ctx)
|
||||
node = break_statements.transform(node, ctx)
|
||||
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
||||
node = converter.apply_(node, ctx, asserts)
|
||||
node = asserts.transform(node, ctx)
|
||||
# Note: sequencing continue canonicalization before for loop one avoids
|
||||
# dealing with the extra loop increment operation that the for
|
||||
# canonicalization creates.
|
||||
node = converter.apply_(node, ctx, continue_statements)
|
||||
node = converter.apply_(node, ctx, return_statements)
|
||||
node = continue_statements.transform(node, ctx)
|
||||
node = return_statements.transform(node, ctx)
|
||||
if ctx.user.options.uses(converter.Feature.LISTS):
|
||||
node = converter.apply_(node, ctx, lists)
|
||||
node = converter.apply_(node, ctx, slices)
|
||||
node = converter.apply_(node, ctx, call_trees)
|
||||
node = converter.apply_(node, ctx, control_flow)
|
||||
node = converter.apply_(node, ctx, conditional_expressions)
|
||||
node = converter.apply_(node, ctx, logical_expressions)
|
||||
node = lists.transform(node, ctx)
|
||||
node = slices.transform(node, ctx)
|
||||
node = call_trees.transform(node, ctx)
|
||||
node = control_flow.transform(node, ctx)
|
||||
node = conditional_expressions.transform(node, ctx)
|
||||
node = logical_expressions.transform(node, ctx)
|
||||
return node
|
||||
|
||||
|
||||
|
@ -45,10 +45,12 @@ class Definition(object):
|
||||
|
||||
Attributes:
|
||||
param_of: Optional[ast.AST]
|
||||
directives: Dict, optional definition annotations
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.param_of = None
|
||||
self.directives = {}
|
||||
|
||||
def __repr__(self):
|
||||
return '%s[%d]' % (self.__class__.__name__, id(self))
|
||||
@ -274,7 +276,7 @@ class TreeAnnotator(transformer.Base):
|
||||
return node
|
||||
|
||||
|
||||
def resolve(node, source_info, graphs, definition_factory):
|
||||
def resolve(node, source_info, graphs, definition_factory=Definition):
|
||||
"""Resolves reaching definitions for each symbol.
|
||||
|
||||
Args:
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import enum
|
||||
|
||||
import gast
|
||||
|
||||
@ -28,6 +29,14 @@ from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
|
||||
|
||||
class AnalysisLevel(enum.IntEnum):
|
||||
|
||||
NONE = 0
|
||||
ACTIVITY = 1
|
||||
DEFINEDNESS = 2
|
||||
LIVENESS = 3
|
||||
|
||||
|
||||
# TODO(znado): Use namedtuple.
|
||||
class Context(object):
|
||||
"""Contains information about a source code transformation.
|
||||
|
Loading…
Reference in New Issue
Block a user