STT-tensorflow/tensorflow/tools/compatibility/ast_edits.py

1115 lines
40 KiB
Python

# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Upgrader for Python scripts according to an API change specification."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import collections
import os
import re
import shutil
import sys
import tempfile
import traceback
import pasta
import six
from six.moves import range
# Some regular expressions we will need for parsing
FIND_OPEN = re.compile(r"^\s*(\[).*$")
FIND_STRING_CHARS = re.compile(r"['\"]")
INFO = "INFO"
WARNING = "WARNING"
ERROR = "ERROR"
ImportRename = collections.namedtuple(
"ImportRename", ["new_name", "excluded_prefixes"])
def full_name_node(name, ctx=ast.Load()):
"""Make an Attribute or Name node for name.
Translate a qualified name into nested Attribute nodes (and a Name node).
Args:
name: The name to translate to a node.
ctx: What context this name is used in. Defaults to Load()
Returns:
A Name or Attribute node.
"""
names = six.ensure_str(name).split(".")
names.reverse()
node = ast.Name(id=names.pop(), ctx=ast.Load())
while names:
node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
# Change outermost ctx to the one given to us (inner ones should be Load).
node.ctx = ctx
return node
def get_arg_value(node, arg_name, arg_pos=None):
"""Get the value of an argument from a ast.Call node.
This function goes through the positional and keyword arguments to check
whether a given argument was used, and if so, returns its value (the node
representing its value).
This cannot introspect *args or **args, but it safely handles *args in
Python3.5+.
Args:
node: The ast.Call node to extract arg values from.
arg_name: The name of the argument to extract.
arg_pos: The position of the argument (in case it's passed as a positional
argument).
Returns:
A tuple (arg_present, arg_value) containing a boolean indicating whether
the argument is present, and its value in case it is.
"""
# Check keyword args
if arg_name is not None:
for kw in node.keywords:
if kw.arg == arg_name:
return (True, kw.value)
# Check positional args
if arg_pos is not None:
idx = 0
for arg in node.args:
if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
continue # Can't parse Starred
if idx == arg_pos:
return (True, arg)
idx += 1
return (False, None)
def uses_star_args_in_call(node):
"""Check if an ast.Call node uses arbitrary-length positional *args.
This function works with the AST call node format of Python3.5+
as well as the different AST format of earlier versions of Python.
Args:
node: The ast.Call node to check arg values for.
Returns:
True if the node uses starred variadic positional args or keyword args.
False if it does not.
"""
if sys.version_info[:2] >= (3, 5):
# Check for an *args usage in python 3.5+
for arg in node.args:
if isinstance(arg, ast.Starred):
return True
else:
if node.starargs:
return True
return False
def uses_star_kwargs_in_call(node):
"""Check if an ast.Call node uses arbitrary-length **kwargs.
This function works with the AST call node format of Python3.5+
as well as the different AST format of earlier versions of Python.
Args:
node: The ast.Call node to check arg values for.
Returns:
True if the node uses starred variadic positional args or keyword args.
False if it does not.
"""
if sys.version_info[:2] >= (3, 5):
# Check for a **kwarg usage in python 3.5+
for keyword in node.keywords:
if keyword.arg is None:
return True
else:
if node.kwargs:
return True
return False
def uses_star_args_or_kwargs_in_call(node):
"""Check if an ast.Call node uses arbitrary-length *args or **kwargs.
This function works with the AST call node format of Python3.5+
as well as the different AST format of earlier versions of Python.
Args:
node: The ast.Call node to check arg values for.
Returns:
True if the node uses starred variadic positional args or keyword args.
False if it does not.
"""
return uses_star_args_in_call(node) or uses_star_kwargs_in_call(node)
def excluded_from_module_rename(module, import_rename_spec):
"""Check if this module import should not be renamed.
Args:
module: (string) module name.
import_rename_spec: ImportRename instance.
Returns:
True if this import should not be renamed according to the
import_rename_spec.
"""
for excluded_prefix in import_rename_spec.excluded_prefixes:
if module.startswith(excluded_prefix):
return True
return False
class APIChangeSpec(object):
"""This class defines the transformations that need to happen.
This class must provide the following fields:
* `function_keyword_renames`: maps function names to a map of old -> new
argument names
* `symbol_renames`: maps function names to new function names
* `change_to_function`: a set of function names that have changed (for
notifications)
* `function_reorders`: maps functions whose argument order has changed to the
list of arguments in the new order
* `function_warnings`: maps full names of functions to warnings that will be
printed out if the function is used. (e.g. tf.nn.convolution())
* `function_transformers`: maps function names to custom handlers
* `module_deprecations`: maps module names to warnings that will be printed
if the module is still used after all other transformations have run
* `import_renames`: maps import name (must be a short name without '.')
to ImportRename instance.
For an example, see `TFAPIChangeSpec`.
"""
def preprocess(self, root_node): # pylint: disable=unused-argument
"""Preprocess a parse tree. Return a preprocessed node, logs and errors."""
return root_node, [], []
def clear_preprocessing(self):
"""Restore this APIChangeSpec to before it preprocessed a file.
This is needed if preprocessing a file changed any rewriting rules.
"""
pass
class NoUpdateSpec(APIChangeSpec):
"""A specification of an API change which doesn't change anything."""
def __init__(self):
self.function_handle = {}
self.function_reorders = {}
self.function_keyword_renames = {}
self.symbol_renames = {}
self.function_warnings = {}
self.change_to_function = {}
self.module_deprecations = {}
self.function_transformers = {}
self.import_renames = {}
class _PastaEditVisitor(ast.NodeVisitor):
"""AST Visitor that processes function calls.
Updates function calls from old API version to new API version using a given
change spec.
"""
def __init__(self, api_change_spec):
self._api_change_spec = api_change_spec
self._log = [] # Holds 4-tuples: severity, line, col, msg.
self._stack = [] # Allow easy access to parents.
# Overridden to maintain a stack of nodes to allow for parent access
def visit(self, node):
self._stack.append(node)
super(_PastaEditVisitor, self).visit(node)
self._stack.pop()
@property
def errors(self):
return [log for log in self._log if log[0] == ERROR]
@property
def warnings(self):
return [log for log in self._log if log[0] == WARNING]
@property
def warnings_and_errors(self):
return [log for log in self._log if log[0] in (WARNING, ERROR)]
@property
def info(self):
return [log for log in self._log if log[0] == INFO]
@property
def log(self):
return self._log
def add_log(self, severity, lineno, col, msg):
self._log.append((severity, lineno, col, msg))
print("%s line %d:%d: %s" % (severity, lineno, col, msg))
def add_logs(self, logs):
"""Record a log and print it.
The log should be a tuple `(severity, lineno, col_offset, msg)`, which will
be printed and recorded. It is part of the log available in the `self.log`
property.
Args:
logs: The logs to add. Must be a list of tuples
`(severity, lineno, col_offset, msg)`.
"""
self._log.extend(logs)
for log in logs:
print("%s line %d:%d: %s" % log)
def _get_applicable_entries(self, transformer_field, full_name, name):
"""Get all list entries indexed by name that apply to full_name or name."""
# Transformers are indexed to full name, name, or no name
# as a performance optimization.
function_transformers = getattr(self._api_change_spec,
transformer_field, {})
glob_name = "*." + six.ensure_str(name) if name else None
transformers = []
if full_name in function_transformers:
transformers.append(function_transformers[full_name])
if glob_name in function_transformers:
transformers.append(function_transformers[glob_name])
if "*" in function_transformers:
transformers.append(function_transformers["*"])
return transformers
def _get_applicable_dict(self, transformer_field, full_name, name):
"""Get all dict entries indexed by name that apply to full_name or name."""
# Transformers are indexed to full name, name, or no name
# as a performance optimization.
function_transformers = getattr(self._api_change_spec,
transformer_field, {})
glob_name = "*." + six.ensure_str(name) if name else None
transformers = function_transformers.get("*", {}).copy()
transformers.update(function_transformers.get(glob_name, {}))
transformers.update(function_transformers.get(full_name, {}))
return transformers
def _get_full_name(self, node):
"""Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
This is the inverse of `full_name_node`.
Args:
node: A Node of type Attribute.
Returns:
a '.'-delimited full-name or None if node was not Attribute or Name.
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
"""
curr = node
items = []
while not isinstance(curr, ast.Name):
if not isinstance(curr, ast.Attribute):
return None
items.append(curr.attr)
curr = curr.value
items.append(curr.id)
return ".".join(reversed(items))
def _maybe_add_warning(self, node, full_name):
"""Adds an error to be printed about full_name at node."""
function_warnings = self._api_change_spec.function_warnings
if full_name in function_warnings:
level, message = function_warnings[full_name]
message = six.ensure_str(message).replace("<function name>", full_name)
self.add_log(level, node.lineno, node.col_offset,
"%s requires manual check. %s" % (full_name, message))
return True
else:
return False
def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name):
"""Adds a warning if full_name is a deprecated module."""
warnings = self._api_change_spec.module_deprecations
if full_name in warnings:
level, message = warnings[full_name]
message = six.ensure_str(message).replace("<function name>",
six.ensure_str(whole_name))
self.add_log(level, node.lineno, node.col_offset,
"Using member %s in deprecated module %s. %s" % (whole_name,
full_name,
message))
return True
else:
return False
def _maybe_add_call_warning(self, node, full_name, name):
"""Print a warning when specific functions are called with selected args.
The function _print_warning_for_function matches the full name of the called
function, e.g., tf.foo.bar(). This function matches the function name that
is called, as long as the function is an attribute. For example,
`tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`.
Args:
node: ast.Call object
full_name: The precomputed full name of the callable, if one exists, None
otherwise.
name: The precomputed name of the callable, if one exists, None otherwise.
Returns:
Whether an error was recorded.
"""
# Only look for *.-warnings here, the other will be handled by the Attribute
# visitor. Also, do not warn for bare functions, only if the call func is
# an attribute.
warned = False
if isinstance(node.func, ast.Attribute):
warned = self._maybe_add_warning(node, "*." + six.ensure_str(name))
# All arg warnings are handled here, since only we have the args
arg_warnings = self._get_applicable_dict("function_arg_warnings",
full_name, name)
variadic_args = uses_star_args_or_kwargs_in_call(node)
for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()):
present, _ = get_arg_value(node, kwarg, arg) or variadic_args
if present:
warned = True
warning_message = six.ensure_str(warning).replace(
"<function name>", six.ensure_str(full_name or name))
template = "%s called with %s argument, requires manual check: %s"
if variadic_args:
template = ("%s called with *args or **kwargs that may include %s, "
"requires manual check: %s")
self.add_log(level, node.lineno, node.col_offset,
template % (full_name or name, kwarg, warning_message))
return warned
def _maybe_rename(self, parent, node, full_name):
"""Replace node (Attribute or Name) with a node representing full_name."""
new_name = self._api_change_spec.symbol_renames.get(full_name, None)
if new_name:
self.add_log(INFO, node.lineno, node.col_offset,
"Renamed %r to %r" % (full_name, new_name))
new_node = full_name_node(new_name, node.ctx)
ast.copy_location(new_node, node)
pasta.ast_utils.replace_child(parent, node, new_node)
return True
else:
return False
def _maybe_change_to_function_call(self, parent, node, full_name):
"""Wraps node (typically, an Attribute or Expr) in a Call."""
if full_name in self._api_change_spec.change_to_function:
if not isinstance(parent, ast.Call):
# ast.Call's constructor is really picky about how many arguments it
# wants, and also, it changed between Py2 and Py3.
if six.PY2:
new_node = ast.Call(node, [], [], None, None)
else:
new_node = ast.Call(node, [], [])
pasta.ast_utils.replace_child(parent, node, new_node)
ast.copy_location(new_node, node)
self.add_log(INFO, node.lineno, node.col_offset,
"Changed %r to a function call" % full_name)
return True
return False
def _maybe_add_arg_names(self, node, full_name):
"""Make args into keyword args if function called full_name requires it."""
function_reorders = self._api_change_spec.function_reorders
if full_name in function_reorders:
if uses_star_args_in_call(node):
self.add_log(WARNING, node.lineno, node.col_offset,
"(Manual check required) upgrading %s may require "
"re-ordering the call arguments, but it was passed "
"variable-length positional *args. The upgrade "
"script cannot handle these automatically." % full_name)
reordered = function_reorders[full_name]
new_keywords = []
idx = 0
for arg in node.args:
if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
continue # Can't move Starred to keywords
keyword_arg = reordered[idx]
keyword = ast.keyword(arg=keyword_arg, value=arg)
new_keywords.append(keyword)
idx += 1
if new_keywords:
self.add_log(INFO, node.lineno, node.col_offset,
"Added keywords to args of function %r" % full_name)
node.args = []
node.keywords = new_keywords + (node.keywords or [])
return True
return False
def _maybe_modify_args(self, node, full_name, name):
"""Rename keyword args if the function called full_name requires it."""
renamed_keywords = self._get_applicable_dict("function_keyword_renames",
full_name, name)
if not renamed_keywords:
return False
if uses_star_kwargs_in_call(node):
self.add_log(WARNING, node.lineno, node.col_offset,
"(Manual check required) upgrading %s may require "
"renaming or removing call arguments, but it was passed "
"variable-length *args or **kwargs. The upgrade "
"script cannot handle these automatically." %
(full_name or name))
modified = False
new_keywords = []
for keyword in node.keywords:
argkey = keyword.arg
if argkey in renamed_keywords:
modified = True
if renamed_keywords[argkey] is None:
lineno = getattr(keyword, "lineno", node.lineno)
col_offset = getattr(keyword, "col_offset", node.col_offset)
self.add_log(INFO, lineno, col_offset,
"Removed argument %s for function %s" % (
argkey, full_name or name))
else:
keyword.arg = renamed_keywords[argkey]
lineno = getattr(keyword, "lineno", node.lineno)
col_offset = getattr(keyword, "col_offset", node.col_offset)
self.add_log(INFO, lineno, col_offset,
"Renamed keyword argument for %s from %s to %s" % (
full_name, argkey, renamed_keywords[argkey]))
new_keywords.append(keyword)
else:
new_keywords.append(keyword)
if modified:
node.keywords = new_keywords
return modified
def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST.
Args:
node: Current Node
"""
assert self._stack[-1] is node
# Get the name for this call, so we can index stuff with it.
full_name = self._get_full_name(node.func)
if full_name:
name = full_name.split(".")[-1]
elif isinstance(node.func, ast.Name):
name = node.func.id
elif isinstance(node.func, ast.Attribute):
name = node.func.attr
else:
name = None
# Call standard transformers for this node.
# Make sure warnings come first, since args or names triggering warnings
# may be removed by the other transformations.
self._maybe_add_call_warning(node, full_name, name)
# Make all args into kwargs
self._maybe_add_arg_names(node, full_name)
# Argument name changes or deletions
self._maybe_modify_args(node, full_name, name)
# Call transformers. These have the ability to modify the node, and if they
# do, will return the new node they created (or the same node if they just
# changed it). The are given the parent, but we will take care of
# integrating their changes into the parent if they return a new node.
#
# These are matched on the old name, since renaming is performed by the
# Attribute visitor, which happens later.
transformers = self._get_applicable_entries("function_transformers",
full_name, name)
parent = self._stack[-2]
if transformers:
if uses_star_args_or_kwargs_in_call(node):
self.add_log(WARNING, node.lineno, node.col_offset,
"(Manual check required) upgrading %s may require "
"modifying call arguments, but it was passed "
"variable-length *args or **kwargs. The upgrade "
"script cannot handle these automatically." %
(full_name or name))
for transformer in transformers:
logs = []
new_node = transformer(parent, node, full_name, name, logs)
self.add_logs(logs)
if new_node and new_node is not node:
pasta.ast_utils.replace_child(parent, node, new_node)
node = new_node
self._stack[-1] = node
self.generic_visit(node)
def visit_Attribute(self, node): # pylint: disable=invalid-name
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
assert self._stack[-1] is node
full_name = self._get_full_name(node)
if full_name:
parent = self._stack[-2]
# Make sure the warning comes first, otherwise the name may have changed
self._maybe_add_warning(node, full_name)
# Once we did a modification, node is invalid and not worth inspecting
# further. Also, we only perform modifications for simple nodes, so
# There'd be no point in descending further.
if self._maybe_rename(parent, node, full_name):
return
if self._maybe_change_to_function_call(parent, node, full_name):
return
# The isinstance check is enough -- a bare Attribute is never root.
i = 2
while isinstance(self._stack[-i], ast.Attribute):
i += 1
whole_name = pasta.dump(self._stack[-(i-1)])
self._maybe_add_module_deprecation_warning(node, full_name, whole_name)
self.generic_visit(node)
def visit_Import(self, node): # pylint: disable=invalid-name
"""Handle visiting an import node in the AST.
Args:
node: Current Node
"""
new_aliases = []
import_updated = False
import_renames = getattr(self._api_change_spec, "import_renames", {})
max_submodule_depth = getattr(self._api_change_spec, "max_submodule_depth",
1)
inserts_after_imports = getattr(self._api_change_spec,
"inserts_after_imports", {})
# This loop processes imports in the format
# import foo as f, bar as b
for import_alias in node.names:
all_import_components = six.ensure_str(import_alias.name).split(".")
# Look for rename, starting with longest import levels.
found_update = False
for i in reversed(list(range(1, max_submodule_depth + 1))):
import_component = all_import_components[0]
for j in range(1, min(i, len(all_import_components))):
import_component += "." + six.ensure_str(all_import_components[j])
import_rename_spec = import_renames.get(import_component, None)
if not import_rename_spec or excluded_from_module_rename(
import_alias.name, import_rename_spec):
continue
new_name = (
import_rename_spec.new_name +
import_alias.name[len(import_component):])
# If current import is
# import foo
# then new import should preserve imported name:
# import new_foo as foo
# This happens when module has just one component.
new_asname = import_alias.asname
if not new_asname and "." not in import_alias.name:
new_asname = import_alias.name
new_alias = ast.alias(name=new_name, asname=new_asname)
new_aliases.append(new_alias)
import_updated = True
found_update = True
# Insert any followup lines that should happen after this import.
full_import = (import_alias.name, import_alias.asname)
insert_offset = 1
for line_to_insert in inserts_after_imports.get(full_import, []):
assert self._stack[-1] is node
parent = self._stack[-2]
new_line_node = pasta.parse(line_to_insert)
ast.copy_location(new_line_node, node)
parent.body.insert(
parent.body.index(node) + insert_offset, new_line_node)
insert_offset += 1
# Insert a newline after the import if necessary
old_suffix = pasta.base.formatting.get(node, "suffix")
if old_suffix is None:
old_suffix = os.linesep
if os.linesep not in old_suffix:
pasta.base.formatting.set(node, "suffix",
six.ensure_str(old_suffix) + os.linesep)
# Apply indentation to new node.
pasta.base.formatting.set(new_line_node, "prefix",
pasta.base.formatting.get(node, "prefix"))
pasta.base.formatting.set(new_line_node, "suffix", os.linesep)
self.add_log(
INFO, node.lineno, node.col_offset,
"Adding `%s` after import of %s" %
(new_line_node, import_alias.name))
# Find one match, break
if found_update:
break
# No rename is found for all levels
if not found_update:
new_aliases.append(import_alias) # no change needed
# Replace the node if at least one import needs to be updated.
if import_updated:
assert self._stack[-1] is node
parent = self._stack[-2]
new_node = ast.Import(new_aliases)
ast.copy_location(new_node, node)
pasta.ast_utils.replace_child(parent, node, new_node)
self.add_log(
INFO, node.lineno, node.col_offset,
"Changed import from %r to %r." %
(pasta.dump(node), pasta.dump(new_node)))
self.generic_visit(node)
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
"""Handle visiting an import-from node in the AST.
Args:
node: Current Node
"""
if not node.module:
self.generic_visit(node)
return
from_import = node.module
# Look for rename based on first component of from-import.
# i.e. based on foo in foo.bar.
from_import_first_component = six.ensure_str(from_import).split(".")[0]
import_renames = getattr(self._api_change_spec, "import_renames", {})
import_rename_spec = import_renames.get(from_import_first_component, None)
if not import_rename_spec:
self.generic_visit(node)
return
# Split module aliases into the ones that require import update
# and those that don't. For e.g. if we want to rename "a" to "b"
# unless we import "a.c" in the following:
# from a import c, d
# we want to update import for "d" but not for "c".
updated_aliases = []
same_aliases = []
for import_alias in node.names:
full_module_name = "%s.%s" % (from_import, import_alias.name)
if excluded_from_module_rename(full_module_name, import_rename_spec):
same_aliases.append(import_alias)
else:
updated_aliases.append(import_alias)
if not updated_aliases:
self.generic_visit(node)
return
assert self._stack[-1] is node
parent = self._stack[-2]
# Replace first component of from-import with new name.
new_from_import = (
import_rename_spec.new_name +
from_import[len(from_import_first_component):])
updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level)
ast.copy_location(updated_node, node)
pasta.ast_utils.replace_child(parent, node, updated_node)
# If some imports had to stay the same, add another import for them.
additional_import_log = ""
if same_aliases:
same_node = ast.ImportFrom(from_import, same_aliases, node.level,
col_offset=node.col_offset, lineno=node.lineno)
ast.copy_location(same_node, node)
parent.body.insert(parent.body.index(updated_node), same_node)
# Apply indentation to new node.
pasta.base.formatting.set(
same_node, "prefix",
pasta.base.formatting.get(updated_node, "prefix"))
additional_import_log = " and %r" % pasta.dump(same_node)
self.add_log(
INFO, node.lineno, node.col_offset,
"Changed import from %r to %r%s." %
(pasta.dump(node),
pasta.dump(updated_node),
additional_import_log))
self.generic_visit(node)
class AnalysisResult(object):
"""This class represents an analysis result and how it should be logged.
This class must provide the following fields:
* `log_level`: The log level to which this detection should be logged
* `log_message`: The message that should be logged for this detection
For an example, see `VersionedTFImport`.
"""
class APIAnalysisSpec(object):
"""This class defines how `AnalysisResult`s should be generated.
It specifies how to map imports and symbols to `AnalysisResult`s.
This class must provide the following fields:
* `symbols_to_detect`: maps function names to `AnalysisResult`s
* `imports_to_detect`: maps imports represented as (full module name, alias)
tuples to `AnalysisResult`s
notifications)
For an example, see `TFAPIImportAnalysisSpec`.
"""
class PastaAnalyzeVisitor(_PastaEditVisitor):
"""AST Visitor that looks for specific API usage without editing anything.
This is used before any rewriting is done to detect if any symbols are used
that require changing imports or disabling rewriting altogether.
"""
def __init__(self, api_analysis_spec):
super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec())
self._api_analysis_spec = api_analysis_spec
self._results = [] # Holds AnalysisResult objects
@property
def results(self):
return self._results
def add_result(self, analysis_result):
self._results.append(analysis_result)
def visit_Attribute(self, node): # pylint: disable=invalid-name
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
full_name = self._get_full_name(node)
if full_name:
detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None)
if detection:
self.add_result(detection)
self.add_log(
detection.log_level, node.lineno, node.col_offset,
detection.log_message)
self.generic_visit(node)
def visit_Import(self, node): # pylint: disable=invalid-name
"""Handle visiting an import node in the AST.
Args:
node: Current Node
"""
for import_alias in node.names:
# Detect based on full import name and alias)
full_import = (import_alias.name, import_alias.asname)
detection = (self._api_analysis_spec
.imports_to_detect.get(full_import, None))
if detection:
self.add_result(detection)
self.add_log(
detection.log_level, node.lineno, node.col_offset,
detection.log_message)
self.generic_visit(node)
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
"""Handle visiting an import-from node in the AST.
Args:
node: Current Node
"""
if not node.module:
self.generic_visit(node)
return
from_import = node.module
for import_alias in node.names:
# Detect based on full import name(to & as)
full_module_name = "%s.%s" % (from_import, import_alias.name)
full_import = (full_module_name, import_alias.asname)
detection = (self._api_analysis_spec
.imports_to_detect.get(full_import, None))
if detection:
self.add_result(detection)
self.add_log(
detection.log_level, node.lineno, node.col_offset,
detection.log_message)
self.generic_visit(node)
class ASTCodeUpgrader(object):
"""Handles upgrading a set of Python files using a given API change spec."""
def __init__(self, api_change_spec):
if not isinstance(api_change_spec, APIChangeSpec):
raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
type(api_change_spec))
self._api_change_spec = api_change_spec
def process_file(self,
in_filename,
out_filename,
no_change_to_outfile_on_error=False):
"""Process the given python file for incompatible changes.
Args:
in_filename: filename to parse
out_filename: output file to write to
no_change_to_outfile_on_error: not modify the output file on errors
Returns:
A tuple representing number of files processed, log of actions, errors
"""
# Write to a temporary file, just in case we are doing an implace modify.
# pylint: disable=g-backslash-continuation
with open(in_filename, "r") as in_file, \
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
ret = self.process_opened_file(in_filename, in_file, out_filename,
temp_file)
# pylint: enable=g-backslash-continuation
if no_change_to_outfile_on_error and ret[0] == 0:
os.remove(temp_file.name)
else:
shutil.move(temp_file.name, out_filename)
return ret
def format_log(self, log, in_filename):
log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3])
if in_filename:
return six.ensure_str(in_filename) + ":" + log_string
else:
return log_string
def update_string_pasta(self, text, in_filename):
"""Updates a file using pasta."""
try:
t = pasta.parse(text)
except (SyntaxError, ValueError, TypeError):
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
return 0, "", log, []
t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
visitor = _PastaEditVisitor(self._api_change_spec)
visitor.visit(t)
self._api_change_spec.clear_preprocessing()
logs = [self.format_log(log, None) for log in (preprocess_logs +
visitor.log)]
errors = [self.format_log(error, in_filename)
for error in (preprocess_errors +
visitor.warnings_and_errors)]
return 1, pasta.dump(t), logs, errors
def _format_log(self, log, in_filename, out_filename):
text = six.ensure_str("-" * 80) + "\n"
text += "Processing file %r\n outputting to %r\n" % (in_filename,
out_filename)
text += six.ensure_str("-" * 80) + "\n\n"
text += "\n".join(log) + "\n"
text += six.ensure_str("-" * 80) + "\n\n"
return text
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
"""Process the given python file for incompatible changes.
This function is split out to facilitate StringIO testing from
tf_upgrade_test.py.
Args:
in_filename: filename to parse
in_file: opened file (or StringIO)
out_filename: output file to write to
out_file: opened file (or StringIO)
Returns:
A tuple representing number of files processed, log of actions, errors
"""
lines = in_file.readlines()
processed_file, new_file_content, log, process_errors = (
self.update_string_pasta("".join(lines), in_filename))
if out_file and processed_file:
out_file.write(new_file_content)
return (processed_file,
self._format_log(log, in_filename, out_filename),
process_errors)
def process_tree(self, root_directory, output_root_directory,
copy_other_files):
"""Processes upgrades on an entire tree of python files in place.
Note that only Python files. If you have custom code in other languages,
you will need to manually upgrade those.
Args:
root_directory: Directory to walk and process.
output_root_directory: Directory to use as base.
copy_other_files: Copy files that are not touched by this converter.
Returns:
A tuple of files processed, the report string for all files, and a dict
mapping filenames to errors encountered in that file.
"""
if output_root_directory == root_directory:
return self.process_tree_inplace(root_directory)
# make sure output directory doesn't exist
if output_root_directory and os.path.exists(output_root_directory):
print("Output directory %r must not already exist." %
(output_root_directory))
sys.exit(1)
# make sure output directory does not overlap with root_directory
norm_root = os.path.split(os.path.normpath(root_directory))
norm_output = os.path.split(os.path.normpath(output_root_directory))
if norm_root == norm_output:
print("Output directory %r same as input directory %r" %
(root_directory, output_root_directory))
sys.exit(1)
# Collect list of files to process (we do this to correctly handle if the
# user puts the output directory in some sub directory of the input dir)
files_to_process = []
files_to_copy = []
for dir_name, _, file_list in os.walk(root_directory):
py_files = [f for f in file_list if six.ensure_str(f).endswith(".py")]
copy_files = [
f for f in file_list if not six.ensure_str(f).endswith(".py")
]
for filename in py_files:
fullpath = os.path.join(dir_name, filename)
fullpath_output = os.path.join(output_root_directory,
os.path.relpath(fullpath,
root_directory))
files_to_process.append((fullpath, fullpath_output))
if copy_other_files:
for filename in copy_files:
fullpath = os.path.join(dir_name, filename)
fullpath_output = os.path.join(output_root_directory,
os.path.relpath(
fullpath, root_directory))
files_to_copy.append((fullpath, fullpath_output))
file_count = 0
tree_errors = {}
report = ""
report += six.ensure_str(("=" * 80)) + "\n"
report += "Input tree: %r\n" % root_directory
report += six.ensure_str(("=" * 80)) + "\n"
for input_path, output_path in files_to_process:
output_directory = os.path.dirname(output_path)
if not os.path.isdir(output_directory):
os.makedirs(output_directory)
if os.path.islink(input_path):
link_target = os.readlink(input_path)
link_target_output = os.path.join(
output_root_directory, os.path.relpath(link_target, root_directory))
if (link_target, link_target_output) in files_to_process:
# Create a link to the new location of the target file
os.symlink(link_target_output, output_path)
else:
report += "Copying symlink %s without modifying its target %s" % (
input_path, link_target)
os.symlink(link_target, output_path)
continue
file_count += 1
_, l_report, l_errors = self.process_file(input_path, output_path)
tree_errors[input_path] = l_errors
report += l_report
for input_path, output_path in files_to_copy:
output_directory = os.path.dirname(output_path)
if not os.path.isdir(output_directory):
os.makedirs(output_directory)
shutil.copy(input_path, output_path)
return file_count, report, tree_errors
def process_tree_inplace(self, root_directory):
"""Process a directory of python files in place."""
files_to_process = []
for dir_name, _, file_list in os.walk(root_directory):
py_files = [
os.path.join(dir_name, f)
for f in file_list
if six.ensure_str(f).endswith(".py")
]
files_to_process += py_files
file_count = 0
tree_errors = {}
report = ""
report += six.ensure_str(("=" * 80)) + "\n"
report += "Input tree: %r\n" % root_directory
report += six.ensure_str(("=" * 80)) + "\n"
for path in files_to_process:
if os.path.islink(path):
report += "Skipping symlink %s.\n" % path
continue
file_count += 1
_, l_report, l_errors = self.process_file(path, path)
tree_errors[path] = l_errors
report += l_report
return file_count, report, tree_errors