Refactor compatibility upgrade tool to use pasta (an AST based refactoring tool) instead of line based edits.
This allows for some simplification (e.g., we can now remove arguments with a dict entry) and for much more powerful transformations. This passes all tests with no destructive test changes (though there are cosmetic test changes). Also adds transformations for tf.to_dtype -> tf.cast(..., dtype=...). PiperOrigin-RevId: 227934801
This commit is contained in:
parent
75f2e9c266
commit
fe66882827
@ -1,17 +1,21 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_copts", # @unused
|
||||
"tf_cc_test", # @unused
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
py_library(
|
||||
name = "ast_edits",
|
||||
srcs = ["ast_edits.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"@pasta",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@ -27,6 +26,9 @@ import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
|
||||
import pasta
|
||||
import six
|
||||
|
||||
# Some regular expressions we will need for parsing
|
||||
FIND_OPEN = re.compile(r"^\s*(\[).*$")
|
||||
FIND_STRING_CHARS = re.compile(r"['\"]")
|
||||
@ -44,169 +46,173 @@ class APIChangeSpec(object):
|
||||
notifications)
|
||||
* `function_reorders`: maps functions whose argument order has changed to the
|
||||
list of arguments in the new order
|
||||
* `function_handle`: maps function names to custom handlers for the function
|
||||
* `function_warnings`: maps full names of functions to warnings that will be
|
||||
printed out if the function is used. (e.g. tf.nn.convolution())
|
||||
* `unrestricted_function_warnings`: maps names of functions to warnings that
|
||||
will be printed out when the function is used (e.g. foo.convolution()).
|
||||
* `function_keyword_additions`: maps function names to a map of arg->value
|
||||
names that should be passed to the function.
|
||||
* `function_transformers`: maps function names to custom handlers
|
||||
|
||||
For an example, see `TFAPIChangeSpec`.
|
||||
"""
|
||||
|
||||
|
||||
class _FileEditTuple(
|
||||
collections.namedtuple("_FileEditTuple",
|
||||
["comment", "line", "start", "old", "new"])):
|
||||
"""Each edit that is recorded by a _FileEditRecorder.
|
||||
|
||||
Fields:
|
||||
comment: A description of the edit and why it was made.
|
||||
line: The line number in the file where the edit occurs (1-indexed).
|
||||
start: The column number in the file where the edit occurs (0-indexed).
|
||||
old: text string to remove (this must match what was in file).
|
||||
new: text string to add in place of `old`.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class _FileEditRecorder(object):
|
||||
"""Record changes that need to be done to the file."""
|
||||
|
||||
def __init__(self, filename):
|
||||
# all edits are lists of chars
|
||||
self._filename = filename
|
||||
|
||||
self._line_to_edit = collections.defaultdict(list)
|
||||
self._errors = []
|
||||
|
||||
def process(self, text):
|
||||
"""Process a list of strings, each corresponding to the recorded changes.
|
||||
|
||||
Args:
|
||||
text: A list of lines of text (assumed to contain newlines)
|
||||
Returns:
|
||||
A tuple of the modified text and a textual description of what is done.
|
||||
Raises:
|
||||
ValueError: if substitution source location does not have expected text.
|
||||
"""
|
||||
|
||||
change_report = ""
|
||||
|
||||
# Iterate of each line
|
||||
for line, edits in self._line_to_edit.items():
|
||||
offset = 0
|
||||
# sort by column so that edits are processed in order in order to make
|
||||
# indexing adjustments cumulative for changes that change the string
|
||||
# length
|
||||
edits.sort(key=lambda x: x.start)
|
||||
|
||||
# Extract each line to a list of characters, because mutable lists
|
||||
# are editable, unlike immutable strings.
|
||||
char_array = list(text[line - 1])
|
||||
|
||||
# Record a description of the change
|
||||
change_report += "%r Line %d\n" % (self._filename, line)
|
||||
change_report += "-" * 80 + "\n\n"
|
||||
for e in edits:
|
||||
change_report += "%s\n" % e.comment
|
||||
change_report += "\n Old: %s" % (text[line - 1])
|
||||
|
||||
# Make underscore buffers for underlining where in the line the edit was
|
||||
change_list = [" "] * len(text[line - 1])
|
||||
change_list_new = [" "] * len(text[line - 1])
|
||||
|
||||
# Iterate for each edit
|
||||
for e in edits:
|
||||
# Create effective start, end by accounting for change in length due
|
||||
# to previous edits
|
||||
start_eff = e.start + offset
|
||||
end_eff = start_eff + len(e.old)
|
||||
|
||||
# Make sure the edit is changing what it should be changing
|
||||
old_actual = "".join(char_array[start_eff:end_eff])
|
||||
if old_actual != e.old:
|
||||
raise ValueError("Expected text %r but got %r" %
|
||||
("".join(e.old), "".join(old_actual)))
|
||||
# Make the edit
|
||||
char_array[start_eff:end_eff] = list(e.new)
|
||||
|
||||
# Create the underline highlighting of the before and after
|
||||
change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
|
||||
change_list_new[start_eff:end_eff] = "~" * len(e.new)
|
||||
|
||||
# Keep track of how to generate effective ranges
|
||||
offset += len(e.new) - len(e.old)
|
||||
|
||||
# Finish the report comment
|
||||
change_report += " %s\n" % "".join(change_list)
|
||||
text[line - 1] = "".join(char_array)
|
||||
change_report += " New: %s" % (text[line - 1])
|
||||
change_report += " %s\n\n" % "".join(change_list_new)
|
||||
return "".join(text), change_report, self._errors
|
||||
|
||||
def add(self, comment, line, start, old, new, error=None):
|
||||
"""Add a new change that is needed.
|
||||
|
||||
Args:
|
||||
comment: A description of what was changed
|
||||
line: Line number (1 indexed)
|
||||
start: Column offset (0 indexed)
|
||||
old: old text
|
||||
new: new text
|
||||
error: this "edit" is something that cannot be fixed automatically
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
self._line_to_edit[line].append(
|
||||
_FileEditTuple(comment, line, start, old, new))
|
||||
if error:
|
||||
self._errors.append("%s:%d: %s" % (self._filename, line, error))
|
||||
|
||||
|
||||
class _ASTCallVisitor(ast.NodeVisitor):
|
||||
class _PastaEditVisitor(ast.NodeVisitor):
|
||||
"""AST Visitor that processes function calls.
|
||||
|
||||
Updates function calls from old API version to new API version using a given
|
||||
change spec.
|
||||
"""
|
||||
|
||||
def __init__(self, filename, lines, api_change_spec):
|
||||
self._filename = filename
|
||||
self._file_edit = _FileEditRecorder(filename)
|
||||
self._lines = lines
|
||||
def __init__(self, api_change_spec):
|
||||
self._api_change_spec = api_change_spec
|
||||
self._log = [] # Holds 3-tuples: line, col, msg.
|
||||
self._errors = [] # Same structure as _log.
|
||||
self._stack = [] # Allow easy access to parents.
|
||||
|
||||
def process(self, lines):
|
||||
return self._file_edit.process(lines)
|
||||
# Overridden to maintain a stack of nodes to allow for parent access
|
||||
def visit(self, node):
|
||||
self._stack.append(node)
|
||||
super(_PastaEditVisitor, self).visit(node)
|
||||
self._stack.pop()
|
||||
|
||||
def generic_visit(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
@property
|
||||
def errors(self):
|
||||
return self._errors
|
||||
|
||||
def _rename_functions(self, node, full_name):
|
||||
symbol_renames = self._api_change_spec.symbol_renames
|
||||
try:
|
||||
new_name = symbol_renames[full_name]
|
||||
self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
|
||||
node.lineno, node.col_offset, full_name, new_name)
|
||||
except KeyError:
|
||||
pass
|
||||
@property
|
||||
def log(self):
|
||||
return self._log
|
||||
|
||||
def _print_warning_for_function(self, node, full_name):
|
||||
def _format_log(self, log):
|
||||
text = ""
|
||||
for log_entry in log:
|
||||
text += "Line %d:%d: %s\n" % log_entry
|
||||
return text
|
||||
|
||||
def log_text(self):
|
||||
return self._format_log(self.log)
|
||||
|
||||
def add_log(self, lineno, col, msg):
|
||||
self._log.append((lineno, col, msg))
|
||||
print("Line %d:%d: %s" % (lineno, col, msg))
|
||||
|
||||
def add_error(self, lineno, col, msg):
|
||||
# All errors are also added to the regular log.
|
||||
self.add_log(lineno, col, msg)
|
||||
self._errors.append((lineno, col, msg))
|
||||
|
||||
def add_logs(self, logs):
|
||||
"""Record a log and print it.
|
||||
|
||||
The log should be a tuple (lineno, col_offset, msg), which will be printed
|
||||
and then recorded. It is part of the log available in the self.log property.
|
||||
|
||||
Args:
|
||||
logs: The log to add. Must be a tuple (lineno, col_offset, msg).
|
||||
"""
|
||||
self._log.extend(logs)
|
||||
for log in logs:
|
||||
print("Line %d:%d: %s" % log)
|
||||
|
||||
def add_errors(self, errors):
|
||||
"""Record an error and print it.
|
||||
|
||||
The error must be a tuple (lineno, col_offset, msg), which will be printed
|
||||
and then recorded as both a log and an error. It is therefore part of the
|
||||
log available in the self.log as well as the self.errors property.
|
||||
|
||||
Args:
|
||||
errors: The log to add. Must be a tuple (lineno, col_offset, msg).
|
||||
"""
|
||||
self.add_logs(errors)
|
||||
self._errors.extend(errors)
|
||||
|
||||
def _get_applicable_entries(self, transformer_field, full_name, name):
|
||||
"""Get all list entries indexed by name that apply to full_name or name."""
|
||||
# Transformers are indexed to full name, name, or no name
|
||||
# as a performance optimization.
|
||||
function_transformers = getattr(self._api_change_spec,
|
||||
transformer_field, {})
|
||||
|
||||
glob_name = "*." + name if name else None
|
||||
transformers = []
|
||||
if full_name in function_transformers:
|
||||
transformers.append(function_transformers[full_name])
|
||||
if glob_name in function_transformers:
|
||||
transformers.append(function_transformers[glob_name])
|
||||
if "*" in function_transformers:
|
||||
transformers.append(function_transformers["*"])
|
||||
return transformers
|
||||
|
||||
def _get_applicable_dict(self, transformer_field, full_name, name):
|
||||
"""Get all dict entries indexed by name that apply to full_name or name."""
|
||||
# Transformers are indexed to full name, name, or no name
|
||||
# as a performance optimization.
|
||||
function_transformers = getattr(self._api_change_spec,
|
||||
transformer_field, {})
|
||||
|
||||
glob_name = "*." + name if name else None
|
||||
transformers = function_transformers.get("*", {}).copy()
|
||||
transformers.update(function_transformers.get(glob_name, {}))
|
||||
transformers.update(function_transformers.get(full_name, {}))
|
||||
return transformers
|
||||
|
||||
def _get_full_name(self, node):
|
||||
"""Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
|
||||
|
||||
This is the inverse of _full_name_node.
|
||||
|
||||
Args:
|
||||
node: A Node of type Attribute.
|
||||
|
||||
Returns:
|
||||
a '.'-delimited full-name or None if node was not Attribute or Name.
|
||||
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
|
||||
"""
|
||||
curr = node
|
||||
items = []
|
||||
while not isinstance(curr, ast.Name):
|
||||
if not isinstance(curr, ast.Attribute):
|
||||
return None
|
||||
items.append(curr.attr)
|
||||
curr = curr.value
|
||||
items.append(curr.id)
|
||||
return ".".join(reversed(items))
|
||||
|
||||
def _full_name_node(self, name, ctx=ast.Load()):
|
||||
"""Make an Attribute or Name node for name.
|
||||
|
||||
Translate a qualified name into nested Attribute nodes (and a Name node).
|
||||
|
||||
Args:
|
||||
name: The name to translate to a node.
|
||||
ctx: What context this name is used in. Defaults to Load()
|
||||
|
||||
Returns:
|
||||
A Name or Attribute node.
|
||||
"""
|
||||
names = name.split(".")
|
||||
names.reverse()
|
||||
node = ast.Name(id=names.pop(), ctx=ast.Load())
|
||||
while names:
|
||||
node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
|
||||
|
||||
# Change outermost ctx to the one given to us (inner ones should be Load).
|
||||
node.ctx = ctx
|
||||
return node
|
||||
|
||||
def _maybe_add_warning(self, node, full_name):
|
||||
"""Adds an error to be printed about full_name at node."""
|
||||
function_warnings = self._api_change_spec.function_warnings
|
||||
try:
|
||||
if full_name in function_warnings:
|
||||
warning_message = function_warnings[full_name]
|
||||
warning_message = warning_message.replace("<function name>", full_name)
|
||||
self._file_edit.add(warning_message,
|
||||
node.lineno, node.col_offset, full_name, full_name,
|
||||
error="%s requires manual check." % full_name)
|
||||
except KeyError:
|
||||
pass
|
||||
self.add_error(node.lineno, node.col_offset,
|
||||
"%s requires manual check: %s." % (full_name,
|
||||
warning_message))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _print_warning_for_function_unrestricted(self, node):
|
||||
def _maybe_add_call_warning(self, node, full_name, name):
|
||||
"""Print a warning when specific functions are called.
|
||||
|
||||
The function _print_warning_for_function matches the full name of the called
|
||||
@ -216,92 +222,118 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
|
||||
Args:
|
||||
node: ast.Call object
|
||||
full_name: The precomputed full name of the callable, if one exists, None
|
||||
otherwise.
|
||||
name: The precomputed name of the callable, if one exists, None otherwise.
|
||||
|
||||
Returns:
|
||||
Whether an error was recorded.
|
||||
"""
|
||||
function_warnings = getattr(
|
||||
self._api_change_spec, "unrestricted_function_warnings", {})
|
||||
# Only look for *.-warnings here, the other will be handled by the Attribute
|
||||
# visitor. Also, do not warn for bare functions, only if the call func is
|
||||
# an attribute.
|
||||
warned = False
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
function_name = node.func.attr
|
||||
try:
|
||||
warning_message = function_warnings[function_name]
|
||||
self._file_edit.add(warning_message,
|
||||
node.lineno, node.col_offset, "", "",
|
||||
error="%s requires manual check." % function_name)
|
||||
except KeyError:
|
||||
pass
|
||||
warned = self._maybe_add_warning(node, "*." + name)
|
||||
|
||||
def _get_attribute_full_path(self, node):
|
||||
"""Traverse an attribute to generate a full name e.g. tf.foo.bar.
|
||||
# All arg warnings are handled here, since only we have the args
|
||||
arg_warnings = self._get_applicable_dict("function_arg_warnings",
|
||||
full_name, name)
|
||||
|
||||
Args:
|
||||
node: A Node of type Attribute.
|
||||
used_args = [kw.arg for kw in node.keywords]
|
||||
for arg, warning in arg_warnings.items():
|
||||
if arg in used_args:
|
||||
warned = True
|
||||
warning_message = warning.replace("<function name>", full_name or name)
|
||||
self.add_error(node.lineno, node.col_offset,
|
||||
"%s called with %s argument requires manual check: %s." %
|
||||
(full_name or name, arg, warning_message))
|
||||
|
||||
Returns:
|
||||
a '.'-delimited full-name or None if the tree was not a simple form.
|
||||
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
|
||||
"""
|
||||
curr = node
|
||||
items = []
|
||||
while not isinstance(curr, ast.Name):
|
||||
if not isinstance(curr, ast.Attribute):
|
||||
return None, None
|
||||
items.append(curr.attr)
|
||||
curr = curr.value
|
||||
items.append(curr.id)
|
||||
return ".".join(reversed(items)), items[0]
|
||||
return warned
|
||||
|
||||
def _find_true_position(self, node):
|
||||
"""Return correct line number and column offset for a given node.
|
||||
def _maybe_rename(self, parent, node, full_name):
|
||||
"""Replace node (Attribute or Name) with a node representing full_name."""
|
||||
new_name = self._api_change_spec.symbol_renames.get(full_name, None)
|
||||
if new_name:
|
||||
self.add_log(node.lineno, node.col_offset,
|
||||
"Renamed %r to %r" % (full_name, new_name))
|
||||
new_node = self._full_name_node(new_name, node.ctx)
|
||||
ast.copy_location(new_node, node)
|
||||
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
This is necessary mainly because ListComp's location reporting reports
|
||||
the next token after the list comprehension list opening.
|
||||
|
||||
Returns:
|
||||
lineno, offset for the given node
|
||||
|
||||
Args:
|
||||
node: Node for which we wish to know the lineno and col_offset
|
||||
"""
|
||||
if isinstance(node, ast.ListComp):
|
||||
# Strangely, ast.ListComp returns the col_offset of the first token
|
||||
# after the '[' token which appears to be a bug. Workaround by
|
||||
# explicitly finding the real start of the list comprehension.
|
||||
line = node.lineno
|
||||
col = node.col_offset
|
||||
# loop over lines
|
||||
while 1:
|
||||
# Reverse the text to and regular expression search for whitespace
|
||||
text = self._lines[line - 1]
|
||||
reversed_preceding_text = text[:col][::-1]
|
||||
# First find if a [ can be found with only whitespace between it and
|
||||
# col.
|
||||
m = FIND_OPEN.match(reversed_preceding_text)
|
||||
if m:
|
||||
new_col_offset = col - m.start(1) - 1
|
||||
return line, new_col_offset
|
||||
def _maybe_change_to_function_call(self, parent, node, full_name):
|
||||
"""Wraps node (typically, an Attribute or Expr) in a Call."""
|
||||
if full_name in self._api_change_spec.change_to_function:
|
||||
if not isinstance(parent, ast.Call):
|
||||
# ast.Call's constructor is really picky about how many arguments it
|
||||
# wants, and also, it changed between Py2 and Py3.
|
||||
if six.PY2:
|
||||
new_node = ast.Call(node, [], [], None, None)
|
||||
else:
|
||||
if (reversed_preceding_text == "" or
|
||||
reversed_preceding_text.isspace()):
|
||||
line = line - 1
|
||||
prev_line = self._lines[line - 1]
|
||||
# TODO(aselle):
|
||||
# this is poor comment detection, but it is good enough for
|
||||
# cases where the comment does not contain string literal starting/
|
||||
# ending characters. If ast gave us start and end locations of the
|
||||
# ast nodes rather than just start, we could use string literal
|
||||
# node ranges to filter out spurious #'s that appear in string
|
||||
# literals.
|
||||
comment_start = prev_line.find("#")
|
||||
if comment_start == -1:
|
||||
col = len(prev_line) - 1
|
||||
elif FIND_STRING_CHARS.search(prev_line[comment_start:]) is None:
|
||||
col = comment_start
|
||||
else:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
# Most other nodes return proper locations (with notably does not), but
|
||||
# it is not possible to use that in an argument.
|
||||
return node.lineno, node.col_offset
|
||||
new_node = ast.Call(node, [], [])
|
||||
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||
ast.copy_location(new_node, node)
|
||||
self.add_log(node.lineno, node.col_offset,
|
||||
"Changed %r to a function call" % full_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _maybe_add_arg_names(self, node, full_name):
|
||||
"""Make args into keyword args if function called full_name requires it."""
|
||||
function_reorders = self._api_change_spec.function_reorders
|
||||
|
||||
if full_name in function_reorders:
|
||||
reordered = function_reorders[full_name]
|
||||
new_keywords = []
|
||||
for idx, arg in enumerate(node.args):
|
||||
keyword_arg = reordered[idx]
|
||||
new_keywords.append(ast.keyword(arg=keyword_arg, value=arg))
|
||||
|
||||
if new_keywords:
|
||||
self.add_log(node.lineno, node.col_offset,
|
||||
"Added keywords to args of function %r" % full_name)
|
||||
node.args = []
|
||||
node.keywords = new_keywords + (node.keywords or [])
|
||||
return True
|
||||
return False
|
||||
|
||||
def _maybe_modify_args(self, node, full_name, name):
|
||||
"""Rename keyword args if the function called full_name requires it."""
|
||||
renamed_keywords = self._get_applicable_dict("function_keyword_renames",
|
||||
full_name, name)
|
||||
|
||||
if not renamed_keywords:
|
||||
return False
|
||||
|
||||
modified = False
|
||||
new_keywords = []
|
||||
for keyword in node.keywords:
|
||||
argkey = keyword.arg
|
||||
if argkey in renamed_keywords:
|
||||
modified = True
|
||||
if renamed_keywords[argkey] is None:
|
||||
lineno = getattr(keyword, "lineno", node.lineno)
|
||||
col_offset = getattr(keyword, "col_offset", node.col_offset)
|
||||
self.add_log(lineno, col_offset,
|
||||
"Removed argument %s for function %s" % (
|
||||
argkey, full_name or name))
|
||||
else:
|
||||
keyword.arg = renamed_keywords[argkey]
|
||||
lineno = getattr(keyword, "lineno", node.lineno)
|
||||
col_offset = getattr(keyword, "col_offset", node.col_offset)
|
||||
self.add_log(lineno, col_offset,
|
||||
"Renamed keyword argument for %s from %s to %s" % (
|
||||
full_name, argkey, renamed_keywords[argkey]))
|
||||
new_keywords.append(keyword)
|
||||
else:
|
||||
new_keywords.append(keyword)
|
||||
|
||||
if modified:
|
||||
node.keywords = new_keywords
|
||||
return modified
|
||||
|
||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting a call node in the AST.
|
||||
@ -309,104 +341,74 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
self._print_warning_for_function_unrestricted(node)
|
||||
|
||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||
full_name, name = self._get_attribute_full_path(node.func)
|
||||
|
||||
# Make sure the func is marked as being part of a call
|
||||
node.func.is_function_for_call = True
|
||||
assert self._stack[-1] is node
|
||||
|
||||
# Get the name for this call, so we can index stuff with it.
|
||||
full_name = self._get_full_name(node.func)
|
||||
if full_name:
|
||||
# Call special handlers
|
||||
function_handles = self._api_change_spec.function_handle
|
||||
glob_name = "*.{}".format(name)
|
||||
if glob_name in function_handles:
|
||||
function_handles[glob_name](self._file_edit, node, self._lines)
|
||||
if full_name in function_handles:
|
||||
function_handles[full_name](self._file_edit, node, self._lines)
|
||||
name = full_name.split(".")[-1]
|
||||
elif isinstance(node.func, ast.Name):
|
||||
name = node.func.id
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
name = node.func.attr
|
||||
else:
|
||||
name = None
|
||||
|
||||
# Examine any non-keyword argument and make it into a keyword argument
|
||||
# if reordering required.
|
||||
function_reorders = self._api_change_spec.function_reorders
|
||||
function_keyword_renames = (
|
||||
self._api_change_spec.function_keyword_renames)
|
||||
# Call standard transformers for this node.
|
||||
# Make sure warnings come first, since args or names triggering warnings
|
||||
# may be removed by the other transformations.
|
||||
self._maybe_add_call_warning(node, full_name, name)
|
||||
# Make all args into kwargs
|
||||
self._maybe_add_arg_names(node, full_name)
|
||||
# Argument name changes or deletions
|
||||
self._maybe_modify_args(node, full_name, name)
|
||||
|
||||
if full_name in function_reorders:
|
||||
reordered = function_reorders[full_name]
|
||||
for idx, arg in enumerate(node.args):
|
||||
lineno, col_offset = self._find_true_position(arg)
|
||||
if lineno is None or col_offset is None:
|
||||
self._file_edit.add(
|
||||
"Failed to add keyword %r to reordered function %r" %
|
||||
(reordered[idx], full_name),
|
||||
arg.lineno,
|
||||
arg.col_offset,
|
||||
"",
|
||||
"",
|
||||
error="A necessary keyword argument failed to be inserted.")
|
||||
else:
|
||||
keyword_arg = reordered[idx]
|
||||
if (full_name in function_keyword_renames and
|
||||
keyword_arg in function_keyword_renames[full_name]):
|
||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||
self._file_edit.add("Added keyword %r to reordered function %r" %
|
||||
(reordered[idx], full_name), lineno, col_offset,
|
||||
"", keyword_arg + "=")
|
||||
# Call transformers. These have the ability to modify the node, and if they
|
||||
# do, will return the new node they created (or the same node if they just
|
||||
# changed it). The are given the parent, but we will take care of
|
||||
# integrating their changes into the parent if they return a new node.
|
||||
#
|
||||
# These are matched on the old name, since renaming is performed by the
|
||||
# Attribute visitor, which happens later.
|
||||
transformers = self._get_applicable_entries("function_transformers",
|
||||
full_name, name)
|
||||
|
||||
# Examine each keyword argument and convert it to the final renamed form
|
||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||
function_keyword_renames[full_name])
|
||||
for keyword in node.keywords:
|
||||
argkey = keyword.arg
|
||||
argval = keyword.value
|
||||
parent = self._stack[-2]
|
||||
|
||||
if argkey in renamed_keywords:
|
||||
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
||||
if argval_lineno is not None and argval_col_offset is not None:
|
||||
# TODO(aselle): We should scan backward to find the start of the
|
||||
# keyword key. Unfortunately ast does not give you the location of
|
||||
# keyword keys, so we are forced to infer it from the keyword arg
|
||||
# value.
|
||||
key_start = argval_col_offset - len(argkey) - 1
|
||||
key_end = key_start + len(argkey) + 1
|
||||
if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
|
||||
"="):
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
(argkey,
|
||||
renamed_keywords[argkey]), argval_lineno,
|
||||
argval_col_offset - len(argkey) - 1,
|
||||
argkey + "=", renamed_keywords[argkey] + "=")
|
||||
continue
|
||||
self._file_edit.add(
|
||||
"Failed to rename keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
"",
|
||||
"",
|
||||
error="Failed to find keyword lexographically. Fix manually.")
|
||||
for transformer in transformers:
|
||||
logs = []
|
||||
errors = []
|
||||
new_node = transformer(parent, node, full_name, name, logs, errors)
|
||||
self.add_logs(logs)
|
||||
self.add_errors(errors)
|
||||
if new_node:
|
||||
if new_node is not node:
|
||||
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||
node = new_node
|
||||
self._stack[-1] = node
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
|
||||
assert self._stack[-1] is node
|
||||
|
||||
Args:
|
||||
node: Node that is of type ast.Attribute
|
||||
"""
|
||||
full_name, _ = self._get_attribute_full_path(node)
|
||||
full_name = self._get_full_name(node)
|
||||
if full_name:
|
||||
# Make sure the warning comes first, otherwise the name may have changed
|
||||
self._print_warning_for_function(node, full_name)
|
||||
self._rename_functions(node, full_name)
|
||||
if full_name in self._api_change_spec.change_to_function:
|
||||
if not hasattr(node, "is_function_for_call"):
|
||||
new_text = full_name + "()"
|
||||
self._file_edit.add("Changed %r to %r" % (full_name, new_text),
|
||||
node.lineno, node.col_offset, full_name, new_text)
|
||||
parent = self._stack[-2]
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
# Make sure the warning comes first, otherwise the name may have changed
|
||||
self._maybe_add_warning(node, full_name)
|
||||
|
||||
# Once we did a modification, node is invalid and not worth inspecting
|
||||
# further. Also, we only perform modifications for simple nodes, so
|
||||
# There'd be no point in descending further.
|
||||
if self._maybe_rename(parent, node, full_name):
|
||||
return
|
||||
if self._maybe_change_to_function_call(parent, node, full_name):
|
||||
return
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
class ASTCodeUpgrader(object):
|
||||
@ -429,16 +431,42 @@ class ASTCodeUpgrader(object):
|
||||
"""
|
||||
|
||||
# Write to a temporary file, just in case we are doing an implace modify.
|
||||
# pylint: disable=g-backslash-continuation
|
||||
with open(in_filename, "r") as in_file, \
|
||||
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||
ret = self.process_opened_file(in_filename, in_file, out_filename,
|
||||
temp_file)
|
||||
# pylint: enable=g-backslash-continuation
|
||||
|
||||
shutil.move(temp_file.name, out_filename)
|
||||
return ret
|
||||
|
||||
# Broad exceptions are required here because ast throws whatever it wants.
|
||||
# pylint: disable=broad-except
|
||||
def _format_errors(self, errors, in_filename):
|
||||
return ["%s:%d:%d: %s" % ((in_filename,) + error) for error in errors]
|
||||
|
||||
def update_string_pasta(self, text, in_filename):
|
||||
"""Updates a file using pasta."""
|
||||
try:
|
||||
t = pasta.parse(text)
|
||||
except (SyntaxError, ValueError, TypeError):
|
||||
log = "Failed to parse.\n\n" + traceback.format_exc()
|
||||
return 0, "", log, []
|
||||
|
||||
visitor = _PastaEditVisitor(self._api_change_spec)
|
||||
visitor.visit(t)
|
||||
|
||||
errors = self._format_errors(visitor.errors, in_filename)
|
||||
return 1, pasta.dump(t), visitor.log_text(), errors
|
||||
|
||||
def _format_log(self, log, in_filename, out_filename):
|
||||
text = "-" * 80 + "\n"
|
||||
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
||||
out_filename)
|
||||
text += "-" * 80 + "\n\n"
|
||||
text += log
|
||||
text += "-" * 80 + "\n\n"
|
||||
return text
|
||||
|
||||
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
|
||||
"""Process the given python file for incompatible changes.
|
||||
|
||||
@ -453,30 +481,16 @@ class ASTCodeUpgrader(object):
|
||||
Returns:
|
||||
A tuple representing number of files processed, log of actions, errors
|
||||
"""
|
||||
process_errors = []
|
||||
text = "-" * 80 + "\n"
|
||||
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
||||
out_filename)
|
||||
text += "-" * 80 + "\n\n"
|
||||
|
||||
parsed_ast = None
|
||||
lines = in_file.readlines()
|
||||
try:
|
||||
parsed_ast = ast.parse("".join(lines))
|
||||
except Exception:
|
||||
text += "Failed to parse %r\n\n" % in_filename
|
||||
text += traceback.format_exc()
|
||||
if parsed_ast:
|
||||
visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
|
||||
visitor.visit(parsed_ast)
|
||||
out_text, new_text, process_errors = visitor.process(lines)
|
||||
text += new_text
|
||||
if out_file:
|
||||
out_file.write(out_text)
|
||||
text += "\n"
|
||||
return 1, text, process_errors
|
||||
processed_file, new_file_content, log, process_errors = (
|
||||
self.update_string_pasta("".join(lines), in_filename))
|
||||
|
||||
# pylint: enable=broad-except
|
||||
if out_file and processed_file:
|
||||
out_file.write(new_file_content)
|
||||
|
||||
return (processed_file,
|
||||
self._format_log(log, in_filename, out_filename),
|
||||
process_errors)
|
||||
|
||||
def process_tree(self, root_directory, output_root_directory,
|
||||
copy_other_files):
|
||||
|
@ -39,6 +39,8 @@ following new APIs:
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import ast
|
||||
import pasta
|
||||
import six
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
@ -54,7 +56,6 @@ class NoUpdateSpec(ast_edits.APIChangeSpec):
|
||||
self.function_keyword_renames = {}
|
||||
self.symbol_renames = {}
|
||||
self.function_warnings = {}
|
||||
self.unrestricted_function_warnings = {}
|
||||
self.change_to_function = {}
|
||||
|
||||
|
||||
@ -401,7 +402,8 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
self.unrestricted_function_warnings = {"foo": "not good"}
|
||||
self.function_warnings = {"*.foo": "not good"}
|
||||
|
||||
texts = ["object.foo()", "get_object().foo()",
|
||||
"get_object().foo()", "object.foo().bar()"]
|
||||
for text in texts:
|
||||
@ -416,5 +418,26 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
||||
self.assertNotIn("not good", report)
|
||||
|
||||
|
||||
class ManualEditsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def disabled_test_arg_order(self):
|
||||
"""Tests that generated arg order is sane."""
|
||||
text = "f(a)"
|
||||
t = pasta.parse(text)
|
||||
node = pasta.ast_utils.find_nodes_by_type(t, (ast.Call,))[0]
|
||||
arg = ast.keyword(arg="b", value=ast.Num(n=0))
|
||||
node.keywords.append(arg)
|
||||
|
||||
# This is only needed in Python3, and I think it's a bug (but maybe in ast).
|
||||
arg.value.lineno = 0
|
||||
arg.value.col_offset = 0
|
||||
|
||||
# pasta.dump should never put kwargs before args, even if the col_offset is
|
||||
# messed up.
|
||||
# This fails if run with python3, but works find for python2.
|
||||
# In python3, the dump yields "f(b=0, a)".
|
||||
self.assertEqual(pasta.dump(t), "f(a, b=0)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
@ -175,27 +175,13 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.op_scope": ["values", "name", "default_name"],
|
||||
}
|
||||
|
||||
# Specially handled functions.
|
||||
self.function_handle = {"tf.reverse": self._reverse_handler}
|
||||
|
||||
# Warnings that should be printed if corresponding functions are used.
|
||||
self.function_warnings = {}
|
||||
|
||||
@staticmethod
|
||||
def _reverse_handler(file_edit_recorder, node, lines):
|
||||
del lines
|
||||
# TODO(aselle): Could check for a literal list of bools and try to convert
|
||||
# them to indices.
|
||||
comment = ("ERROR: tf.reverse has had its argument semantics changed "
|
||||
"significantly the converter cannot detect this reliably, so "
|
||||
"you need to inspect this usage manually.\n")
|
||||
file_edit_recorder.add(
|
||||
comment,
|
||||
node.lineno,
|
||||
node.col_offset,
|
||||
"tf.reverse",
|
||||
"tf.reverse",
|
||||
error="tf.reverse requires manual check.")
|
||||
self.function_warnings = {
|
||||
"tf.reverse":
|
||||
"ERROR: tf.reverse has had its argument semantics changed "
|
||||
"significantly. The converter cannot detect this reliably, so "
|
||||
"you need to inspect this usage manually.\n",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -112,7 +112,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
text = "tf.reverse(a, b)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, new_text)
|
||||
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
|
||||
self.assertIn("tf.reverse requires manual check", errors[0])
|
||||
|
||||
def testListComprehension(self):
|
||||
def _test(input, output): # pylint: disable=redefined-builtin
|
||||
|
@ -18,7 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import ast
|
||||
|
||||
import pasta
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import renames_v2
|
||||
@ -31,7 +33,22 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
def __init__(self):
|
||||
# Maps from a function name to a dictionary that describes how to
|
||||
# map from an old argument keyword to the new argument keyword.
|
||||
# If the new argument is None, it will be removed.
|
||||
# Only keyword args are handled, so make sure to also put any function in
|
||||
# function_reorders to ensure that all args are made into keywords first.
|
||||
self.function_keyword_renames = {
|
||||
"tf.gradients": {
|
||||
"colocate_gradients_with_ops": None,
|
||||
},
|
||||
"tf.hessians": {
|
||||
"colocate_gradients_with_ops": None,
|
||||
},
|
||||
"*.minimize": {
|
||||
"colocate_gradients_with_ops": None,
|
||||
},
|
||||
"*.compute_gradients": {
|
||||
"colocate_gradients_with_ops": None,
|
||||
},
|
||||
"tf.argmin": {
|
||||
"dimension": "axis",
|
||||
},
|
||||
@ -672,14 +689,31 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
# positional arguments yourself, this could do the wrong thing.
|
||||
self.function_reorders = reorders_v2.reorders
|
||||
|
||||
# Specially handled functions.
|
||||
self.function_handle = {
|
||||
"tf.batch_gather": self._batch_gather_handler,
|
||||
"tf.nn.dropout": self._dropout_handler,
|
||||
"tf.gradients": self._colocate_handler("tf.gradients"),
|
||||
"*.minimize": self._colocate_handler("Optimizer.minimize"),
|
||||
"*.compute_gradients":
|
||||
self._colocate_handler("Optimizer.compute_gradients"),
|
||||
# Specially handled functions (pasta version)
|
||||
# Each transformer is a callable which will be called with the arguments
|
||||
# transformer(parent, node, full_name, name, logs, errors)
|
||||
# Where logs and errors are lists to which (line, col, msg) tuples can be
|
||||
# appended, full_name is the FQN of the function called (or None if that is
|
||||
# unknown), name is the name of the function called (or None is that is
|
||||
# unknown). node is an ast.Call node representing this function call, and
|
||||
# parent is its parent in the AST.
|
||||
# The function may modify node (but not parent), and must return
|
||||
# - none, if nothing was modified
|
||||
# - node, if node was modified in place (make sure to use
|
||||
# pasta.ast_utils.replace_child to swap out children, otherwise formatting
|
||||
# may get messy)
|
||||
# - a replacement for node, if the whole call node was replaced. The caller
|
||||
# will take care of changing parent.
|
||||
self.function_transformers = {
|
||||
"tf.nn.dropout": self._dropout_transformer,
|
||||
"tf.batch_gather": self._batch_gather_transformer,
|
||||
"tf.to_bfloat16": self._cast_transformer,
|
||||
"tf.to_complex128": self._cast_transformer,
|
||||
"tf.to_complex64": self._cast_transformer,
|
||||
"tf.to_double": self._cast_transformer,
|
||||
"tf.to_float": self._cast_transformer,
|
||||
"tf.to_int32": self._cast_transformer,
|
||||
"tf.to_int64": self._cast_transformer,
|
||||
}
|
||||
|
||||
decay_function_comment = (
|
||||
@ -748,9 +782,45 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"compat.v1 for backward compatibility. Please update these calls to "
|
||||
"the TF 2.0 versions.")
|
||||
|
||||
export_saved_model_renamed = (
|
||||
"(Manual edit required) Please rename the method export_savedmodel() "
|
||||
"to export_saved_model(). Two things to note:\n\t(1) The argument "
|
||||
"strip_default_attributes has been removed. The function will always "
|
||||
"strip the default attributes from ops. If this breaks your code, "
|
||||
"please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
|
||||
"only effects core estimator. If you are using "
|
||||
"tf.contrib.learn.Estimator, please switch to using core estimator.")
|
||||
|
||||
make_initializable_iterator_deprecation = (
|
||||
"(Manual edit required) The "
|
||||
"`tf.data.Dataset.make_initializable_iterator()` method has been "
|
||||
"removed. If you are using the Estimator API, you can return a dataset "
|
||||
"directly from your input functions without creating an iterator. "
|
||||
"As a last resort, please replace calls to that method on `dataset` "
|
||||
"with a call to "
|
||||
"`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
|
||||
|
||||
make_one_shot_iterator_deprecation = (
|
||||
"(Manual edit required) The "
|
||||
"`tf.data.Dataset.make_one_shot_iterator()` method has been "
|
||||
"removed. If you are using eager execution, you can iterate over "
|
||||
"`dataset` using a Python `for` loop. If you are using the Estimator "
|
||||
"API, you can return a dataset directly from your input functions "
|
||||
"without creating an iterator. As a last resort, please replace calls "
|
||||
"to that method on `dataset` with a call to "
|
||||
"`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
|
||||
|
||||
# Function warnings. <function name> placeholder inside warnings will be
|
||||
# replaced by function name.
|
||||
# You can use *. to add items which do not check the FQN, and apply to e.g.,
|
||||
# methods.
|
||||
self.function_warnings = {
|
||||
"*.export_savedmodel":
|
||||
export_saved_model_renamed,
|
||||
"*.make_initializable_iterator":
|
||||
make_initializable_iterator_deprecation,
|
||||
"*.make_one_shot_iterator":
|
||||
make_one_shot_iterator_deprecation,
|
||||
"tf.assert_greater":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_equal":
|
||||
@ -834,11 +904,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
default_loss_reduction_changed,
|
||||
"tf.estimator.BaselineRegressor":
|
||||
default_loss_reduction_changed,
|
||||
"tf.hessians":
|
||||
"tf.hessians no longer takes "
|
||||
"'colocate_gradients_with_ops' argument. Also, "
|
||||
"arguments have been reordered so that 'name' is the "
|
||||
"last argument.",
|
||||
"tf.nn.conv1d":
|
||||
"WARNING: use_cudnn_on_gpu argument has been removed and \"value\""
|
||||
" was renamed to \"input\"",
|
||||
@ -1064,111 +1129,118 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
metrics_comment,
|
||||
}
|
||||
|
||||
# Warnings that are emitted only if a specific arg is found.
|
||||
self.function_arg_warnings = {
|
||||
"tf.gradients": {
|
||||
"colocate_gradients_with_ops":
|
||||
"tf.gradients no longer takes "
|
||||
"'colocate_gradients_with_ops' argument, it behaves as if it "
|
||||
"was set to True.",
|
||||
},
|
||||
"*.minimize": {
|
||||
"colocate_gradients_with_ops":
|
||||
"Optimizer.minimize no longer takes "
|
||||
"'colocate_gradients_with_ops' argument, it behaves as if it "
|
||||
"was set to True.",
|
||||
},
|
||||
"*.compute_gradients": {
|
||||
"colocate_gradients_with_ops":
|
||||
"Optimizer.compute_gradients no "
|
||||
"longer takes 'colocate_gradients_with_ops' argument, it "
|
||||
"behaves as if it was set to True.",
|
||||
},
|
||||
}
|
||||
|
||||
self.symbol_renames = {
|
||||
name: new_name
|
||||
for name, new_name in self.symbol_renames.items()
|
||||
}
|
||||
|
||||
export_saved_model_renamed = (
|
||||
"(Manual edit required) Please rename the method export_savedmodel() "
|
||||
"to export_saved_model(). Two things to note:\n\t(1) The argument "
|
||||
"strip_default_attributes has been removed. The function will always "
|
||||
"strip the default attributes from ops. If this breaks your code, "
|
||||
"please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
|
||||
"only effects core estimator. If you are using "
|
||||
"tf.contrib.learn.Estimator, please switch to using core estimator.")
|
||||
|
||||
make_initializable_iterator_deprecation = (
|
||||
"(Manual edit required) The "
|
||||
"`tf.data.Dataset.make_initializable_iterator()` method has been "
|
||||
"removed. If you are using the Estimator API, you can return a dataset "
|
||||
"directly from your input functions without creating an iterator. "
|
||||
"As a last resort, please replace calls to that method on `dataset` "
|
||||
"with a call to "
|
||||
"`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
|
||||
|
||||
make_one_shot_iterator_deprecation = (
|
||||
"(Manual edit required) The "
|
||||
"`tf.data.Dataset.make_one_shot_iterator()` method has been "
|
||||
"removed. If you are using eager execution, you can iterate over "
|
||||
"`dataset` using a Python `for` loop. If you are using the Estimator "
|
||||
"API, you can return a dataset directly from your input functions "
|
||||
"without creating an iterator. As a last resort, please replace calls "
|
||||
"to that method on `dataset` with a call to "
|
||||
"`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
|
||||
|
||||
# Specify warnings for functions that aren't restricted to the tf.x.y.z
|
||||
# format. This should only be used for methods with unique names, e.g.
|
||||
# export_savedmodel, which is only defined in Estimator objects.
|
||||
self.unrestricted_function_warnings = {
|
||||
"export_savedmodel": export_saved_model_renamed,
|
||||
"make_initializable_iterator": make_initializable_iterator_deprecation,
|
||||
"make_one_shot_iterator": make_one_shot_iterator_deprecation,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dropout_handler(file_edit_recorder, node, lines):
|
||||
del lines
|
||||
def _dropout_transformer(parent, node, full_name, name, logs, errors):
|
||||
def _replace_keep_prob_node(parent, old_value):
|
||||
"""Replaces old_value with 1-(old_value)."""
|
||||
one = ast.Num(n=1)
|
||||
one.lineno = 0
|
||||
one.col_offset = 0
|
||||
new_value = ast.BinOp(left=one, op=ast.Sub(),
|
||||
right=old_value)
|
||||
# This copies the prefix and suffix on old_value to new_value.
|
||||
pasta.ast_utils.replace_child(parent, old_value, new_value)
|
||||
ast.copy_location(new_value, old_value)
|
||||
# Put parentheses around keep_prob.value (and remove the old prefix/
|
||||
# suffix, they should only be around new_value).
|
||||
pasta.base.formatting.set(old_value, "prefix", "(")
|
||||
pasta.base.formatting.set(old_value, "suffix", ")")
|
||||
|
||||
# Check if we have a keep_prob keyword arg
|
||||
for keep_prob in node.keywords:
|
||||
if keep_prob.arg == "keep_prob":
|
||||
logs.append((node.lineno, node.col_offset,
|
||||
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
||||
"recomputing value. Please check this transformation.\n"))
|
||||
keep_prob.arg = "rate"
|
||||
_replace_keep_prob_node(keep_prob, keep_prob.value)
|
||||
return node
|
||||
|
||||
# Maybe it was a positional arg
|
||||
if len(node.args) < 2:
|
||||
comment = ("ERROR: tf.nn.dropout did not take arguments, so automatic "
|
||||
"transformation was disabled. tf.nn.dropout has changed "
|
||||
"the semantics of the second argument.")
|
||||
file_edit_recorder.add(
|
||||
comment,
|
||||
node.lineno,
|
||||
node.col_offset,
|
||||
"tf.nn.dropout",
|
||||
"tf.nn.dropout",
|
||||
error="tf.nn.dropout requires manual check.")
|
||||
errors.append((node.lineno, node.col_offset,
|
||||
"ERROR: tf.nn.dropout called without arguments, so "
|
||||
"automatic fix was disabled. tf.nn.dropout has changed "
|
||||
"the semantics of the second argument."))
|
||||
else:
|
||||
comment = ("WARNING: tf.nn.dropout has changed the semantics of the "
|
||||
"second argument. Please check the transformation.\n")
|
||||
file_edit_recorder.add(
|
||||
comment,
|
||||
node.args[1].lineno,
|
||||
node.args[1].col_offset,
|
||||
"",
|
||||
"1 - ")
|
||||
_replace_keep_prob_node(node, node.args[1])
|
||||
logs.append((node.lineno, node.col_offset,
|
||||
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
||||
"recomputing value.\n"))
|
||||
errors.append((node.lineno, node.col_offset,
|
||||
"WARNING: tf.nn.dropout has changed the semantics of the "
|
||||
"second argument. Please check the applied transformation."
|
||||
))
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _colocate_handler(name):
|
||||
def _helper(file_edit_recorder, node, lines):
|
||||
"""Handler for updating colocate arguments."""
|
||||
del lines
|
||||
for keyword in node.keywords:
|
||||
if keyword.arg == "colocate_gradients_with_ops":
|
||||
# TODO(jhseu): Since ast_edit.py does string replacement, there's no
|
||||
# straightforward way to remove the argument. Try to fix before 2.0 is
|
||||
# final.
|
||||
comment = ("For tf.gradients and tf.Optimizer.minimize, "
|
||||
"colocate_gradients_with_op has been removed and now "
|
||||
"defaults to True.")
|
||||
file_edit_recorder.add(
|
||||
comment,
|
||||
node.lineno,
|
||||
node.col_offset,
|
||||
"",
|
||||
"",
|
||||
error="{} requires manual check.".format(name))
|
||||
return _helper
|
||||
def _cast_transformer(parent, node, full_name, name, logs, errors):
|
||||
"""Transforms to_int and to_float to cast(..., dtype=...)."""
|
||||
|
||||
# Find out the dtype to cast to from the function name
|
||||
dtype_str = name[3:]
|
||||
new_arg = ast.keyword(arg="dtype",
|
||||
value=ast.Attribute(value=ast.Name(id="tf",
|
||||
ctx=ast.Load()),
|
||||
attr=dtype_str, ctx=ast.Load()))
|
||||
|
||||
# Python3 ast requires the args for the Attribute, but codegen will mess up
|
||||
# the arg order if we just set them to 0.
|
||||
new_arg.value.lineno = node.lineno
|
||||
new_arg.value.col_offset = node.col_offset+100
|
||||
|
||||
node.keywords.append(new_arg)
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
node.func.attr = "cast"
|
||||
else:
|
||||
assert isinstance(node.func, ast.Name)
|
||||
node.func.id = "cast"
|
||||
|
||||
logs.append((node.lineno, node.col_offset,
|
||||
"Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name,
|
||||
dtype_str)))
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def _batch_gather_handler(file_edit_recorder, node, lines):
|
||||
lineno = node.lineno
|
||||
column = node.col_offset
|
||||
def _batch_gather_transformer(parent, node, full_name, name, logs, errors):
|
||||
# Check if the call already has a batch_dims argument
|
||||
if any([kw.arg == "batch_dims" for kw in node.keywords]):
|
||||
logs.append((node.lineno, node.col_offset, "tf.batch_gather already has "
|
||||
"batch_dims argument. Neat."))
|
||||
return None
|
||||
|
||||
# Find the position to add the batch_dims argument. We add it as the
|
||||
# first argument, since that's easiest. This is safe because we included
|
||||
# batch_gather in self.reordered_function_names, so it will have all
|
||||
# of its arguments changed to keyword arguments.
|
||||
m = re.match(r"tf\s*\.\s*batch_gather\s*\(", lines[lineno - 1][column:])
|
||||
if m is not None:
|
||||
file_edit_recorder.add(
|
||||
"Added keyword argument 'batch_dims=-1' to 'tf.batch_gather'",
|
||||
lineno, column + m.end(), "", "batch_dims=-1, ")
|
||||
else:
|
||||
file_edit_recorder.add(
|
||||
"Unable to add keyword argument 'batch_dims=-1' to 'tf.batch_gather'",
|
||||
lineno, column, "", "",
|
||||
error="Unable to add keyword argument batch_dims=-1 to "
|
||||
"tf.batch_gather; please add it manually.")
|
||||
minus_one = ast.Num(n=-1)
|
||||
minus_one.lineno = 0
|
||||
minus_one.col_offset = 0
|
||||
new_arg = ast.keyword("batch_dims", minus_one)
|
||||
node.keywords.append(new_arg)
|
||||
logs.append((node.lineno, node.col_offset,
|
||||
"Added keyword argument batch_dims=-1 to tf.batch_gather."))
|
||||
return node
|
||||
|
@ -239,8 +239,8 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
}
|
||||
function_warnings = (
|
||||
tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
|
||||
function_handles = (
|
||||
tf_upgrade_v2.TFAPIChangeSpec().function_handle)
|
||||
function_transformers = (
|
||||
tf_upgrade_v2.TFAPIChangeSpec().function_transformers)
|
||||
keyword_renames = (
|
||||
tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
|
||||
|
||||
@ -255,7 +255,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
|
||||
for name in names_v1:
|
||||
tf_name = "tf.%s" % name
|
||||
if tf_name in function_warnings or tf_name in function_handles:
|
||||
if tf_name in function_warnings or tf_name in function_transformers:
|
||||
continue # These require manual change
|
||||
if tf_name in v1_name_exceptions:
|
||||
continue
|
||||
@ -362,15 +362,14 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
|
||||
text = "%s(a, b)\n" % decay
|
||||
_, report, errors, _ = self._upgrade(text)
|
||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
|
||||
self.assertIn("%s requires manual check" % decay, errors[0])
|
||||
self.assertIn("%s has been changed" % decay, report)
|
||||
|
||||
def testPiecewiseDecay(self):
|
||||
text = "tf.train.piecewise_constant_decay(a, b)\n"
|
||||
_, report, errors, _ = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
errors,
|
||||
["test.py:1: tf.train.piecewise_constant_decay requires manual check."])
|
||||
self.assertIn("tf.train.piecewise_constant_decay requires manual check",
|
||||
errors[0])
|
||||
self.assertIn("tf.train.piecewise_constant_decay has been changed", report)
|
||||
|
||||
def testMetrics(self):
|
||||
@ -414,7 +413,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
text = ns + "(a, b)"
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
|
||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
|
||||
self.assertIn("test.py:1:0: %s requires manual check" % ns, errors[0])
|
||||
self.assertIn(
|
||||
"WARNING: tf.metrics have been converted to object oriented"
|
||||
" versions in TF 2.0 and after. The metric function calls have been "
|
||||
@ -445,7 +444,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
text = ns + "(a, b)"
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
|
||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
|
||||
self.assertIn("test.py:1:0: %s requires manual check" % ns, errors[0])
|
||||
self.assertIn(
|
||||
"WARNING: tf.losses have been converted to object oriented"
|
||||
" versions in TF 2.0 and after. The loss function calls have been "
|
||||
@ -463,7 +462,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
text = ns + "(a, b)"
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
|
||||
self.assertIn("%s requires manual check" % ns, errors[0])
|
||||
self.assertIn("loss_reduction has been changed", report)
|
||||
|
||||
def testDropout(self):
|
||||
@ -471,15 +470,40 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text,
|
||||
"tf.nn.dropout(x, 1 - keep_prob, name=\"foo\")\n",
|
||||
"tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n",
|
||||
)
|
||||
|
||||
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text,
|
||||
"tf.nn.dropout(x, rate=1 - (.4), name=\"foo\")\n",
|
||||
)
|
||||
|
||||
text = (
|
||||
"tf.nn.dropout(x, # Stuff before\n"
|
||||
" keep_prob=.4, # Stuff after\n"
|
||||
" name=\"foo\")\n"
|
||||
)
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text,
|
||||
"tf.nn.dropout(x, # Stuff before\n"
|
||||
" rate=1 - (.4), # Stuff after\n"
|
||||
" name=\"foo\")\n",
|
||||
)
|
||||
|
||||
text = "tf.nn.dropout(x)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, text)
|
||||
self.assertIn("tf.nn.dropout called without arguments", errors[0])
|
||||
|
||||
def testDropoutExpr(self):
|
||||
text = "tf.nn.dropout(x, 1 - func(3 + 4.), name=\"foo\")\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
errors,
|
||||
["test.py:1: tf.nn.dropout requires manual check."]
|
||||
new_text,
|
||||
"tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n",
|
||||
)
|
||||
|
||||
def testCountNonZeroChanges(self):
|
||||
@ -543,9 +567,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
|
||||
text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
self.assertEqual(errors, ["test.py:1: tf.gradients requires manual check."])
|
||||
self.assertEqual("tf.gradients(a)\n", new_text)
|
||||
self.assertIn("tf.gradients", errors[0])
|
||||
self.assertIn("requires manual check", errors[0])
|
||||
|
||||
def testColocateGradientsWithOpsMinimize(self):
|
||||
text = "optimizer.minimize(a, foo=False)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
@ -553,10 +579,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
|
||||
text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
self.assertEqual(errors,
|
||||
["test.py:1: Optimizer.minimize requires manual check."])
|
||||
self.assertEqual("optimizer.minimize(a)\n", new_text)
|
||||
self.assertIn("requires manual check", errors[0])
|
||||
self.assertIn("minimize", errors[0])
|
||||
|
||||
def testColocateGradientsWithOpsComputeGradients(self):
|
||||
text = "optimizer.compute_gradients(a, foo=False)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
@ -564,10 +591,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
|
||||
text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
self.assertEqual(errors,
|
||||
["test.py:1: Optimizer.compute_gradients "
|
||||
"requires manual check."])
|
||||
self.assertEqual("optimizer.compute_gradients(a)\n", new_text)
|
||||
self.assertIn("requires manual check", errors[0])
|
||||
self.assertIn("compute_gradients", errors[0])
|
||||
|
||||
def testExportSavedModelRename(self):
|
||||
text = "self.est.export_savedmodel(path)"
|
||||
@ -673,7 +699,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn(
|
||||
"tf.nn.softmax_cross_entropy_with_logits requires manual check.",
|
||||
"tf.nn.softmax_cross_entropy_with_logits requires manual check",
|
||||
errors[0])
|
||||
self.assertIn(
|
||||
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. ",
|
||||
@ -817,27 +843,24 @@ tf.print('abc')
|
||||
|
||||
def testBatchGather(self):
|
||||
text = "tf.batch_gather(foo, bar)"
|
||||
expected_text = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
||||
expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
|
||||
expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn(new_text, [expected_text1, expected_text2])
|
||||
|
||||
text = "tf.batch_gather(params=foo, indices=bar)"
|
||||
expected_text = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
||||
expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
|
||||
expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn(new_text, [expected_text1, expected_text2])
|
||||
|
||||
text = "tf.batch_gather ( foo, bar)"
|
||||
expected_text = "tf.gather (batch_dims=-1, params=foo, indices=bar)"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
text = "(tf.batch_gather\n(foo, bar))"
|
||||
expected_text = "(tf.gather\n(params=foo, indices=bar))"
|
||||
expected_errors = ["test.py:1: Unable to add keyword argument batch_dims=-1"
|
||||
" to tf.batch_gather; please add it manually."]
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(errors, expected_errors)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
def testCast(self):
|
||||
for dtype in ["int32", "int64", "float", "double",
|
||||
"complex64", "complex128", "bfloat16"]:
|
||||
text = "tf.to_%s(x, name='test')" % dtype
|
||||
expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
@ -169,6 +169,7 @@ filegroup(
|
||||
"@local_config_sycl//sycl:LICENSE.text",
|
||||
"@nasm//:LICENSE",
|
||||
"@nsync//:LICENSE",
|
||||
"@pasta//:LICENSE",
|
||||
"@pcre//:LICENCE",
|
||||
"@png_archive//:LICENSE",
|
||||
"@protobuf_archive//:LICENSE",
|
||||
|
@ -29,6 +29,7 @@ load("//third_party/jpeg:workspace.bzl", jpeg = "repo")
|
||||
load("//third_party/nasm:workspace.bzl", nasm = "repo")
|
||||
load("//third_party/kissfft:workspace.bzl", kissfft = "repo")
|
||||
load("//third_party/keras_applications_archive:workspace.bzl", keras_applications = "repo")
|
||||
load("//third_party/pasta:workspace.bzl", pasta = "repo")
|
||||
|
||||
def initialize_third_party():
|
||||
""" Load third party repositories. See above load() statements. """
|
||||
@ -41,6 +42,7 @@ def initialize_third_party():
|
||||
kissfft()
|
||||
jpeg()
|
||||
nasm()
|
||||
pasta()
|
||||
|
||||
# Sanitize a dependency so that it works correctly from code that includes
|
||||
# TensorFlow as a submodule.
|
||||
|
1
third_party/pasta/BUILD
vendored
Normal file
1
third_party/pasta/BUILD
vendored
Normal file
@ -0,0 +1 @@
|
||||
# Empty BUILD file to force build system to see this directory at all.
|
29
third_party/pasta/BUILD.bazel
vendored
Normal file
29
third_party/pasta/BUILD.bazel
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
# Description:
|
||||
# AST-based python refactoring.
|
||||
|
||||
licenses(["notice"]) # Apache2
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "pasta",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"augment/__init__.py",
|
||||
"augment/errors.py",
|
||||
"augment/import_utils.py",
|
||||
"augment/inline.py",
|
||||
"augment/rename.py",
|
||||
"base/__init__.py",
|
||||
"base/annotate.py",
|
||||
"base/ast_constants.py",
|
||||
"base/ast_utils.py",
|
||||
"base/codegen.py",
|
||||
"base/formatting.py",
|
||||
"base/scope.py",
|
||||
"base/test_utils.py",
|
||||
"base/token_generator.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
13
third_party/pasta/BUILD.system
vendored
Normal file
13
third_party/pasta/BUILD.system
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
# Description: Pasta, AST based python refactoring.
|
||||
|
||||
licenses(["notice"]) # Apache2
|
||||
|
||||
filegroup(
|
||||
name = "LICENSE",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pasta",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
16
third_party/pasta/workspace.bzl
vendored
Normal file
16
third_party/pasta/workspace.bzl
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
"""Loads pasta python package."""
|
||||
|
||||
load("//third_party:repo.bzl", "third_party_http_archive")
|
||||
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "pasta",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
|
||||
"https://github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
|
||||
],
|
||||
strip_prefix = "pasta-c3d72cdee6fc806251949e912510444d58d7413c/pasta",
|
||||
sha256 = "b5905f9cecc4b28363c563f3c4cb0545288bd35f7cc72c55066e97e53befc084",
|
||||
build_file = "//third_party/pasta:BUILD.bazel",
|
||||
system_build_file = "//third_party/pasta:BUILD.system",
|
||||
)
|
Loading…
Reference in New Issue
Block a user