Internal cleanup: let each converter only perform the analysis it needs. Since the bottleneck is inside the conversion process, this is expected to speed it up noticeably.

PiperOrigin-RevId: 309550095
Change-Id: I7e9a3713b7d9455ca407b847410936de70739534
This commit is contained in:
Dan Moldovan 2020-05-02 05:13:45 -07:00 committed by TensorFlower Gardener
parent 81688c64e0
commit b951b63196
14 changed files with 133 additions and 81 deletions

View File

@ -48,4 +48,5 @@ class AssertTransformer(converter.Base):
def transform(node, ctx):
return AssertTransformer(ctx).visit(node)
node = AssertTransformer(ctx).visit(node)
return node

View File

@ -20,7 +20,9 @@ from __future__ import print_function
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
@ -179,6 +181,9 @@ class BreakTransformer(converter.Base):
def transform(node, ctx):
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
transformer = BreakTransformer(ctx)
node = transformer.visit(node)
return node

View File

@ -29,6 +29,7 @@ import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.utils import ag_logging
@ -218,4 +219,7 @@ def transform(node, ctx):
node: The transformed AST
new_names: set(string), containing any newly-generated names
"""
return CallTreeTransformer(ctx).visit(node)
node = qual_names.resolve(node)
node = CallTreeTransformer(ctx).visit(node)
return node

View File

@ -20,7 +20,9 @@ from __future__ import print_function
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
@ -159,6 +161,8 @@ class ContinueCanonicalizationTransformer(converter.Base):
def transform(node, ctx):
transformer = ContinueCanonicalizationTransformer(ctx)
node = transformer.visit(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = ContinueCanonicalizationTransformer(ctx).visit(node)
return node

View File

@ -24,9 +24,14 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.utils import compat_util
@ -528,9 +533,22 @@ class ControlFlowTransformer(converter.Base):
undefined_assigns=undefined_assigns)
class AnnotatedDef(reaching_definitions.Definition):
def __init__(self):
super(AnnotatedDef, self).__init__()
self.directives = {}
def transform(node, ctx):
transformer = ControlFlowTransformer(ctx)
return transformer.visit(node)
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
node = liveness.resolve(node, ctx, graphs)
node = ControlFlowTransformer(ctx).visit(node)
return node
compat_util.deprecated_py2_support(__name__)

View File

@ -27,9 +27,14 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
# TODO(mdan): Refactor functions to make them smaller.
@ -604,6 +609,19 @@ class ControlFlowTransformer(converter.Base):
opts=opts)
class AnnotatedDef(reaching_definitions.Definition):
def __init__(self):
super(AnnotatedDef, self).__init__()
self.directives = {}
def transform(node, ctx):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
node = liveness.resolve(node, ctx, graphs)
node = ControlFlowTransformer(ctx).visit(node)
return node

View File

@ -23,7 +23,9 @@ import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
@ -139,4 +141,7 @@ class FunctionTransformer(converter.Base):
def transform(node, ctx):
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
return FunctionTransformer(ctx).visit(node)

View File

@ -36,7 +36,9 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
@ -235,4 +237,7 @@ class ListTransformer(converter.Base):
def transform(node, ctx):
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
return ListTransformer(ctx).visit(node)

View File

@ -23,7 +23,9 @@ import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
@ -396,8 +398,14 @@ def transform(node, ctx, default_to_null_return=True):
# Note: Technically, these two could be merged into a single walk, but
# keeping them separate helps with readability.
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = ConditionalReturnRewriter(ctx).visit(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
transformer = ReturnStatementsTransformer(
ctx, default_to_null_return=default_to_null_return)
node = transformer.visit(node)

View File

@ -68,14 +68,9 @@ import enum
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.util.tf_export import tf_export
# TODO(mdan): These contexts can be refactored into first class objects.
@ -326,56 +321,3 @@ class Base(transformer.Base):
return super(Base, self).visit(node)
finally:
self._ast_depth -= 1
class AnnotatedDef(reaching_definitions.Definition):
def __init__(self):
super(AnnotatedDef, self).__init__()
self.directives = {}
def standard_analysis(node, context, is_initial=False):
"""Performs a complete static analysis of the given code.
Args:
node: ast.AST
context: converter.EntityContext
is_initial: bool, whether this is the initial analysis done on the input
source code
Returns:
ast.AST, same as node, with the static analysis annotations added
"""
# TODO(mdan): Clear static analysis here.
# TODO(mdan): Consider not running all analyses every time.
# TODO(mdan): Don't return a node because it's modified by reference.
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, context, None)
node = reaching_definitions.resolve(node, context, graphs, AnnotatedDef)
node = liveness.resolve(node, context, graphs)
if is_initial:
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
return node
def apply_(node, context, converter_module):
"""Applies a converter to an AST.
Args:
node: ast.AST
context: converter.EntityContext
converter_module: converter.Base
Returns:
ast.AST, the result of applying converter to node
"""
node = standard_analysis(node, context)
node = converter_module.transform(node, context)
return node

View File

@ -31,12 +31,17 @@ from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.platform import test
@ -137,8 +142,7 @@ class TestCase(test.TestCase):
if not isinstance(converter_module, (list, tuple)):
converter_module = (converter_module,)
for i, m in enumerate(converter_module):
node = converter.standard_analysis(node, ctx, is_initial=not i)
for m in converter_module:
node = m.transform(node, ctx)
with self.compiled(node, namespace, tf_symbols) as result:
@ -177,5 +181,16 @@ class TestCase(test.TestCase):
namespace=namespace)
ctx = transformer.Context(entity_info, namer, program_ctx)
origin_info.resolve_entity(node, source, test_fn)
node = converter.standard_analysis(node, ctx, is_initial=True)
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs)
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
return node, ctx

View File

@ -41,9 +41,14 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.core import unsupported_features_checker
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transpiler
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.eager import function
from tensorflow.python.util import tf_inspect
@ -58,24 +63,35 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
# TODO(mdan): Insert list_comprehensions somewhere.
unsupported_features_checker.verify(node)
node = converter.standard_analysis(node, ctx, is_initial=True)
node = converter.apply_(node, ctx, functions)
node = converter.apply_(node, ctx, directives)
node = converter.apply_(node, ctx, break_statements)
# Run initial analysis.
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs)
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
node = functions.transform(node, ctx)
node = directives.transform(node, ctx)
node = break_statements.transform(node, ctx)
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
node = converter.apply_(node, ctx, asserts)
node = asserts.transform(node, ctx)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
node = converter.apply_(node, ctx, continue_statements)
node = converter.apply_(node, ctx, return_statements)
node = continue_statements.transform(node, ctx)
node = return_statements.transform(node, ctx)
if ctx.user.options.uses(converter.Feature.LISTS):
node = converter.apply_(node, ctx, lists)
node = converter.apply_(node, ctx, slices)
node = converter.apply_(node, ctx, call_trees)
node = converter.apply_(node, ctx, control_flow)
node = converter.apply_(node, ctx, conditional_expressions)
node = converter.apply_(node, ctx, logical_expressions)
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
node = call_trees.transform(node, ctx)
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
node = logical_expressions.transform(node, ctx)
return node

View File

@ -45,10 +45,12 @@ class Definition(object):
Attributes:
param_of: Optional[ast.AST]
directives: Dict, optional definition annotations
"""
def __init__(self):
self.param_of = None
self.directives = {}
def __repr__(self):
return '%s[%d]' % (self.__class__.__name__, id(self))
@ -274,7 +276,7 @@ class TreeAnnotator(transformer.Base):
return node
def resolve(node, source_info, graphs, definition_factory):
def resolve(node, source_info, graphs, definition_factory=Definition):
"""Resolves reaching definitions for each symbol.
Args:

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import enum
import gast
@ -28,6 +29,14 @@ from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import templates
class AnalysisLevel(enum.IntEnum):
NONE = 0
ACTIVITY = 1
DEFINEDNESS = 2
LIVENESS = 3
# TODO(znado): Use namedtuple.
class Context(object):
"""Contains information about a source code transformation.