Refactor compatibility upgrade tool to use pasta (an AST based refactoring tool) instead of line based edits.
This allows for some simplification (e.g., we can now remove arguments with a dict entry) and for much more powerful transformations. This passes all tests with no destructive test changes (though there are cosmetic test changes). Also adds transformations for tf.to_dtype -> tf.cast(..., dtype=...). PiperOrigin-RevId: 227934801
This commit is contained in:
parent
75f2e9c266
commit
fe66882827
@ -1,17 +1,21 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(default_visibility = ["//tensorflow:internal"])
|
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"tf_copts", # @unused
|
"tf_copts", # @unused
|
||||||
"tf_cc_test", # @unused
|
"tf_cc_test", # @unused
|
||||||
)
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "ast_edits",
|
name = "ast_edits",
|
||||||
srcs = ["ast_edits.py"],
|
srcs = ["ast_edits.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"@pasta",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import collections
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@ -27,6 +26,9 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import pasta
|
||||||
|
import six
|
||||||
|
|
||||||
# Some regular expressions we will need for parsing
|
# Some regular expressions we will need for parsing
|
||||||
FIND_OPEN = re.compile(r"^\s*(\[).*$")
|
FIND_OPEN = re.compile(r"^\s*(\[).*$")
|
||||||
FIND_STRING_CHARS = re.compile(r"['\"]")
|
FIND_STRING_CHARS = re.compile(r"['\"]")
|
||||||
@ -44,169 +46,173 @@ class APIChangeSpec(object):
|
|||||||
notifications)
|
notifications)
|
||||||
* `function_reorders`: maps functions whose argument order has changed to the
|
* `function_reorders`: maps functions whose argument order has changed to the
|
||||||
list of arguments in the new order
|
list of arguments in the new order
|
||||||
* `function_handle`: maps function names to custom handlers for the function
|
|
||||||
* `function_warnings`: maps full names of functions to warnings that will be
|
* `function_warnings`: maps full names of functions to warnings that will be
|
||||||
printed out if the function is used. (e.g. tf.nn.convolution())
|
printed out if the function is used. (e.g. tf.nn.convolution())
|
||||||
* `unrestricted_function_warnings`: maps names of functions to warnings that
|
* `function_transformers`: maps function names to custom handlers
|
||||||
will be printed out when the function is used (e.g. foo.convolution()).
|
|
||||||
* `function_keyword_additions`: maps function names to a map of arg->value
|
|
||||||
names that should be passed to the function.
|
|
||||||
|
|
||||||
For an example, see `TFAPIChangeSpec`.
|
For an example, see `TFAPIChangeSpec`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class _FileEditTuple(
|
class _PastaEditVisitor(ast.NodeVisitor):
|
||||||
collections.namedtuple("_FileEditTuple",
|
|
||||||
["comment", "line", "start", "old", "new"])):
|
|
||||||
"""Each edit that is recorded by a _FileEditRecorder.
|
|
||||||
|
|
||||||
Fields:
|
|
||||||
comment: A description of the edit and why it was made.
|
|
||||||
line: The line number in the file where the edit occurs (1-indexed).
|
|
||||||
start: The column number in the file where the edit occurs (0-indexed).
|
|
||||||
old: text string to remove (this must match what was in file).
|
|
||||||
new: text string to add in place of `old`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
|
|
||||||
class _FileEditRecorder(object):
|
|
||||||
"""Record changes that need to be done to the file."""
|
|
||||||
|
|
||||||
def __init__(self, filename):
|
|
||||||
# all edits are lists of chars
|
|
||||||
self._filename = filename
|
|
||||||
|
|
||||||
self._line_to_edit = collections.defaultdict(list)
|
|
||||||
self._errors = []
|
|
||||||
|
|
||||||
def process(self, text):
|
|
||||||
"""Process a list of strings, each corresponding to the recorded changes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: A list of lines of text (assumed to contain newlines)
|
|
||||||
Returns:
|
|
||||||
A tuple of the modified text and a textual description of what is done.
|
|
||||||
Raises:
|
|
||||||
ValueError: if substitution source location does not have expected text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
change_report = ""
|
|
||||||
|
|
||||||
# Iterate of each line
|
|
||||||
for line, edits in self._line_to_edit.items():
|
|
||||||
offset = 0
|
|
||||||
# sort by column so that edits are processed in order in order to make
|
|
||||||
# indexing adjustments cumulative for changes that change the string
|
|
||||||
# length
|
|
||||||
edits.sort(key=lambda x: x.start)
|
|
||||||
|
|
||||||
# Extract each line to a list of characters, because mutable lists
|
|
||||||
# are editable, unlike immutable strings.
|
|
||||||
char_array = list(text[line - 1])
|
|
||||||
|
|
||||||
# Record a description of the change
|
|
||||||
change_report += "%r Line %d\n" % (self._filename, line)
|
|
||||||
change_report += "-" * 80 + "\n\n"
|
|
||||||
for e in edits:
|
|
||||||
change_report += "%s\n" % e.comment
|
|
||||||
change_report += "\n Old: %s" % (text[line - 1])
|
|
||||||
|
|
||||||
# Make underscore buffers for underlining where in the line the edit was
|
|
||||||
change_list = [" "] * len(text[line - 1])
|
|
||||||
change_list_new = [" "] * len(text[line - 1])
|
|
||||||
|
|
||||||
# Iterate for each edit
|
|
||||||
for e in edits:
|
|
||||||
# Create effective start, end by accounting for change in length due
|
|
||||||
# to previous edits
|
|
||||||
start_eff = e.start + offset
|
|
||||||
end_eff = start_eff + len(e.old)
|
|
||||||
|
|
||||||
# Make sure the edit is changing what it should be changing
|
|
||||||
old_actual = "".join(char_array[start_eff:end_eff])
|
|
||||||
if old_actual != e.old:
|
|
||||||
raise ValueError("Expected text %r but got %r" %
|
|
||||||
("".join(e.old), "".join(old_actual)))
|
|
||||||
# Make the edit
|
|
||||||
char_array[start_eff:end_eff] = list(e.new)
|
|
||||||
|
|
||||||
# Create the underline highlighting of the before and after
|
|
||||||
change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
|
|
||||||
change_list_new[start_eff:end_eff] = "~" * len(e.new)
|
|
||||||
|
|
||||||
# Keep track of how to generate effective ranges
|
|
||||||
offset += len(e.new) - len(e.old)
|
|
||||||
|
|
||||||
# Finish the report comment
|
|
||||||
change_report += " %s\n" % "".join(change_list)
|
|
||||||
text[line - 1] = "".join(char_array)
|
|
||||||
change_report += " New: %s" % (text[line - 1])
|
|
||||||
change_report += " %s\n\n" % "".join(change_list_new)
|
|
||||||
return "".join(text), change_report, self._errors
|
|
||||||
|
|
||||||
def add(self, comment, line, start, old, new, error=None):
|
|
||||||
"""Add a new change that is needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
comment: A description of what was changed
|
|
||||||
line: Line number (1 indexed)
|
|
||||||
start: Column offset (0 indexed)
|
|
||||||
old: old text
|
|
||||||
new: new text
|
|
||||||
error: this "edit" is something that cannot be fixed automatically
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
self._line_to_edit[line].append(
|
|
||||||
_FileEditTuple(comment, line, start, old, new))
|
|
||||||
if error:
|
|
||||||
self._errors.append("%s:%d: %s" % (self._filename, line, error))
|
|
||||||
|
|
||||||
|
|
||||||
class _ASTCallVisitor(ast.NodeVisitor):
|
|
||||||
"""AST Visitor that processes function calls.
|
"""AST Visitor that processes function calls.
|
||||||
|
|
||||||
Updates function calls from old API version to new API version using a given
|
Updates function calls from old API version to new API version using a given
|
||||||
change spec.
|
change spec.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filename, lines, api_change_spec):
|
def __init__(self, api_change_spec):
|
||||||
self._filename = filename
|
|
||||||
self._file_edit = _FileEditRecorder(filename)
|
|
||||||
self._lines = lines
|
|
||||||
self._api_change_spec = api_change_spec
|
self._api_change_spec = api_change_spec
|
||||||
|
self._log = [] # Holds 3-tuples: line, col, msg.
|
||||||
|
self._errors = [] # Same structure as _log.
|
||||||
|
self._stack = [] # Allow easy access to parents.
|
||||||
|
|
||||||
def process(self, lines):
|
# Overridden to maintain a stack of nodes to allow for parent access
|
||||||
return self._file_edit.process(lines)
|
def visit(self, node):
|
||||||
|
self._stack.append(node)
|
||||||
|
super(_PastaEditVisitor, self).visit(node)
|
||||||
|
self._stack.pop()
|
||||||
|
|
||||||
def generic_visit(self, node):
|
@property
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
def errors(self):
|
||||||
|
return self._errors
|
||||||
|
|
||||||
def _rename_functions(self, node, full_name):
|
@property
|
||||||
symbol_renames = self._api_change_spec.symbol_renames
|
def log(self):
|
||||||
try:
|
return self._log
|
||||||
new_name = symbol_renames[full_name]
|
|
||||||
self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
|
|
||||||
node.lineno, node.col_offset, full_name, new_name)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _print_warning_for_function(self, node, full_name):
|
def _format_log(self, log):
|
||||||
|
text = ""
|
||||||
|
for log_entry in log:
|
||||||
|
text += "Line %d:%d: %s\n" % log_entry
|
||||||
|
return text
|
||||||
|
|
||||||
|
def log_text(self):
|
||||||
|
return self._format_log(self.log)
|
||||||
|
|
||||||
|
def add_log(self, lineno, col, msg):
|
||||||
|
self._log.append((lineno, col, msg))
|
||||||
|
print("Line %d:%d: %s" % (lineno, col, msg))
|
||||||
|
|
||||||
|
def add_error(self, lineno, col, msg):
|
||||||
|
# All errors are also added to the regular log.
|
||||||
|
self.add_log(lineno, col, msg)
|
||||||
|
self._errors.append((lineno, col, msg))
|
||||||
|
|
||||||
|
def add_logs(self, logs):
|
||||||
|
"""Record a log and print it.
|
||||||
|
|
||||||
|
The log should be a tuple (lineno, col_offset, msg), which will be printed
|
||||||
|
and then recorded. It is part of the log available in the self.log property.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logs: The log to add. Must be a tuple (lineno, col_offset, msg).
|
||||||
|
"""
|
||||||
|
self._log.extend(logs)
|
||||||
|
for log in logs:
|
||||||
|
print("Line %d:%d: %s" % log)
|
||||||
|
|
||||||
|
def add_errors(self, errors):
|
||||||
|
"""Record an error and print it.
|
||||||
|
|
||||||
|
The error must be a tuple (lineno, col_offset, msg), which will be printed
|
||||||
|
and then recorded as both a log and an error. It is therefore part of the
|
||||||
|
log available in the self.log as well as the self.errors property.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
errors: The log to add. Must be a tuple (lineno, col_offset, msg).
|
||||||
|
"""
|
||||||
|
self.add_logs(errors)
|
||||||
|
self._errors.extend(errors)
|
||||||
|
|
||||||
|
def _get_applicable_entries(self, transformer_field, full_name, name):
|
||||||
|
"""Get all list entries indexed by name that apply to full_name or name."""
|
||||||
|
# Transformers are indexed to full name, name, or no name
|
||||||
|
# as a performance optimization.
|
||||||
|
function_transformers = getattr(self._api_change_spec,
|
||||||
|
transformer_field, {})
|
||||||
|
|
||||||
|
glob_name = "*." + name if name else None
|
||||||
|
transformers = []
|
||||||
|
if full_name in function_transformers:
|
||||||
|
transformers.append(function_transformers[full_name])
|
||||||
|
if glob_name in function_transformers:
|
||||||
|
transformers.append(function_transformers[glob_name])
|
||||||
|
if "*" in function_transformers:
|
||||||
|
transformers.append(function_transformers["*"])
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
def _get_applicable_dict(self, transformer_field, full_name, name):
|
||||||
|
"""Get all dict entries indexed by name that apply to full_name or name."""
|
||||||
|
# Transformers are indexed to full name, name, or no name
|
||||||
|
# as a performance optimization.
|
||||||
|
function_transformers = getattr(self._api_change_spec,
|
||||||
|
transformer_field, {})
|
||||||
|
|
||||||
|
glob_name = "*." + name if name else None
|
||||||
|
transformers = function_transformers.get("*", {}).copy()
|
||||||
|
transformers.update(function_transformers.get(glob_name, {}))
|
||||||
|
transformers.update(function_transformers.get(full_name, {}))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
def _get_full_name(self, node):
|
||||||
|
"""Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
|
||||||
|
|
||||||
|
This is the inverse of _full_name_node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: A Node of type Attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a '.'-delimited full-name or None if node was not Attribute or Name.
|
||||||
|
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
|
||||||
|
"""
|
||||||
|
curr = node
|
||||||
|
items = []
|
||||||
|
while not isinstance(curr, ast.Name):
|
||||||
|
if not isinstance(curr, ast.Attribute):
|
||||||
|
return None
|
||||||
|
items.append(curr.attr)
|
||||||
|
curr = curr.value
|
||||||
|
items.append(curr.id)
|
||||||
|
return ".".join(reversed(items))
|
||||||
|
|
||||||
|
def _full_name_node(self, name, ctx=ast.Load()):
|
||||||
|
"""Make an Attribute or Name node for name.
|
||||||
|
|
||||||
|
Translate a qualified name into nested Attribute nodes (and a Name node).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to translate to a node.
|
||||||
|
ctx: What context this name is used in. Defaults to Load()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Name or Attribute node.
|
||||||
|
"""
|
||||||
|
names = name.split(".")
|
||||||
|
names.reverse()
|
||||||
|
node = ast.Name(id=names.pop(), ctx=ast.Load())
|
||||||
|
while names:
|
||||||
|
node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
|
||||||
|
|
||||||
|
# Change outermost ctx to the one given to us (inner ones should be Load).
|
||||||
|
node.ctx = ctx
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _maybe_add_warning(self, node, full_name):
|
||||||
|
"""Adds an error to be printed about full_name at node."""
|
||||||
function_warnings = self._api_change_spec.function_warnings
|
function_warnings = self._api_change_spec.function_warnings
|
||||||
try:
|
if full_name in function_warnings:
|
||||||
warning_message = function_warnings[full_name]
|
warning_message = function_warnings[full_name]
|
||||||
warning_message = warning_message.replace("<function name>", full_name)
|
warning_message = warning_message.replace("<function name>", full_name)
|
||||||
self._file_edit.add(warning_message,
|
self.add_error(node.lineno, node.col_offset,
|
||||||
node.lineno, node.col_offset, full_name, full_name,
|
"%s requires manual check: %s." % (full_name,
|
||||||
error="%s requires manual check." % full_name)
|
warning_message))
|
||||||
except KeyError:
|
return True
|
||||||
pass
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def _print_warning_for_function_unrestricted(self, node):
|
def _maybe_add_call_warning(self, node, full_name, name):
|
||||||
"""Print a warning when specific functions are called.
|
"""Print a warning when specific functions are called.
|
||||||
|
|
||||||
The function _print_warning_for_function matches the full name of the called
|
The function _print_warning_for_function matches the full name of the called
|
||||||
@ -216,92 +222,118 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: ast.Call object
|
node: ast.Call object
|
||||||
|
full_name: The precomputed full name of the callable, if one exists, None
|
||||||
|
otherwise.
|
||||||
|
name: The precomputed name of the callable, if one exists, None otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether an error was recorded.
|
||||||
"""
|
"""
|
||||||
function_warnings = getattr(
|
# Only look for *.-warnings here, the other will be handled by the Attribute
|
||||||
self._api_change_spec, "unrestricted_function_warnings", {})
|
# visitor. Also, do not warn for bare functions, only if the call func is
|
||||||
|
# an attribute.
|
||||||
|
warned = False
|
||||||
if isinstance(node.func, ast.Attribute):
|
if isinstance(node.func, ast.Attribute):
|
||||||
function_name = node.func.attr
|
warned = self._maybe_add_warning(node, "*." + name)
|
||||||
try:
|
|
||||||
warning_message = function_warnings[function_name]
|
|
||||||
self._file_edit.add(warning_message,
|
|
||||||
node.lineno, node.col_offset, "", "",
|
|
||||||
error="%s requires manual check." % function_name)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_attribute_full_path(self, node):
|
# All arg warnings are handled here, since only we have the args
|
||||||
"""Traverse an attribute to generate a full name e.g. tf.foo.bar.
|
arg_warnings = self._get_applicable_dict("function_arg_warnings",
|
||||||
|
full_name, name)
|
||||||
|
|
||||||
Args:
|
used_args = [kw.arg for kw in node.keywords]
|
||||||
node: A Node of type Attribute.
|
for arg, warning in arg_warnings.items():
|
||||||
|
if arg in used_args:
|
||||||
|
warned = True
|
||||||
|
warning_message = warning.replace("<function name>", full_name or name)
|
||||||
|
self.add_error(node.lineno, node.col_offset,
|
||||||
|
"%s called with %s argument requires manual check: %s." %
|
||||||
|
(full_name or name, arg, warning_message))
|
||||||
|
|
||||||
Returns:
|
return warned
|
||||||
a '.'-delimited full-name or None if the tree was not a simple form.
|
|
||||||
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
|
|
||||||
"""
|
|
||||||
curr = node
|
|
||||||
items = []
|
|
||||||
while not isinstance(curr, ast.Name):
|
|
||||||
if not isinstance(curr, ast.Attribute):
|
|
||||||
return None, None
|
|
||||||
items.append(curr.attr)
|
|
||||||
curr = curr.value
|
|
||||||
items.append(curr.id)
|
|
||||||
return ".".join(reversed(items)), items[0]
|
|
||||||
|
|
||||||
def _find_true_position(self, node):
|
def _maybe_rename(self, parent, node, full_name):
|
||||||
"""Return correct line number and column offset for a given node.
|
"""Replace node (Attribute or Name) with a node representing full_name."""
|
||||||
|
new_name = self._api_change_spec.symbol_renames.get(full_name, None)
|
||||||
This is necessary mainly because ListComp's location reporting reports
|
if new_name:
|
||||||
the next token after the list comprehension list opening.
|
self.add_log(node.lineno, node.col_offset,
|
||||||
|
"Renamed %r to %r" % (full_name, new_name))
|
||||||
Returns:
|
new_node = self._full_name_node(new_name, node.ctx)
|
||||||
lineno, offset for the given node
|
ast.copy_location(new_node, node)
|
||||||
|
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||||
Args:
|
return True
|
||||||
node: Node for which we wish to know the lineno and col_offset
|
|
||||||
"""
|
|
||||||
if isinstance(node, ast.ListComp):
|
|
||||||
# Strangely, ast.ListComp returns the col_offset of the first token
|
|
||||||
# after the '[' token which appears to be a bug. Workaround by
|
|
||||||
# explicitly finding the real start of the list comprehension.
|
|
||||||
line = node.lineno
|
|
||||||
col = node.col_offset
|
|
||||||
# loop over lines
|
|
||||||
while 1:
|
|
||||||
# Reverse the text to and regular expression search for whitespace
|
|
||||||
text = self._lines[line - 1]
|
|
||||||
reversed_preceding_text = text[:col][::-1]
|
|
||||||
# First find if a [ can be found with only whitespace between it and
|
|
||||||
# col.
|
|
||||||
m = FIND_OPEN.match(reversed_preceding_text)
|
|
||||||
if m:
|
|
||||||
new_col_offset = col - m.start(1) - 1
|
|
||||||
return line, new_col_offset
|
|
||||||
else:
|
else:
|
||||||
if (reversed_preceding_text == "" or
|
return False
|
||||||
reversed_preceding_text.isspace()):
|
|
||||||
line = line - 1
|
def _maybe_change_to_function_call(self, parent, node, full_name):
|
||||||
prev_line = self._lines[line - 1]
|
"""Wraps node (typically, an Attribute or Expr) in a Call."""
|
||||||
# TODO(aselle):
|
if full_name in self._api_change_spec.change_to_function:
|
||||||
# this is poor comment detection, but it is good enough for
|
if not isinstance(parent, ast.Call):
|
||||||
# cases where the comment does not contain string literal starting/
|
# ast.Call's constructor is really picky about how many arguments it
|
||||||
# ending characters. If ast gave us start and end locations of the
|
# wants, and also, it changed between Py2 and Py3.
|
||||||
# ast nodes rather than just start, we could use string literal
|
if six.PY2:
|
||||||
# node ranges to filter out spurious #'s that appear in string
|
new_node = ast.Call(node, [], [], None, None)
|
||||||
# literals.
|
|
||||||
comment_start = prev_line.find("#")
|
|
||||||
if comment_start == -1:
|
|
||||||
col = len(prev_line) - 1
|
|
||||||
elif FIND_STRING_CHARS.search(prev_line[comment_start:]) is None:
|
|
||||||
col = comment_start
|
|
||||||
else:
|
else:
|
||||||
return None, None
|
new_node = ast.Call(node, [], [])
|
||||||
|
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||||
|
ast.copy_location(new_node, node)
|
||||||
|
self.add_log(node.lineno, node.col_offset,
|
||||||
|
"Changed %r to a function call" % full_name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _maybe_add_arg_names(self, node, full_name):
|
||||||
|
"""Make args into keyword args if function called full_name requires it."""
|
||||||
|
function_reorders = self._api_change_spec.function_reorders
|
||||||
|
|
||||||
|
if full_name in function_reorders:
|
||||||
|
reordered = function_reorders[full_name]
|
||||||
|
new_keywords = []
|
||||||
|
for idx, arg in enumerate(node.args):
|
||||||
|
keyword_arg = reordered[idx]
|
||||||
|
new_keywords.append(ast.keyword(arg=keyword_arg, value=arg))
|
||||||
|
|
||||||
|
if new_keywords:
|
||||||
|
self.add_log(node.lineno, node.col_offset,
|
||||||
|
"Added keywords to args of function %r" % full_name)
|
||||||
|
node.args = []
|
||||||
|
node.keywords = new_keywords + (node.keywords or [])
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _maybe_modify_args(self, node, full_name, name):
|
||||||
|
"""Rename keyword args if the function called full_name requires it."""
|
||||||
|
renamed_keywords = self._get_applicable_dict("function_keyword_renames",
|
||||||
|
full_name, name)
|
||||||
|
|
||||||
|
if not renamed_keywords:
|
||||||
|
return False
|
||||||
|
|
||||||
|
modified = False
|
||||||
|
new_keywords = []
|
||||||
|
for keyword in node.keywords:
|
||||||
|
argkey = keyword.arg
|
||||||
|
if argkey in renamed_keywords:
|
||||||
|
modified = True
|
||||||
|
if renamed_keywords[argkey] is None:
|
||||||
|
lineno = getattr(keyword, "lineno", node.lineno)
|
||||||
|
col_offset = getattr(keyword, "col_offset", node.col_offset)
|
||||||
|
self.add_log(lineno, col_offset,
|
||||||
|
"Removed argument %s for function %s" % (
|
||||||
|
argkey, full_name or name))
|
||||||
else:
|
else:
|
||||||
return None, None
|
keyword.arg = renamed_keywords[argkey]
|
||||||
# Most other nodes return proper locations (with notably does not), but
|
lineno = getattr(keyword, "lineno", node.lineno)
|
||||||
# it is not possible to use that in an argument.
|
col_offset = getattr(keyword, "col_offset", node.col_offset)
|
||||||
return node.lineno, node.col_offset
|
self.add_log(lineno, col_offset,
|
||||||
|
"Renamed keyword argument for %s from %s to %s" % (
|
||||||
|
full_name, argkey, renamed_keywords[argkey]))
|
||||||
|
new_keywords.append(keyword)
|
||||||
|
else:
|
||||||
|
new_keywords.append(keyword)
|
||||||
|
|
||||||
|
if modified:
|
||||||
|
node.keywords = new_keywords
|
||||||
|
return modified
|
||||||
|
|
||||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||||
"""Handle visiting a call node in the AST.
|
"""Handle visiting a call node in the AST.
|
||||||
@ -309,104 +341,74 @@ class _ASTCallVisitor(ast.NodeVisitor):
|
|||||||
Args:
|
Args:
|
||||||
node: Current Node
|
node: Current Node
|
||||||
"""
|
"""
|
||||||
self._print_warning_for_function_unrestricted(node)
|
assert self._stack[-1] is node
|
||||||
|
|
||||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
|
||||||
full_name, name = self._get_attribute_full_path(node.func)
|
|
||||||
|
|
||||||
# Make sure the func is marked as being part of a call
|
|
||||||
node.func.is_function_for_call = True
|
|
||||||
|
|
||||||
|
# Get the name for this call, so we can index stuff with it.
|
||||||
|
full_name = self._get_full_name(node.func)
|
||||||
if full_name:
|
if full_name:
|
||||||
# Call special handlers
|
name = full_name.split(".")[-1]
|
||||||
function_handles = self._api_change_spec.function_handle
|
elif isinstance(node.func, ast.Name):
|
||||||
glob_name = "*.{}".format(name)
|
name = node.func.id
|
||||||
if glob_name in function_handles:
|
elif isinstance(node.func, ast.Attribute):
|
||||||
function_handles[glob_name](self._file_edit, node, self._lines)
|
name = node.func.attr
|
||||||
if full_name in function_handles:
|
|
||||||
function_handles[full_name](self._file_edit, node, self._lines)
|
|
||||||
|
|
||||||
# Examine any non-keyword argument and make it into a keyword argument
|
|
||||||
# if reordering required.
|
|
||||||
function_reorders = self._api_change_spec.function_reorders
|
|
||||||
function_keyword_renames = (
|
|
||||||
self._api_change_spec.function_keyword_renames)
|
|
||||||
|
|
||||||
if full_name in function_reorders:
|
|
||||||
reordered = function_reorders[full_name]
|
|
||||||
for idx, arg in enumerate(node.args):
|
|
||||||
lineno, col_offset = self._find_true_position(arg)
|
|
||||||
if lineno is None or col_offset is None:
|
|
||||||
self._file_edit.add(
|
|
||||||
"Failed to add keyword %r to reordered function %r" %
|
|
||||||
(reordered[idx], full_name),
|
|
||||||
arg.lineno,
|
|
||||||
arg.col_offset,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
error="A necessary keyword argument failed to be inserted.")
|
|
||||||
else:
|
else:
|
||||||
keyword_arg = reordered[idx]
|
name = None
|
||||||
if (full_name in function_keyword_renames and
|
|
||||||
keyword_arg in function_keyword_renames[full_name]):
|
|
||||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
|
||||||
self._file_edit.add("Added keyword %r to reordered function %r" %
|
|
||||||
(reordered[idx], full_name), lineno, col_offset,
|
|
||||||
"", keyword_arg + "=")
|
|
||||||
|
|
||||||
# Examine each keyword argument and convert it to the final renamed form
|
# Call standard transformers for this node.
|
||||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
# Make sure warnings come first, since args or names triggering warnings
|
||||||
function_keyword_renames[full_name])
|
# may be removed by the other transformations.
|
||||||
for keyword in node.keywords:
|
self._maybe_add_call_warning(node, full_name, name)
|
||||||
argkey = keyword.arg
|
# Make all args into kwargs
|
||||||
argval = keyword.value
|
self._maybe_add_arg_names(node, full_name)
|
||||||
|
# Argument name changes or deletions
|
||||||
|
self._maybe_modify_args(node, full_name, name)
|
||||||
|
|
||||||
if argkey in renamed_keywords:
|
# Call transformers. These have the ability to modify the node, and if they
|
||||||
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
# do, will return the new node they created (or the same node if they just
|
||||||
if argval_lineno is not None and argval_col_offset is not None:
|
# changed it). The are given the parent, but we will take care of
|
||||||
# TODO(aselle): We should scan backward to find the start of the
|
# integrating their changes into the parent if they return a new node.
|
||||||
# keyword key. Unfortunately ast does not give you the location of
|
#
|
||||||
# keyword keys, so we are forced to infer it from the keyword arg
|
# These are matched on the old name, since renaming is performed by the
|
||||||
# value.
|
# Attribute visitor, which happens later.
|
||||||
key_start = argval_col_offset - len(argkey) - 1
|
transformers = self._get_applicable_entries("function_transformers",
|
||||||
key_end = key_start + len(argkey) + 1
|
full_name, name)
|
||||||
if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
|
|
||||||
"="):
|
|
||||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
|
||||||
(argkey,
|
|
||||||
renamed_keywords[argkey]), argval_lineno,
|
|
||||||
argval_col_offset - len(argkey) - 1,
|
|
||||||
argkey + "=", renamed_keywords[argkey] + "=")
|
|
||||||
continue
|
|
||||||
self._file_edit.add(
|
|
||||||
"Failed to rename keyword argument from %r to %r" %
|
|
||||||
(argkey, renamed_keywords[argkey]),
|
|
||||||
argval.lineno,
|
|
||||||
argval.col_offset - len(argkey) - 1,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
error="Failed to find keyword lexographically. Fix manually.")
|
|
||||||
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
parent = self._stack[-2]
|
||||||
|
|
||||||
|
for transformer in transformers:
|
||||||
|
logs = []
|
||||||
|
errors = []
|
||||||
|
new_node = transformer(parent, node, full_name, name, logs, errors)
|
||||||
|
self.add_logs(logs)
|
||||||
|
self.add_errors(errors)
|
||||||
|
if new_node:
|
||||||
|
if new_node is not node:
|
||||||
|
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||||
|
node = new_node
|
||||||
|
self._stack[-1] = node
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
|
||||||
|
assert self._stack[-1] is node
|
||||||
|
|
||||||
Args:
|
full_name = self._get_full_name(node)
|
||||||
node: Node that is of type ast.Attribute
|
|
||||||
"""
|
|
||||||
full_name, _ = self._get_attribute_full_path(node)
|
|
||||||
if full_name:
|
if full_name:
|
||||||
# Make sure the warning comes first, otherwise the name may have changed
|
parent = self._stack[-2]
|
||||||
self._print_warning_for_function(node, full_name)
|
|
||||||
self._rename_functions(node, full_name)
|
|
||||||
if full_name in self._api_change_spec.change_to_function:
|
|
||||||
if not hasattr(node, "is_function_for_call"):
|
|
||||||
new_text = full_name + "()"
|
|
||||||
self._file_edit.add("Changed %r to %r" % (full_name, new_text),
|
|
||||||
node.lineno, node.col_offset, full_name, new_text)
|
|
||||||
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
# Make sure the warning comes first, otherwise the name may have changed
|
||||||
|
self._maybe_add_warning(node, full_name)
|
||||||
|
|
||||||
|
# Once we did a modification, node is invalid and not worth inspecting
|
||||||
|
# further. Also, we only perform modifications for simple nodes, so
|
||||||
|
# There'd be no point in descending further.
|
||||||
|
if self._maybe_rename(parent, node, full_name):
|
||||||
|
return
|
||||||
|
if self._maybe_change_to_function_call(parent, node, full_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
|
||||||
class ASTCodeUpgrader(object):
|
class ASTCodeUpgrader(object):
|
||||||
@ -429,16 +431,42 @@ class ASTCodeUpgrader(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Write to a temporary file, just in case we are doing an implace modify.
|
# Write to a temporary file, just in case we are doing an implace modify.
|
||||||
|
# pylint: disable=g-backslash-continuation
|
||||||
with open(in_filename, "r") as in_file, \
|
with open(in_filename, "r") as in_file, \
|
||||||
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||||
ret = self.process_opened_file(in_filename, in_file, out_filename,
|
ret = self.process_opened_file(in_filename, in_file, out_filename,
|
||||||
temp_file)
|
temp_file)
|
||||||
|
# pylint: enable=g-backslash-continuation
|
||||||
|
|
||||||
shutil.move(temp_file.name, out_filename)
|
shutil.move(temp_file.name, out_filename)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# Broad exceptions are required here because ast throws whatever it wants.
|
def _format_errors(self, errors, in_filename):
|
||||||
# pylint: disable=broad-except
|
return ["%s:%d:%d: %s" % ((in_filename,) + error) for error in errors]
|
||||||
|
|
||||||
|
def update_string_pasta(self, text, in_filename):
|
||||||
|
"""Updates a file using pasta."""
|
||||||
|
try:
|
||||||
|
t = pasta.parse(text)
|
||||||
|
except (SyntaxError, ValueError, TypeError):
|
||||||
|
log = "Failed to parse.\n\n" + traceback.format_exc()
|
||||||
|
return 0, "", log, []
|
||||||
|
|
||||||
|
visitor = _PastaEditVisitor(self._api_change_spec)
|
||||||
|
visitor.visit(t)
|
||||||
|
|
||||||
|
errors = self._format_errors(visitor.errors, in_filename)
|
||||||
|
return 1, pasta.dump(t), visitor.log_text(), errors
|
||||||
|
|
||||||
|
def _format_log(self, log, in_filename, out_filename):
|
||||||
|
text = "-" * 80 + "\n"
|
||||||
|
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
||||||
|
out_filename)
|
||||||
|
text += "-" * 80 + "\n\n"
|
||||||
|
text += log
|
||||||
|
text += "-" * 80 + "\n\n"
|
||||||
|
return text
|
||||||
|
|
||||||
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
|
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
|
||||||
"""Process the given python file for incompatible changes.
|
"""Process the given python file for incompatible changes.
|
||||||
|
|
||||||
@ -453,30 +481,16 @@ class ASTCodeUpgrader(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple representing number of files processed, log of actions, errors
|
A tuple representing number of files processed, log of actions, errors
|
||||||
"""
|
"""
|
||||||
process_errors = []
|
|
||||||
text = "-" * 80 + "\n"
|
|
||||||
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
|
||||||
out_filename)
|
|
||||||
text += "-" * 80 + "\n\n"
|
|
||||||
|
|
||||||
parsed_ast = None
|
|
||||||
lines = in_file.readlines()
|
lines = in_file.readlines()
|
||||||
try:
|
processed_file, new_file_content, log, process_errors = (
|
||||||
parsed_ast = ast.parse("".join(lines))
|
self.update_string_pasta("".join(lines), in_filename))
|
||||||
except Exception:
|
|
||||||
text += "Failed to parse %r\n\n" % in_filename
|
|
||||||
text += traceback.format_exc()
|
|
||||||
if parsed_ast:
|
|
||||||
visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
|
|
||||||
visitor.visit(parsed_ast)
|
|
||||||
out_text, new_text, process_errors = visitor.process(lines)
|
|
||||||
text += new_text
|
|
||||||
if out_file:
|
|
||||||
out_file.write(out_text)
|
|
||||||
text += "\n"
|
|
||||||
return 1, text, process_errors
|
|
||||||
|
|
||||||
# pylint: enable=broad-except
|
if out_file and processed_file:
|
||||||
|
out_file.write(new_file_content)
|
||||||
|
|
||||||
|
return (processed_file,
|
||||||
|
self._format_log(log, in_filename, out_filename),
|
||||||
|
process_errors)
|
||||||
|
|
||||||
def process_tree(self, root_directory, output_root_directory,
|
def process_tree(self, root_directory, output_root_directory,
|
||||||
copy_other_files):
|
copy_other_files):
|
||||||
|
@ -39,6 +39,8 @@ following new APIs:
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
import ast
|
||||||
|
import pasta
|
||||||
import six
|
import six
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test as test_lib
|
from tensorflow.python.platform import test as test_lib
|
||||||
@ -54,7 +56,6 @@ class NoUpdateSpec(ast_edits.APIChangeSpec):
|
|||||||
self.function_keyword_renames = {}
|
self.function_keyword_renames = {}
|
||||||
self.symbol_renames = {}
|
self.symbol_renames = {}
|
||||||
self.function_warnings = {}
|
self.function_warnings = {}
|
||||||
self.unrestricted_function_warnings = {}
|
|
||||||
self.change_to_function = {}
|
self.change_to_function = {}
|
||||||
|
|
||||||
|
|
||||||
@ -401,7 +402,8 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
NoUpdateSpec.__init__(self)
|
||||||
self.unrestricted_function_warnings = {"foo": "not good"}
|
self.function_warnings = {"*.foo": "not good"}
|
||||||
|
|
||||||
texts = ["object.foo()", "get_object().foo()",
|
texts = ["object.foo()", "get_object().foo()",
|
||||||
"get_object().foo()", "object.foo().bar()"]
|
"get_object().foo()", "object.foo().bar()"]
|
||||||
for text in texts:
|
for text in texts:
|
||||||
@ -416,5 +418,26 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
|||||||
self.assertNotIn("not good", report)
|
self.assertNotIn("not good", report)
|
||||||
|
|
||||||
|
|
||||||
|
class ManualEditsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def disabled_test_arg_order(self):
|
||||||
|
"""Tests that generated arg order is sane."""
|
||||||
|
text = "f(a)"
|
||||||
|
t = pasta.parse(text)
|
||||||
|
node = pasta.ast_utils.find_nodes_by_type(t, (ast.Call,))[0]
|
||||||
|
arg = ast.keyword(arg="b", value=ast.Num(n=0))
|
||||||
|
node.keywords.append(arg)
|
||||||
|
|
||||||
|
# This is only needed in Python3, and I think it's a bug (but maybe in ast).
|
||||||
|
arg.value.lineno = 0
|
||||||
|
arg.value.col_offset = 0
|
||||||
|
|
||||||
|
# pasta.dump should never put kwargs before args, even if the col_offset is
|
||||||
|
# messed up.
|
||||||
|
# This fails if run with python3, but works find for python2.
|
||||||
|
# In python3, the dump yields "f(b=0, a)".
|
||||||
|
self.assertEqual(pasta.dump(t), "f(a, b=0)")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
@ -175,27 +175,13 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
"tf.op_scope": ["values", "name", "default_name"],
|
"tf.op_scope": ["values", "name", "default_name"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Specially handled functions.
|
|
||||||
self.function_handle = {"tf.reverse": self._reverse_handler}
|
|
||||||
|
|
||||||
# Warnings that should be printed if corresponding functions are used.
|
# Warnings that should be printed if corresponding functions are used.
|
||||||
self.function_warnings = {}
|
self.function_warnings = {
|
||||||
|
"tf.reverse":
|
||||||
@staticmethod
|
"ERROR: tf.reverse has had its argument semantics changed "
|
||||||
def _reverse_handler(file_edit_recorder, node, lines):
|
"significantly. The converter cannot detect this reliably, so "
|
||||||
del lines
|
"you need to inspect this usage manually.\n",
|
||||||
# TODO(aselle): Could check for a literal list of bools and try to convert
|
}
|
||||||
# them to indices.
|
|
||||||
comment = ("ERROR: tf.reverse has had its argument semantics changed "
|
|
||||||
"significantly the converter cannot detect this reliably, so "
|
|
||||||
"you need to inspect this usage manually.\n")
|
|
||||||
file_edit_recorder.add(
|
|
||||||
comment,
|
|
||||||
node.lineno,
|
|
||||||
node.col_offset,
|
|
||||||
"tf.reverse",
|
|
||||||
"tf.reverse",
|
|
||||||
error="tf.reverse requires manual check.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -112,7 +112,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
|||||||
text = "tf.reverse(a, b)\n"
|
text = "tf.reverse(a, b)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, new_text)
|
self.assertEqual(new_text, new_text)
|
||||||
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
|
self.assertIn("tf.reverse requires manual check", errors[0])
|
||||||
|
|
||||||
def testListComprehension(self):
|
def testListComprehension(self):
|
||||||
def _test(input, output): # pylint: disable=redefined-builtin
|
def _test(input, output): # pylint: disable=redefined-builtin
|
||||||
|
@ -18,7 +18,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import re
|
import ast
|
||||||
|
|
||||||
|
import pasta
|
||||||
|
|
||||||
from tensorflow.tools.compatibility import ast_edits
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
from tensorflow.tools.compatibility import renames_v2
|
from tensorflow.tools.compatibility import renames_v2
|
||||||
@ -31,7 +33,22 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Maps from a function name to a dictionary that describes how to
|
# Maps from a function name to a dictionary that describes how to
|
||||||
# map from an old argument keyword to the new argument keyword.
|
# map from an old argument keyword to the new argument keyword.
|
||||||
|
# If the new argument is None, it will be removed.
|
||||||
|
# Only keyword args are handled, so make sure to also put any function in
|
||||||
|
# function_reorders to ensure that all args are made into keywords first.
|
||||||
self.function_keyword_renames = {
|
self.function_keyword_renames = {
|
||||||
|
"tf.gradients": {
|
||||||
|
"colocate_gradients_with_ops": None,
|
||||||
|
},
|
||||||
|
"tf.hessians": {
|
||||||
|
"colocate_gradients_with_ops": None,
|
||||||
|
},
|
||||||
|
"*.minimize": {
|
||||||
|
"colocate_gradients_with_ops": None,
|
||||||
|
},
|
||||||
|
"*.compute_gradients": {
|
||||||
|
"colocate_gradients_with_ops": None,
|
||||||
|
},
|
||||||
"tf.argmin": {
|
"tf.argmin": {
|
||||||
"dimension": "axis",
|
"dimension": "axis",
|
||||||
},
|
},
|
||||||
@ -672,14 +689,31 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
# positional arguments yourself, this could do the wrong thing.
|
# positional arguments yourself, this could do the wrong thing.
|
||||||
self.function_reorders = reorders_v2.reorders
|
self.function_reorders = reorders_v2.reorders
|
||||||
|
|
||||||
# Specially handled functions.
|
# Specially handled functions (pasta version)
|
||||||
self.function_handle = {
|
# Each transformer is a callable which will be called with the arguments
|
||||||
"tf.batch_gather": self._batch_gather_handler,
|
# transformer(parent, node, full_name, name, logs, errors)
|
||||||
"tf.nn.dropout": self._dropout_handler,
|
# Where logs and errors are lists to which (line, col, msg) tuples can be
|
||||||
"tf.gradients": self._colocate_handler("tf.gradients"),
|
# appended, full_name is the FQN of the function called (or None if that is
|
||||||
"*.minimize": self._colocate_handler("Optimizer.minimize"),
|
# unknown), name is the name of the function called (or None is that is
|
||||||
"*.compute_gradients":
|
# unknown). node is an ast.Call node representing this function call, and
|
||||||
self._colocate_handler("Optimizer.compute_gradients"),
|
# parent is its parent in the AST.
|
||||||
|
# The function may modify node (but not parent), and must return
|
||||||
|
# - none, if nothing was modified
|
||||||
|
# - node, if node was modified in place (make sure to use
|
||||||
|
# pasta.ast_utils.replace_child to swap out children, otherwise formatting
|
||||||
|
# may get messy)
|
||||||
|
# - a replacement for node, if the whole call node was replaced. The caller
|
||||||
|
# will take care of changing parent.
|
||||||
|
self.function_transformers = {
|
||||||
|
"tf.nn.dropout": self._dropout_transformer,
|
||||||
|
"tf.batch_gather": self._batch_gather_transformer,
|
||||||
|
"tf.to_bfloat16": self._cast_transformer,
|
||||||
|
"tf.to_complex128": self._cast_transformer,
|
||||||
|
"tf.to_complex64": self._cast_transformer,
|
||||||
|
"tf.to_double": self._cast_transformer,
|
||||||
|
"tf.to_float": self._cast_transformer,
|
||||||
|
"tf.to_int32": self._cast_transformer,
|
||||||
|
"tf.to_int64": self._cast_transformer,
|
||||||
}
|
}
|
||||||
|
|
||||||
decay_function_comment = (
|
decay_function_comment = (
|
||||||
@ -748,9 +782,45 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
"compat.v1 for backward compatibility. Please update these calls to "
|
"compat.v1 for backward compatibility. Please update these calls to "
|
||||||
"the TF 2.0 versions.")
|
"the TF 2.0 versions.")
|
||||||
|
|
||||||
|
export_saved_model_renamed = (
|
||||||
|
"(Manual edit required) Please rename the method export_savedmodel() "
|
||||||
|
"to export_saved_model(). Two things to note:\n\t(1) The argument "
|
||||||
|
"strip_default_attributes has been removed. The function will always "
|
||||||
|
"strip the default attributes from ops. If this breaks your code, "
|
||||||
|
"please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
|
||||||
|
"only effects core estimator. If you are using "
|
||||||
|
"tf.contrib.learn.Estimator, please switch to using core estimator.")
|
||||||
|
|
||||||
|
make_initializable_iterator_deprecation = (
|
||||||
|
"(Manual edit required) The "
|
||||||
|
"`tf.data.Dataset.make_initializable_iterator()` method has been "
|
||||||
|
"removed. If you are using the Estimator API, you can return a dataset "
|
||||||
|
"directly from your input functions without creating an iterator. "
|
||||||
|
"As a last resort, please replace calls to that method on `dataset` "
|
||||||
|
"with a call to "
|
||||||
|
"`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
|
||||||
|
|
||||||
|
make_one_shot_iterator_deprecation = (
|
||||||
|
"(Manual edit required) The "
|
||||||
|
"`tf.data.Dataset.make_one_shot_iterator()` method has been "
|
||||||
|
"removed. If you are using eager execution, you can iterate over "
|
||||||
|
"`dataset` using a Python `for` loop. If you are using the Estimator "
|
||||||
|
"API, you can return a dataset directly from your input functions "
|
||||||
|
"without creating an iterator. As a last resort, please replace calls "
|
||||||
|
"to that method on `dataset` with a call to "
|
||||||
|
"`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
|
||||||
|
|
||||||
# Function warnings. <function name> placeholder inside warnings will be
|
# Function warnings. <function name> placeholder inside warnings will be
|
||||||
# replaced by function name.
|
# replaced by function name.
|
||||||
|
# You can use *. to add items which do not check the FQN, and apply to e.g.,
|
||||||
|
# methods.
|
||||||
self.function_warnings = {
|
self.function_warnings = {
|
||||||
|
"*.export_savedmodel":
|
||||||
|
export_saved_model_renamed,
|
||||||
|
"*.make_initializable_iterator":
|
||||||
|
make_initializable_iterator_deprecation,
|
||||||
|
"*.make_one_shot_iterator":
|
||||||
|
make_one_shot_iterator_deprecation,
|
||||||
"tf.assert_greater":
|
"tf.assert_greater":
|
||||||
assert_return_type_comment,
|
assert_return_type_comment,
|
||||||
"tf.assert_equal":
|
"tf.assert_equal":
|
||||||
@ -834,11 +904,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
default_loss_reduction_changed,
|
default_loss_reduction_changed,
|
||||||
"tf.estimator.BaselineRegressor":
|
"tf.estimator.BaselineRegressor":
|
||||||
default_loss_reduction_changed,
|
default_loss_reduction_changed,
|
||||||
"tf.hessians":
|
|
||||||
"tf.hessians no longer takes "
|
|
||||||
"'colocate_gradients_with_ops' argument. Also, "
|
|
||||||
"arguments have been reordered so that 'name' is the "
|
|
||||||
"last argument.",
|
|
||||||
"tf.nn.conv1d":
|
"tf.nn.conv1d":
|
||||||
"WARNING: use_cudnn_on_gpu argument has been removed and \"value\""
|
"WARNING: use_cudnn_on_gpu argument has been removed and \"value\""
|
||||||
" was renamed to \"input\"",
|
" was renamed to \"input\"",
|
||||||
@ -1064,111 +1129,118 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
metrics_comment,
|
metrics_comment,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Warnings that are emitted only if a specific arg is found.
|
||||||
|
self.function_arg_warnings = {
|
||||||
|
"tf.gradients": {
|
||||||
|
"colocate_gradients_with_ops":
|
||||||
|
"tf.gradients no longer takes "
|
||||||
|
"'colocate_gradients_with_ops' argument, it behaves as if it "
|
||||||
|
"was set to True.",
|
||||||
|
},
|
||||||
|
"*.minimize": {
|
||||||
|
"colocate_gradients_with_ops":
|
||||||
|
"Optimizer.minimize no longer takes "
|
||||||
|
"'colocate_gradients_with_ops' argument, it behaves as if it "
|
||||||
|
"was set to True.",
|
||||||
|
},
|
||||||
|
"*.compute_gradients": {
|
||||||
|
"colocate_gradients_with_ops":
|
||||||
|
"Optimizer.compute_gradients no "
|
||||||
|
"longer takes 'colocate_gradients_with_ops' argument, it "
|
||||||
|
"behaves as if it was set to True.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
self.symbol_renames = {
|
self.symbol_renames = {
|
||||||
name: new_name
|
name: new_name
|
||||||
for name, new_name in self.symbol_renames.items()
|
for name, new_name in self.symbol_renames.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
export_saved_model_renamed = (
|
|
||||||
"(Manual edit required) Please rename the method export_savedmodel() "
|
|
||||||
"to export_saved_model(). Two things to note:\n\t(1) The argument "
|
|
||||||
"strip_default_attributes has been removed. The function will always "
|
|
||||||
"strip the default attributes from ops. If this breaks your code, "
|
|
||||||
"please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
|
|
||||||
"only effects core estimator. If you are using "
|
|
||||||
"tf.contrib.learn.Estimator, please switch to using core estimator.")
|
|
||||||
|
|
||||||
make_initializable_iterator_deprecation = (
|
|
||||||
"(Manual edit required) The "
|
|
||||||
"`tf.data.Dataset.make_initializable_iterator()` method has been "
|
|
||||||
"removed. If you are using the Estimator API, you can return a dataset "
|
|
||||||
"directly from your input functions without creating an iterator. "
|
|
||||||
"As a last resort, please replace calls to that method on `dataset` "
|
|
||||||
"with a call to "
|
|
||||||
"`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
|
|
||||||
|
|
||||||
make_one_shot_iterator_deprecation = (
|
|
||||||
"(Manual edit required) The "
|
|
||||||
"`tf.data.Dataset.make_one_shot_iterator()` method has been "
|
|
||||||
"removed. If you are using eager execution, you can iterate over "
|
|
||||||
"`dataset` using a Python `for` loop. If you are using the Estimator "
|
|
||||||
"API, you can return a dataset directly from your input functions "
|
|
||||||
"without creating an iterator. As a last resort, please replace calls "
|
|
||||||
"to that method on `dataset` with a call to "
|
|
||||||
"`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
|
|
||||||
|
|
||||||
# Specify warnings for functions that aren't restricted to the tf.x.y.z
|
|
||||||
# format. This should only be used for methods with unique names, e.g.
|
|
||||||
# export_savedmodel, which is only defined in Estimator objects.
|
|
||||||
self.unrestricted_function_warnings = {
|
|
||||||
"export_savedmodel": export_saved_model_renamed,
|
|
||||||
"make_initializable_iterator": make_initializable_iterator_deprecation,
|
|
||||||
"make_one_shot_iterator": make_one_shot_iterator_deprecation,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dropout_handler(file_edit_recorder, node, lines):
|
def _dropout_transformer(parent, node, full_name, name, logs, errors):
|
||||||
del lines
|
def _replace_keep_prob_node(parent, old_value):
|
||||||
|
"""Replaces old_value with 1-(old_value)."""
|
||||||
|
one = ast.Num(n=1)
|
||||||
|
one.lineno = 0
|
||||||
|
one.col_offset = 0
|
||||||
|
new_value = ast.BinOp(left=one, op=ast.Sub(),
|
||||||
|
right=old_value)
|
||||||
|
# This copies the prefix and suffix on old_value to new_value.
|
||||||
|
pasta.ast_utils.replace_child(parent, old_value, new_value)
|
||||||
|
ast.copy_location(new_value, old_value)
|
||||||
|
# Put parentheses around keep_prob.value (and remove the old prefix/
|
||||||
|
# suffix, they should only be around new_value).
|
||||||
|
pasta.base.formatting.set(old_value, "prefix", "(")
|
||||||
|
pasta.base.formatting.set(old_value, "suffix", ")")
|
||||||
|
|
||||||
|
# Check if we have a keep_prob keyword arg
|
||||||
|
for keep_prob in node.keywords:
|
||||||
|
if keep_prob.arg == "keep_prob":
|
||||||
|
logs.append((node.lineno, node.col_offset,
|
||||||
|
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
||||||
|
"recomputing value. Please check this transformation.\n"))
|
||||||
|
keep_prob.arg = "rate"
|
||||||
|
_replace_keep_prob_node(keep_prob, keep_prob.value)
|
||||||
|
return node
|
||||||
|
|
||||||
|
# Maybe it was a positional arg
|
||||||
if len(node.args) < 2:
|
if len(node.args) < 2:
|
||||||
comment = ("ERROR: tf.nn.dropout did not take arguments, so automatic "
|
errors.append((node.lineno, node.col_offset,
|
||||||
"transformation was disabled. tf.nn.dropout has changed "
|
"ERROR: tf.nn.dropout called without arguments, so "
|
||||||
"the semantics of the second argument.")
|
"automatic fix was disabled. tf.nn.dropout has changed "
|
||||||
file_edit_recorder.add(
|
"the semantics of the second argument."))
|
||||||
comment,
|
|
||||||
node.lineno,
|
|
||||||
node.col_offset,
|
|
||||||
"tf.nn.dropout",
|
|
||||||
"tf.nn.dropout",
|
|
||||||
error="tf.nn.dropout requires manual check.")
|
|
||||||
else:
|
else:
|
||||||
comment = ("WARNING: tf.nn.dropout has changed the semantics of the "
|
_replace_keep_prob_node(node, node.args[1])
|
||||||
"second argument. Please check the transformation.\n")
|
logs.append((node.lineno, node.col_offset,
|
||||||
file_edit_recorder.add(
|
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
||||||
comment,
|
"recomputing value.\n"))
|
||||||
node.args[1].lineno,
|
errors.append((node.lineno, node.col_offset,
|
||||||
node.args[1].col_offset,
|
"WARNING: tf.nn.dropout has changed the semantics of the "
|
||||||
"",
|
"second argument. Please check the applied transformation."
|
||||||
"1 - ")
|
))
|
||||||
|
return node
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _colocate_handler(name):
|
def _cast_transformer(parent, node, full_name, name, logs, errors):
|
||||||
def _helper(file_edit_recorder, node, lines):
|
"""Transforms to_int and to_float to cast(..., dtype=...)."""
|
||||||
"""Handler for updating colocate arguments."""
|
|
||||||
del lines
|
# Find out the dtype to cast to from the function name
|
||||||
for keyword in node.keywords:
|
dtype_str = name[3:]
|
||||||
if keyword.arg == "colocate_gradients_with_ops":
|
new_arg = ast.keyword(arg="dtype",
|
||||||
# TODO(jhseu): Since ast_edit.py does string replacement, there's no
|
value=ast.Attribute(value=ast.Name(id="tf",
|
||||||
# straightforward way to remove the argument. Try to fix before 2.0 is
|
ctx=ast.Load()),
|
||||||
# final.
|
attr=dtype_str, ctx=ast.Load()))
|
||||||
comment = ("For tf.gradients and tf.Optimizer.minimize, "
|
|
||||||
"colocate_gradients_with_op has been removed and now "
|
# Python3 ast requires the args for the Attribute, but codegen will mess up
|
||||||
"defaults to True.")
|
# the arg order if we just set them to 0.
|
||||||
file_edit_recorder.add(
|
new_arg.value.lineno = node.lineno
|
||||||
comment,
|
new_arg.value.col_offset = node.col_offset+100
|
||||||
node.lineno,
|
|
||||||
node.col_offset,
|
node.keywords.append(new_arg)
|
||||||
"",
|
if isinstance(node.func, ast.Attribute):
|
||||||
"",
|
node.func.attr = "cast"
|
||||||
error="{} requires manual check.".format(name))
|
else:
|
||||||
return _helper
|
assert isinstance(node.func, ast.Name)
|
||||||
|
node.func.id = "cast"
|
||||||
|
|
||||||
|
logs.append((node.lineno, node.col_offset,
|
||||||
|
"Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name,
|
||||||
|
dtype_str)))
|
||||||
|
return node
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _batch_gather_handler(file_edit_recorder, node, lines):
|
def _batch_gather_transformer(parent, node, full_name, name, logs, errors):
|
||||||
lineno = node.lineno
|
# Check if the call already has a batch_dims argument
|
||||||
column = node.col_offset
|
if any([kw.arg == "batch_dims" for kw in node.keywords]):
|
||||||
|
logs.append((node.lineno, node.col_offset, "tf.batch_gather already has "
|
||||||
|
"batch_dims argument. Neat."))
|
||||||
|
return None
|
||||||
|
|
||||||
# Find the position to add the batch_dims argument. We add it as the
|
minus_one = ast.Num(n=-1)
|
||||||
# first argument, since that's easiest. This is safe because we included
|
minus_one.lineno = 0
|
||||||
# batch_gather in self.reordered_function_names, so it will have all
|
minus_one.col_offset = 0
|
||||||
# of its arguments changed to keyword arguments.
|
new_arg = ast.keyword("batch_dims", minus_one)
|
||||||
m = re.match(r"tf\s*\.\s*batch_gather\s*\(", lines[lineno - 1][column:])
|
node.keywords.append(new_arg)
|
||||||
if m is not None:
|
logs.append((node.lineno, node.col_offset,
|
||||||
file_edit_recorder.add(
|
"Added keyword argument batch_dims=-1 to tf.batch_gather."))
|
||||||
"Added keyword argument 'batch_dims=-1' to 'tf.batch_gather'",
|
return node
|
||||||
lineno, column + m.end(), "", "batch_dims=-1, ")
|
|
||||||
else:
|
|
||||||
file_edit_recorder.add(
|
|
||||||
"Unable to add keyword argument 'batch_dims=-1' to 'tf.batch_gather'",
|
|
||||||
lineno, column, "", "",
|
|
||||||
error="Unable to add keyword argument batch_dims=-1 to "
|
|
||||||
"tf.batch_gather; please add it manually.")
|
|
||||||
|
@ -239,8 +239,8 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
|||||||
}
|
}
|
||||||
function_warnings = (
|
function_warnings = (
|
||||||
tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
|
tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
|
||||||
function_handles = (
|
function_transformers = (
|
||||||
tf_upgrade_v2.TFAPIChangeSpec().function_handle)
|
tf_upgrade_v2.TFAPIChangeSpec().function_transformers)
|
||||||
keyword_renames = (
|
keyword_renames = (
|
||||||
tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
|
tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
for name in names_v1:
|
for name in names_v1:
|
||||||
tf_name = "tf.%s" % name
|
tf_name = "tf.%s" % name
|
||||||
if tf_name in function_warnings or tf_name in function_handles:
|
if tf_name in function_warnings or tf_name in function_transformers:
|
||||||
continue # These require manual change
|
continue # These require manual change
|
||||||
if tf_name in v1_name_exceptions:
|
if tf_name in v1_name_exceptions:
|
||||||
continue
|
continue
|
||||||
@ -362,15 +362,14 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
|
|
||||||
text = "%s(a, b)\n" % decay
|
text = "%s(a, b)\n" % decay
|
||||||
_, report, errors, _ = self._upgrade(text)
|
_, report, errors, _ = self._upgrade(text)
|
||||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
|
self.assertIn("%s requires manual check" % decay, errors[0])
|
||||||
self.assertIn("%s has been changed" % decay, report)
|
self.assertIn("%s has been changed" % decay, report)
|
||||||
|
|
||||||
def testPiecewiseDecay(self):
|
def testPiecewiseDecay(self):
|
||||||
text = "tf.train.piecewise_constant_decay(a, b)\n"
|
text = "tf.train.piecewise_constant_decay(a, b)\n"
|
||||||
_, report, errors, _ = self._upgrade(text)
|
_, report, errors, _ = self._upgrade(text)
|
||||||
self.assertEqual(
|
self.assertIn("tf.train.piecewise_constant_decay requires manual check",
|
||||||
errors,
|
errors[0])
|
||||||
["test.py:1: tf.train.piecewise_constant_decay requires manual check."])
|
|
||||||
self.assertIn("tf.train.piecewise_constant_decay has been changed", report)
|
self.assertIn("tf.train.piecewise_constant_decay has been changed", report)
|
||||||
|
|
||||||
def testMetrics(self):
|
def testMetrics(self):
|
||||||
@ -414,7 +413,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
text = ns + "(a, b)"
|
text = ns + "(a, b)"
|
||||||
_, report, errors, new_text = self._upgrade(text)
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
|
self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
|
||||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
|
self.assertIn("test.py:1:0: %s requires manual check" % ns, errors[0])
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"WARNING: tf.metrics have been converted to object oriented"
|
"WARNING: tf.metrics have been converted to object oriented"
|
||||||
" versions in TF 2.0 and after. The metric function calls have been "
|
" versions in TF 2.0 and after. The metric function calls have been "
|
||||||
@ -445,7 +444,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
text = ns + "(a, b)"
|
text = ns + "(a, b)"
|
||||||
_, report, errors, new_text = self._upgrade(text)
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
|
self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
|
||||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
|
self.assertIn("test.py:1:0: %s requires manual check" % ns, errors[0])
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"WARNING: tf.losses have been converted to object oriented"
|
"WARNING: tf.losses have been converted to object oriented"
|
||||||
" versions in TF 2.0 and after. The loss function calls have been "
|
" versions in TF 2.0 and after. The loss function calls have been "
|
||||||
@ -463,7 +462,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
text = ns + "(a, b)"
|
text = ns + "(a, b)"
|
||||||
_, report, errors, new_text = self._upgrade(text)
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(text, new_text)
|
self.assertEqual(text, new_text)
|
||||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
|
self.assertIn("%s requires manual check" % ns, errors[0])
|
||||||
self.assertIn("loss_reduction has been changed", report)
|
self.assertIn("loss_reduction has been changed", report)
|
||||||
|
|
||||||
def testDropout(self):
|
def testDropout(self):
|
||||||
@ -471,15 +470,40 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
new_text,
|
new_text,
|
||||||
"tf.nn.dropout(x, 1 - keep_prob, name=\"foo\")\n",
|
"tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(
|
||||||
|
new_text,
|
||||||
|
"tf.nn.dropout(x, rate=1 - (.4), name=\"foo\")\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
text = (
|
||||||
|
"tf.nn.dropout(x, # Stuff before\n"
|
||||||
|
" keep_prob=.4, # Stuff after\n"
|
||||||
|
" name=\"foo\")\n"
|
||||||
|
)
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(
|
||||||
|
new_text,
|
||||||
|
"tf.nn.dropout(x, # Stuff before\n"
|
||||||
|
" rate=1 - (.4), # Stuff after\n"
|
||||||
|
" name=\"foo\")\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
text = "tf.nn.dropout(x)\n"
|
text = "tf.nn.dropout(x)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, text)
|
self.assertEqual(new_text, text)
|
||||||
|
self.assertIn("tf.nn.dropout called without arguments", errors[0])
|
||||||
|
|
||||||
|
def testDropoutExpr(self):
|
||||||
|
text = "tf.nn.dropout(x, 1 - func(3 + 4.), name=\"foo\")\n"
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
errors,
|
new_text,
|
||||||
["test.py:1: tf.nn.dropout requires manual check."]
|
"tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
def testCountNonZeroChanges(self):
|
def testCountNonZeroChanges(self):
|
||||||
@ -543,9 +567,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
|
|
||||||
text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
|
text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(text, new_text)
|
self.assertEqual("tf.gradients(a)\n", new_text)
|
||||||
self.assertEqual(errors, ["test.py:1: tf.gradients requires manual check."])
|
self.assertIn("tf.gradients", errors[0])
|
||||||
|
self.assertIn("requires manual check", errors[0])
|
||||||
|
|
||||||
|
def testColocateGradientsWithOpsMinimize(self):
|
||||||
text = "optimizer.minimize(a, foo=False)\n"
|
text = "optimizer.minimize(a, foo=False)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(text, new_text)
|
self.assertEqual(text, new_text)
|
||||||
@ -553,10 +579,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
|
|
||||||
text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
|
text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(text, new_text)
|
self.assertEqual("optimizer.minimize(a)\n", new_text)
|
||||||
self.assertEqual(errors,
|
self.assertIn("requires manual check", errors[0])
|
||||||
["test.py:1: Optimizer.minimize requires manual check."])
|
self.assertIn("minimize", errors[0])
|
||||||
|
|
||||||
|
def testColocateGradientsWithOpsComputeGradients(self):
|
||||||
text = "optimizer.compute_gradients(a, foo=False)\n"
|
text = "optimizer.compute_gradients(a, foo=False)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(text, new_text)
|
self.assertEqual(text, new_text)
|
||||||
@ -564,10 +591,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
|
|
||||||
text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
|
text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
_, unused_report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(text, new_text)
|
self.assertEqual("optimizer.compute_gradients(a)\n", new_text)
|
||||||
self.assertEqual(errors,
|
self.assertIn("requires manual check", errors[0])
|
||||||
["test.py:1: Optimizer.compute_gradients "
|
self.assertIn("compute_gradients", errors[0])
|
||||||
"requires manual check."])
|
|
||||||
|
|
||||||
def testExportSavedModelRename(self):
|
def testExportSavedModelRename(self):
|
||||||
text = "self.est.export_savedmodel(path)"
|
text = "self.est.export_savedmodel(path)"
|
||||||
@ -673,7 +699,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
_, report, errors, new_text = self._upgrade(text)
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, expected_text)
|
self.assertEqual(new_text, expected_text)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"tf.nn.softmax_cross_entropy_with_logits requires manual check.",
|
"tf.nn.softmax_cross_entropy_with_logits requires manual check",
|
||||||
errors[0])
|
errors[0])
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. ",
|
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. ",
|
||||||
@ -817,27 +843,24 @@ tf.print('abc')
|
|||||||
|
|
||||||
def testBatchGather(self):
|
def testBatchGather(self):
|
||||||
text = "tf.batch_gather(foo, bar)"
|
text = "tf.batch_gather(foo, bar)"
|
||||||
expected_text = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
|
||||||
|
expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, expected_text)
|
self.assertIn(new_text, [expected_text1, expected_text2])
|
||||||
|
|
||||||
text = "tf.batch_gather(params=foo, indices=bar)"
|
text = "tf.batch_gather(params=foo, indices=bar)"
|
||||||
expected_text = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
|
||||||
|
expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
|
||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, expected_text)
|
self.assertIn(new_text, [expected_text1, expected_text2])
|
||||||
|
|
||||||
text = "tf.batch_gather ( foo, bar)"
|
def testCast(self):
|
||||||
expected_text = "tf.gather (batch_dims=-1, params=foo, indices=bar)"
|
for dtype in ["int32", "int64", "float", "double",
|
||||||
|
"complex64", "complex128", "bfloat16"]:
|
||||||
|
text = "tf.to_%s(x, name='test')" % dtype
|
||||||
|
expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
|
||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(new_text, expected_text)
|
self.assertEqual(expected_text, new_text)
|
||||||
|
|
||||||
text = "(tf.batch_gather\n(foo, bar))"
|
|
||||||
expected_text = "(tf.gather\n(params=foo, indices=bar))"
|
|
||||||
expected_errors = ["test.py:1: Unable to add keyword argument batch_dims=-1"
|
|
||||||
" to tf.batch_gather; please add it manually."]
|
|
||||||
_, unused_report, errors, new_text = self._upgrade(text)
|
|
||||||
self.assertEqual(errors, expected_errors)
|
|
||||||
self.assertEqual(new_text, expected_text)
|
|
||||||
|
|
||||||
|
|
||||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||||
|
@ -169,6 +169,7 @@ filegroup(
|
|||||||
"@local_config_sycl//sycl:LICENSE.text",
|
"@local_config_sycl//sycl:LICENSE.text",
|
||||||
"@nasm//:LICENSE",
|
"@nasm//:LICENSE",
|
||||||
"@nsync//:LICENSE",
|
"@nsync//:LICENSE",
|
||||||
|
"@pasta//:LICENSE",
|
||||||
"@pcre//:LICENCE",
|
"@pcre//:LICENCE",
|
||||||
"@png_archive//:LICENSE",
|
"@png_archive//:LICENSE",
|
||||||
"@protobuf_archive//:LICENSE",
|
"@protobuf_archive//:LICENSE",
|
||||||
|
@ -29,6 +29,7 @@ load("//third_party/jpeg:workspace.bzl", jpeg = "repo")
|
|||||||
load("//third_party/nasm:workspace.bzl", nasm = "repo")
|
load("//third_party/nasm:workspace.bzl", nasm = "repo")
|
||||||
load("//third_party/kissfft:workspace.bzl", kissfft = "repo")
|
load("//third_party/kissfft:workspace.bzl", kissfft = "repo")
|
||||||
load("//third_party/keras_applications_archive:workspace.bzl", keras_applications = "repo")
|
load("//third_party/keras_applications_archive:workspace.bzl", keras_applications = "repo")
|
||||||
|
load("//third_party/pasta:workspace.bzl", pasta = "repo")
|
||||||
|
|
||||||
def initialize_third_party():
|
def initialize_third_party():
|
||||||
""" Load third party repositories. See above load() statements. """
|
""" Load third party repositories. See above load() statements. """
|
||||||
@ -41,6 +42,7 @@ def initialize_third_party():
|
|||||||
kissfft()
|
kissfft()
|
||||||
jpeg()
|
jpeg()
|
||||||
nasm()
|
nasm()
|
||||||
|
pasta()
|
||||||
|
|
||||||
# Sanitize a dependency so that it works correctly from code that includes
|
# Sanitize a dependency so that it works correctly from code that includes
|
||||||
# TensorFlow as a submodule.
|
# TensorFlow as a submodule.
|
||||||
|
1
third_party/pasta/BUILD
vendored
Normal file
1
third_party/pasta/BUILD
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Empty BUILD file to force build system to see this directory at all.
|
29
third_party/pasta/BUILD.bazel
vendored
Normal file
29
third_party/pasta/BUILD.bazel
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Description:
|
||||||
|
# AST-based python refactoring.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache2
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "pasta",
|
||||||
|
srcs = [
|
||||||
|
"__init__.py",
|
||||||
|
"augment/__init__.py",
|
||||||
|
"augment/errors.py",
|
||||||
|
"augment/import_utils.py",
|
||||||
|
"augment/inline.py",
|
||||||
|
"augment/rename.py",
|
||||||
|
"base/__init__.py",
|
||||||
|
"base/annotate.py",
|
||||||
|
"base/ast_constants.py",
|
||||||
|
"base/ast_utils.py",
|
||||||
|
"base/codegen.py",
|
||||||
|
"base/formatting.py",
|
||||||
|
"base/scope.py",
|
||||||
|
"base/test_utils.py",
|
||||||
|
"base/token_generator.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
13
third_party/pasta/BUILD.system
vendored
Normal file
13
third_party/pasta/BUILD.system
vendored
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Description: Pasta, AST based python refactoring.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache2
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "LICENSE",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "pasta",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
16
third_party/pasta/workspace.bzl
vendored
Normal file
16
third_party/pasta/workspace.bzl
vendored
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
"""Loads pasta python package."""
|
||||||
|
|
||||||
|
load("//third_party:repo.bzl", "third_party_http_archive")
|
||||||
|
|
||||||
|
def repo():
|
||||||
|
third_party_http_archive(
|
||||||
|
name = "pasta",
|
||||||
|
urls = [
|
||||||
|
"https://mirror.bazel.build/github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
|
||||||
|
"https://github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
|
||||||
|
],
|
||||||
|
strip_prefix = "pasta-c3d72cdee6fc806251949e912510444d58d7413c/pasta",
|
||||||
|
sha256 = "b5905f9cecc4b28363c563f3c4cb0545288bd35f7cc72c55066e97e53befc084",
|
||||||
|
build_file = "//third_party/pasta:BUILD.bazel",
|
||||||
|
system_build_file = "//third_party/pasta:BUILD.system",
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user