Refactor compatibility upgrade tool to use pasta (an AST based refactoring tool) instead of line based edits.

This allows for some simplification (e.g., we can now remove arguments with a dict entry) and for much more powerful transformations.

This passes all tests with no destructive test changes (though there are cosmetic test changes).

Also adds transformations for tf.to_dtype -> tf.cast(..., dtype=...).

PiperOrigin-RevId: 227934801
This commit is contained in:
Martin Wicke 2019-01-04 16:51:20 -08:00 committed by TensorFlower Gardener
parent 75f2e9c266
commit fe66882827
13 changed files with 689 additions and 505 deletions

View File

@ -1,17 +1,21 @@
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_copts", # @unused "tf_copts", # @unused
"tf_cc_test", # @unused "tf_cc_test", # @unused
) )
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
py_library( py_library(
name = "ast_edits", name = "ast_edits",
srcs = ["ast_edits.py"], srcs = ["ast_edits.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [
"@pasta",
"@six_archive//:six",
],
) )
py_test( py_test(

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import ast import ast
import collections
import os import os
import re import re
import shutil import shutil
@ -27,6 +26,9 @@ import sys
import tempfile import tempfile
import traceback import traceback
import pasta
import six
# Some regular expressions we will need for parsing # Some regular expressions we will need for parsing
FIND_OPEN = re.compile(r"^\s*(\[).*$") FIND_OPEN = re.compile(r"^\s*(\[).*$")
FIND_STRING_CHARS = re.compile(r"['\"]") FIND_STRING_CHARS = re.compile(r"['\"]")
@ -44,169 +46,173 @@ class APIChangeSpec(object):
notifications) notifications)
* `function_reorders`: maps functions whose argument order has changed to the * `function_reorders`: maps functions whose argument order has changed to the
list of arguments in the new order list of arguments in the new order
* `function_handle`: maps function names to custom handlers for the function
* `function_warnings`: maps full names of functions to warnings that will be * `function_warnings`: maps full names of functions to warnings that will be
printed out if the function is used. (e.g. tf.nn.convolution()) printed out if the function is used. (e.g. tf.nn.convolution())
* `unrestricted_function_warnings`: maps names of functions to warnings that * `function_transformers`: maps function names to custom handlers
will be printed out when the function is used (e.g. foo.convolution()).
* `function_keyword_additions`: maps function names to a map of arg->value
names that should be passed to the function.
For an example, see `TFAPIChangeSpec`. For an example, see `TFAPIChangeSpec`.
""" """
class _FileEditTuple( class _PastaEditVisitor(ast.NodeVisitor):
collections.namedtuple("_FileEditTuple",
["comment", "line", "start", "old", "new"])):
"""Each edit that is recorded by a _FileEditRecorder.
Fields:
comment: A description of the edit and why it was made.
line: The line number in the file where the edit occurs (1-indexed).
start: The column number in the file where the edit occurs (0-indexed).
old: text string to remove (this must match what was in file).
new: text string to add in place of `old`.
"""
__slots__ = ()
class _FileEditRecorder(object):
"""Record changes that need to be done to the file."""
def __init__(self, filename):
# all edits are lists of chars
self._filename = filename
self._line_to_edit = collections.defaultdict(list)
self._errors = []
def process(self, text):
"""Process a list of strings, each corresponding to the recorded changes.
Args:
text: A list of lines of text (assumed to contain newlines)
Returns:
A tuple of the modified text and a textual description of what is done.
Raises:
ValueError: if substitution source location does not have expected text.
"""
change_report = ""
# Iterate of each line
for line, edits in self._line_to_edit.items():
offset = 0
# sort by column so that edits are processed in order in order to make
# indexing adjustments cumulative for changes that change the string
# length
edits.sort(key=lambda x: x.start)
# Extract each line to a list of characters, because mutable lists
# are editable, unlike immutable strings.
char_array = list(text[line - 1])
# Record a description of the change
change_report += "%r Line %d\n" % (self._filename, line)
change_report += "-" * 80 + "\n\n"
for e in edits:
change_report += "%s\n" % e.comment
change_report += "\n Old: %s" % (text[line - 1])
# Make underscore buffers for underlining where in the line the edit was
change_list = [" "] * len(text[line - 1])
change_list_new = [" "] * len(text[line - 1])
# Iterate for each edit
for e in edits:
# Create effective start, end by accounting for change in length due
# to previous edits
start_eff = e.start + offset
end_eff = start_eff + len(e.old)
# Make sure the edit is changing what it should be changing
old_actual = "".join(char_array[start_eff:end_eff])
if old_actual != e.old:
raise ValueError("Expected text %r but got %r" %
("".join(e.old), "".join(old_actual)))
# Make the edit
char_array[start_eff:end_eff] = list(e.new)
# Create the underline highlighting of the before and after
change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
change_list_new[start_eff:end_eff] = "~" * len(e.new)
# Keep track of how to generate effective ranges
offset += len(e.new) - len(e.old)
# Finish the report comment
change_report += " %s\n" % "".join(change_list)
text[line - 1] = "".join(char_array)
change_report += " New: %s" % (text[line - 1])
change_report += " %s\n\n" % "".join(change_list_new)
return "".join(text), change_report, self._errors
def add(self, comment, line, start, old, new, error=None):
"""Add a new change that is needed.
Args:
comment: A description of what was changed
line: Line number (1 indexed)
start: Column offset (0 indexed)
old: old text
new: new text
error: this "edit" is something that cannot be fixed automatically
Returns:
None
"""
self._line_to_edit[line].append(
_FileEditTuple(comment, line, start, old, new))
if error:
self._errors.append("%s:%d: %s" % (self._filename, line, error))
class _ASTCallVisitor(ast.NodeVisitor):
"""AST Visitor that processes function calls. """AST Visitor that processes function calls.
Updates function calls from old API version to new API version using a given Updates function calls from old API version to new API version using a given
change spec. change spec.
""" """
def __init__(self, filename, lines, api_change_spec): def __init__(self, api_change_spec):
self._filename = filename
self._file_edit = _FileEditRecorder(filename)
self._lines = lines
self._api_change_spec = api_change_spec self._api_change_spec = api_change_spec
self._log = [] # Holds 3-tuples: line, col, msg.
self._errors = [] # Same structure as _log.
self._stack = [] # Allow easy access to parents.
def process(self, lines): # Overridden to maintain a stack of nodes to allow for parent access
return self._file_edit.process(lines) def visit(self, node):
self._stack.append(node)
super(_PastaEditVisitor, self).visit(node)
self._stack.pop()
def generic_visit(self, node): @property
ast.NodeVisitor.generic_visit(self, node) def errors(self):
return self._errors
def _rename_functions(self, node, full_name): @property
symbol_renames = self._api_change_spec.symbol_renames def log(self):
try: return self._log
new_name = symbol_renames[full_name]
self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
node.lineno, node.col_offset, full_name, new_name)
except KeyError:
pass
def _print_warning_for_function(self, node, full_name): def _format_log(self, log):
text = ""
for log_entry in log:
text += "Line %d:%d: %s\n" % log_entry
return text
def log_text(self):
return self._format_log(self.log)
def add_log(self, lineno, col, msg):
self._log.append((lineno, col, msg))
print("Line %d:%d: %s" % (lineno, col, msg))
def add_error(self, lineno, col, msg):
# All errors are also added to the regular log.
self.add_log(lineno, col, msg)
self._errors.append((lineno, col, msg))
def add_logs(self, logs):
"""Record a log and print it.
The log should be a tuple (lineno, col_offset, msg), which will be printed
and then recorded. It is part of the log available in the self.log property.
Args:
logs: The log to add. Must be a tuple (lineno, col_offset, msg).
"""
self._log.extend(logs)
for log in logs:
print("Line %d:%d: %s" % log)
def add_errors(self, errors):
"""Record an error and print it.
The error must be a tuple (lineno, col_offset, msg), which will be printed
and then recorded as both a log and an error. It is therefore part of the
log available in the self.log as well as the self.errors property.
Args:
errors: The log to add. Must be a tuple (lineno, col_offset, msg).
"""
self.add_logs(errors)
self._errors.extend(errors)
def _get_applicable_entries(self, transformer_field, full_name, name):
"""Get all list entries indexed by name that apply to full_name or name."""
# Transformers are indexed to full name, name, or no name
# as a performance optimization.
function_transformers = getattr(self._api_change_spec,
transformer_field, {})
glob_name = "*." + name if name else None
transformers = []
if full_name in function_transformers:
transformers.append(function_transformers[full_name])
if glob_name in function_transformers:
transformers.append(function_transformers[glob_name])
if "*" in function_transformers:
transformers.append(function_transformers["*"])
return transformers
def _get_applicable_dict(self, transformer_field, full_name, name):
"""Get all dict entries indexed by name that apply to full_name or name."""
# Transformers are indexed to full name, name, or no name
# as a performance optimization.
function_transformers = getattr(self._api_change_spec,
transformer_field, {})
glob_name = "*." + name if name else None
transformers = function_transformers.get("*", {}).copy()
transformers.update(function_transformers.get(glob_name, {}))
transformers.update(function_transformers.get(full_name, {}))
return transformers
def _get_full_name(self, node):
"""Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
This is the inverse of _full_name_node.
Args:
node: A Node of type Attribute.
Returns:
a '.'-delimited full-name or None if node was not Attribute or Name.
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
"""
curr = node
items = []
while not isinstance(curr, ast.Name):
if not isinstance(curr, ast.Attribute):
return None
items.append(curr.attr)
curr = curr.value
items.append(curr.id)
return ".".join(reversed(items))
def _full_name_node(self, name, ctx=ast.Load()):
"""Make an Attribute or Name node for name.
Translate a qualified name into nested Attribute nodes (and a Name node).
Args:
name: The name to translate to a node.
ctx: What context this name is used in. Defaults to Load()
Returns:
A Name or Attribute node.
"""
names = name.split(".")
names.reverse()
node = ast.Name(id=names.pop(), ctx=ast.Load())
while names:
node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
# Change outermost ctx to the one given to us (inner ones should be Load).
node.ctx = ctx
return node
def _maybe_add_warning(self, node, full_name):
"""Adds an error to be printed about full_name at node."""
function_warnings = self._api_change_spec.function_warnings function_warnings = self._api_change_spec.function_warnings
try: if full_name in function_warnings:
warning_message = function_warnings[full_name] warning_message = function_warnings[full_name]
warning_message = warning_message.replace("<function name>", full_name) warning_message = warning_message.replace("<function name>", full_name)
self._file_edit.add(warning_message, self.add_error(node.lineno, node.col_offset,
node.lineno, node.col_offset, full_name, full_name, "%s requires manual check: %s." % (full_name,
error="%s requires manual check." % full_name) warning_message))
except KeyError: return True
pass else:
return False
def _print_warning_for_function_unrestricted(self, node): def _maybe_add_call_warning(self, node, full_name, name):
"""Print a warning when specific functions are called. """Print a warning when specific functions are called.
The function _print_warning_for_function matches the full name of the called The function _print_warning_for_function matches the full name of the called
@ -216,92 +222,118 @@ class _ASTCallVisitor(ast.NodeVisitor):
Args: Args:
node: ast.Call object node: ast.Call object
full_name: The precomputed full name of the callable, if one exists, None
otherwise.
name: The precomputed name of the callable, if one exists, None otherwise.
Returns:
Whether an error was recorded.
""" """
function_warnings = getattr( # Only look for *.-warnings here, the other will be handled by the Attribute
self._api_change_spec, "unrestricted_function_warnings", {}) # visitor. Also, do not warn for bare functions, only if the call func is
# an attribute.
warned = False
if isinstance(node.func, ast.Attribute): if isinstance(node.func, ast.Attribute):
function_name = node.func.attr warned = self._maybe_add_warning(node, "*." + name)
try:
warning_message = function_warnings[function_name]
self._file_edit.add(warning_message,
node.lineno, node.col_offset, "", "",
error="%s requires manual check." % function_name)
except KeyError:
pass
def _get_attribute_full_path(self, node): # All arg warnings are handled here, since only we have the args
"""Traverse an attribute to generate a full name e.g. tf.foo.bar. arg_warnings = self._get_applicable_dict("function_arg_warnings",
full_name, name)
Args: used_args = [kw.arg for kw in node.keywords]
node: A Node of type Attribute. for arg, warning in arg_warnings.items():
if arg in used_args:
warned = True
warning_message = warning.replace("<function name>", full_name or name)
self.add_error(node.lineno, node.col_offset,
"%s called with %s argument requires manual check: %s." %
(full_name or name, arg, warning_message))
Returns: return warned
a '.'-delimited full-name or None if the tree was not a simple form.
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
"""
curr = node
items = []
while not isinstance(curr, ast.Name):
if not isinstance(curr, ast.Attribute):
return None, None
items.append(curr.attr)
curr = curr.value
items.append(curr.id)
return ".".join(reversed(items)), items[0]
def _find_true_position(self, node): def _maybe_rename(self, parent, node, full_name):
"""Return correct line number and column offset for a given node. """Replace node (Attribute or Name) with a node representing full_name."""
new_name = self._api_change_spec.symbol_renames.get(full_name, None)
if new_name:
self.add_log(node.lineno, node.col_offset,
"Renamed %r to %r" % (full_name, new_name))
new_node = self._full_name_node(new_name, node.ctx)
ast.copy_location(new_node, node)
pasta.ast_utils.replace_child(parent, node, new_node)
return True
else:
return False
This is necessary mainly because ListComp's location reporting reports def _maybe_change_to_function_call(self, parent, node, full_name):
the next token after the list comprehension list opening. """Wraps node (typically, an Attribute or Expr) in a Call."""
if full_name in self._api_change_spec.change_to_function:
Returns: if not isinstance(parent, ast.Call):
lineno, offset for the given node # ast.Call's constructor is really picky about how many arguments it
# wants, and also, it changed between Py2 and Py3.
Args: if six.PY2:
node: Node for which we wish to know the lineno and col_offset new_node = ast.Call(node, [], [], None, None)
"""
if isinstance(node, ast.ListComp):
# Strangely, ast.ListComp returns the col_offset of the first token
# after the '[' token which appears to be a bug. Workaround by
# explicitly finding the real start of the list comprehension.
line = node.lineno
col = node.col_offset
# loop over lines
while 1:
# Reverse the text to and regular expression search for whitespace
text = self._lines[line - 1]
reversed_preceding_text = text[:col][::-1]
# First find if a [ can be found with only whitespace between it and
# col.
m = FIND_OPEN.match(reversed_preceding_text)
if m:
new_col_offset = col - m.start(1) - 1
return line, new_col_offset
else: else:
if (reversed_preceding_text == "" or new_node = ast.Call(node, [], [])
reversed_preceding_text.isspace()): pasta.ast_utils.replace_child(parent, node, new_node)
line = line - 1 ast.copy_location(new_node, node)
prev_line = self._lines[line - 1] self.add_log(node.lineno, node.col_offset,
# TODO(aselle): "Changed %r to a function call" % full_name)
# this is poor comment detection, but it is good enough for return True
# cases where the comment does not contain string literal starting/ return False
# ending characters. If ast gave us start and end locations of the
# ast nodes rather than just start, we could use string literal def _maybe_add_arg_names(self, node, full_name):
# node ranges to filter out spurious #'s that appear in string """Make args into keyword args if function called full_name requires it."""
# literals. function_reorders = self._api_change_spec.function_reorders
comment_start = prev_line.find("#")
if comment_start == -1: if full_name in function_reorders:
col = len(prev_line) - 1 reordered = function_reorders[full_name]
elif FIND_STRING_CHARS.search(prev_line[comment_start:]) is None: new_keywords = []
col = comment_start for idx, arg in enumerate(node.args):
else: keyword_arg = reordered[idx]
return None, None new_keywords.append(ast.keyword(arg=keyword_arg, value=arg))
else:
return None, None if new_keywords:
# Most other nodes return proper locations (with notably does not), but self.add_log(node.lineno, node.col_offset,
# it is not possible to use that in an argument. "Added keywords to args of function %r" % full_name)
return node.lineno, node.col_offset node.args = []
node.keywords = new_keywords + (node.keywords or [])
return True
return False
def _maybe_modify_args(self, node, full_name, name):
"""Rename keyword args if the function called full_name requires it."""
renamed_keywords = self._get_applicable_dict("function_keyword_renames",
full_name, name)
if not renamed_keywords:
return False
modified = False
new_keywords = []
for keyword in node.keywords:
argkey = keyword.arg
if argkey in renamed_keywords:
modified = True
if renamed_keywords[argkey] is None:
lineno = getattr(keyword, "lineno", node.lineno)
col_offset = getattr(keyword, "col_offset", node.col_offset)
self.add_log(lineno, col_offset,
"Removed argument %s for function %s" % (
argkey, full_name or name))
else:
keyword.arg = renamed_keywords[argkey]
lineno = getattr(keyword, "lineno", node.lineno)
col_offset = getattr(keyword, "col_offset", node.col_offset)
self.add_log(lineno, col_offset,
"Renamed keyword argument for %s from %s to %s" % (
full_name, argkey, renamed_keywords[argkey]))
new_keywords.append(keyword)
else:
new_keywords.append(keyword)
if modified:
node.keywords = new_keywords
return modified
def visit_Call(self, node): # pylint: disable=invalid-name def visit_Call(self, node): # pylint: disable=invalid-name
"""Handle visiting a call node in the AST. """Handle visiting a call node in the AST.
@ -309,104 +341,74 @@ class _ASTCallVisitor(ast.NodeVisitor):
Args: Args:
node: Current Node node: Current Node
""" """
self._print_warning_for_function_unrestricted(node) assert self._stack[-1] is node
# Find a simple attribute name path e.g. "tf.foo.bar"
full_name, name = self._get_attribute_full_path(node.func)
# Make sure the func is marked as being part of a call
node.func.is_function_for_call = True
# Get the name for this call, so we can index stuff with it.
full_name = self._get_full_name(node.func)
if full_name: if full_name:
# Call special handlers name = full_name.split(".")[-1]
function_handles = self._api_change_spec.function_handle elif isinstance(node.func, ast.Name):
glob_name = "*.{}".format(name) name = node.func.id
if glob_name in function_handles: elif isinstance(node.func, ast.Attribute):
function_handles[glob_name](self._file_edit, node, self._lines) name = node.func.attr
if full_name in function_handles: else:
function_handles[full_name](self._file_edit, node, self._lines) name = None
# Examine any non-keyword argument and make it into a keyword argument # Call standard transformers for this node.
# if reordering required. # Make sure warnings come first, since args or names triggering warnings
function_reorders = self._api_change_spec.function_reorders # may be removed by the other transformations.
function_keyword_renames = ( self._maybe_add_call_warning(node, full_name, name)
self._api_change_spec.function_keyword_renames) # Make all args into kwargs
self._maybe_add_arg_names(node, full_name)
# Argument name changes or deletions
self._maybe_modify_args(node, full_name, name)
if full_name in function_reorders: # Call transformers. These have the ability to modify the node, and if they
reordered = function_reorders[full_name] # do, will return the new node they created (or the same node if they just
for idx, arg in enumerate(node.args): # changed it). The are given the parent, but we will take care of
lineno, col_offset = self._find_true_position(arg) # integrating their changes into the parent if they return a new node.
if lineno is None or col_offset is None: #
self._file_edit.add( # These are matched on the old name, since renaming is performed by the
"Failed to add keyword %r to reordered function %r" % # Attribute visitor, which happens later.
(reordered[idx], full_name), transformers = self._get_applicable_entries("function_transformers",
arg.lineno, full_name, name)
arg.col_offset,
"",
"",
error="A necessary keyword argument failed to be inserted.")
else:
keyword_arg = reordered[idx]
if (full_name in function_keyword_renames and
keyword_arg in function_keyword_renames[full_name]):
keyword_arg = function_keyword_renames[full_name][keyword_arg]
self._file_edit.add("Added keyword %r to reordered function %r" %
(reordered[idx], full_name), lineno, col_offset,
"", keyword_arg + "=")
# Examine each keyword argument and convert it to the final renamed form parent = self._stack[-2]
renamed_keywords = ({} if full_name not in function_keyword_renames else
function_keyword_renames[full_name])
for keyword in node.keywords:
argkey = keyword.arg
argval = keyword.value
if argkey in renamed_keywords: for transformer in transformers:
argval_lineno, argval_col_offset = self._find_true_position(argval) logs = []
if argval_lineno is not None and argval_col_offset is not None: errors = []
# TODO(aselle): We should scan backward to find the start of the new_node = transformer(parent, node, full_name, name, logs, errors)
# keyword key. Unfortunately ast does not give you the location of self.add_logs(logs)
# keyword keys, so we are forced to infer it from the keyword arg self.add_errors(errors)
# value. if new_node:
key_start = argval_col_offset - len(argkey) - 1 if new_node is not node:
key_end = key_start + len(argkey) + 1 pasta.ast_utils.replace_child(parent, node, new_node)
if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + node = new_node
"="): self._stack[-1] = node
self._file_edit.add("Renamed keyword argument from %r to %r" %
(argkey,
renamed_keywords[argkey]), argval_lineno,
argval_col_offset - len(argkey) - 1,
argkey + "=", renamed_keywords[argkey] + "=")
continue
self._file_edit.add(
"Failed to rename keyword argument from %r to %r" %
(argkey, renamed_keywords[argkey]),
argval.lineno,
argval.col_offset - len(argkey) - 1,
"",
"",
error="Failed to find keyword lexographically. Fix manually.")
ast.NodeVisitor.generic_visit(self, node) self.generic_visit(node)
def visit_Attribute(self, node): # pylint: disable=invalid-name def visit_Attribute(self, node): # pylint: disable=invalid-name
"""Handle bare Attributes i.e. [tf.foo, tf.bar]. """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
assert self._stack[-1] is node
Args: full_name = self._get_full_name(node)
node: Node that is of type ast.Attribute
"""
full_name, _ = self._get_attribute_full_path(node)
if full_name: if full_name:
# Make sure the warning comes first, otherwise the name may have changed parent = self._stack[-2]
self._print_warning_for_function(node, full_name)
self._rename_functions(node, full_name)
if full_name in self._api_change_spec.change_to_function:
if not hasattr(node, "is_function_for_call"):
new_text = full_name + "()"
self._file_edit.add("Changed %r to %r" % (full_name, new_text),
node.lineno, node.col_offset, full_name, new_text)
ast.NodeVisitor.generic_visit(self, node) # Make sure the warning comes first, otherwise the name may have changed
self._maybe_add_warning(node, full_name)
# Once we did a modification, node is invalid and not worth inspecting
# further. Also, we only perform modifications for simple nodes, so
# There'd be no point in descending further.
if self._maybe_rename(parent, node, full_name):
return
if self._maybe_change_to_function_call(parent, node, full_name):
return
self.generic_visit(node)
class ASTCodeUpgrader(object): class ASTCodeUpgrader(object):
@ -429,16 +431,42 @@ class ASTCodeUpgrader(object):
""" """
# Write to a temporary file, just in case we are doing an implace modify. # Write to a temporary file, just in case we are doing an implace modify.
# pylint: disable=g-backslash-continuation
with open(in_filename, "r") as in_file, \ with open(in_filename, "r") as in_file, \
tempfile.NamedTemporaryFile("w", delete=False) as temp_file: tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
ret = self.process_opened_file(in_filename, in_file, out_filename, ret = self.process_opened_file(in_filename, in_file, out_filename,
temp_file) temp_file)
# pylint: enable=g-backslash-continuation
shutil.move(temp_file.name, out_filename) shutil.move(temp_file.name, out_filename)
return ret return ret
# Broad exceptions are required here because ast throws whatever it wants. def _format_errors(self, errors, in_filename):
# pylint: disable=broad-except return ["%s:%d:%d: %s" % ((in_filename,) + error) for error in errors]
def update_string_pasta(self, text, in_filename):
"""Updates a file using pasta."""
try:
t = pasta.parse(text)
except (SyntaxError, ValueError, TypeError):
log = "Failed to parse.\n\n" + traceback.format_exc()
return 0, "", log, []
visitor = _PastaEditVisitor(self._api_change_spec)
visitor.visit(t)
errors = self._format_errors(visitor.errors, in_filename)
return 1, pasta.dump(t), visitor.log_text(), errors
def _format_log(self, log, in_filename, out_filename):
text = "-" * 80 + "\n"
text += "Processing file %r\n outputting to %r\n" % (in_filename,
out_filename)
text += "-" * 80 + "\n\n"
text += log
text += "-" * 80 + "\n\n"
return text
def process_opened_file(self, in_filename, in_file, out_filename, out_file): def process_opened_file(self, in_filename, in_file, out_filename, out_file):
"""Process the given python file for incompatible changes. """Process the given python file for incompatible changes.
@ -453,30 +481,16 @@ class ASTCodeUpgrader(object):
Returns: Returns:
A tuple representing number of files processed, log of actions, errors A tuple representing number of files processed, log of actions, errors
""" """
process_errors = []
text = "-" * 80 + "\n"
text += "Processing file %r\n outputting to %r\n" % (in_filename,
out_filename)
text += "-" * 80 + "\n\n"
parsed_ast = None
lines = in_file.readlines() lines = in_file.readlines()
try: processed_file, new_file_content, log, process_errors = (
parsed_ast = ast.parse("".join(lines)) self.update_string_pasta("".join(lines), in_filename))
except Exception:
text += "Failed to parse %r\n\n" % in_filename
text += traceback.format_exc()
if parsed_ast:
visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
visitor.visit(parsed_ast)
out_text, new_text, process_errors = visitor.process(lines)
text += new_text
if out_file:
out_file.write(out_text)
text += "\n"
return 1, text, process_errors
# pylint: enable=broad-except if out_file and processed_file:
out_file.write(new_file_content)
return (processed_file,
self._format_log(log, in_filename, out_filename),
process_errors)
def process_tree(self, root_directory, output_root_directory, def process_tree(self, root_directory, output_root_directory,
copy_other_files): copy_other_files):

View File

@ -39,6 +39,8 @@ following new APIs:
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import ast
import pasta
import six import six
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.platform import test as test_lib from tensorflow.python.platform import test as test_lib
@ -54,7 +56,6 @@ class NoUpdateSpec(ast_edits.APIChangeSpec):
self.function_keyword_renames = {} self.function_keyword_renames = {}
self.symbol_renames = {} self.symbol_renames = {}
self.function_warnings = {} self.function_warnings = {}
self.unrestricted_function_warnings = {}
self.change_to_function = {} self.change_to_function = {}
@ -401,7 +402,8 @@ class TestAstEdits(test_util.TensorFlowTestCase):
def __init__(self): def __init__(self):
NoUpdateSpec.__init__(self) NoUpdateSpec.__init__(self)
self.unrestricted_function_warnings = {"foo": "not good"} self.function_warnings = {"*.foo": "not good"}
texts = ["object.foo()", "get_object().foo()", texts = ["object.foo()", "get_object().foo()",
"get_object().foo()", "object.foo().bar()"] "get_object().foo()", "object.foo().bar()"]
for text in texts: for text in texts:
@ -416,5 +418,26 @@ class TestAstEdits(test_util.TensorFlowTestCase):
self.assertNotIn("not good", report) self.assertNotIn("not good", report)
class ManualEditsTest(test_util.TensorFlowTestCase):
def disabled_test_arg_order(self):
"""Tests that generated arg order is sane."""
text = "f(a)"
t = pasta.parse(text)
node = pasta.ast_utils.find_nodes_by_type(t, (ast.Call,))[0]
arg = ast.keyword(arg="b", value=ast.Num(n=0))
node.keywords.append(arg)
# This is only needed in Python3, and I think it's a bug (but maybe in ast).
arg.value.lineno = 0
arg.value.col_offset = 0
# pasta.dump should never put kwargs before args, even if the col_offset is
# messed up.
# This fails if run with python3, but works find for python2.
# In python3, the dump yields "f(b=0, a)".
self.assertEqual(pasta.dump(t), "f(a, b=0)")
if __name__ == "__main__": if __name__ == "__main__":
test_lib.main() test_lib.main()

View File

@ -175,27 +175,13 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.op_scope": ["values", "name", "default_name"], "tf.op_scope": ["values", "name", "default_name"],
} }
# Specially handled functions.
self.function_handle = {"tf.reverse": self._reverse_handler}
# Warnings that should be printed if corresponding functions are used. # Warnings that should be printed if corresponding functions are used.
self.function_warnings = {} self.function_warnings = {
"tf.reverse":
@staticmethod "ERROR: tf.reverse has had its argument semantics changed "
def _reverse_handler(file_edit_recorder, node, lines): "significantly. The converter cannot detect this reliably, so "
del lines "you need to inspect this usage manually.\n",
# TODO(aselle): Could check for a literal list of bools and try to convert }
# them to indices.
comment = ("ERROR: tf.reverse has had its argument semantics changed "
"significantly the converter cannot detect this reliably, so "
"you need to inspect this usage manually.\n")
file_edit_recorder.add(
comment,
node.lineno,
node.col_offset,
"tf.reverse",
"tf.reverse",
error="tf.reverse requires manual check.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -112,7 +112,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
text = "tf.reverse(a, b)\n" text = "tf.reverse(a, b)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, new_text) self.assertEqual(new_text, new_text)
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."]) self.assertIn("tf.reverse requires manual check", errors[0])
def testListComprehension(self): def testListComprehension(self):
def _test(input, output): # pylint: disable=redefined-builtin def _test(input, output): # pylint: disable=redefined-builtin

View File

@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import re import ast
import pasta
from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import renames_v2 from tensorflow.tools.compatibility import renames_v2
@ -31,7 +33,22 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
def __init__(self): def __init__(self):
# Maps from a function name to a dictionary that describes how to # Maps from a function name to a dictionary that describes how to
# map from an old argument keyword to the new argument keyword. # map from an old argument keyword to the new argument keyword.
# If the new argument is None, it will be removed.
# Only keyword args are handled, so make sure to also put any function in
# function_reorders to ensure that all args are made into keywords first.
self.function_keyword_renames = { self.function_keyword_renames = {
"tf.gradients": {
"colocate_gradients_with_ops": None,
},
"tf.hessians": {
"colocate_gradients_with_ops": None,
},
"*.minimize": {
"colocate_gradients_with_ops": None,
},
"*.compute_gradients": {
"colocate_gradients_with_ops": None,
},
"tf.argmin": { "tf.argmin": {
"dimension": "axis", "dimension": "axis",
}, },
@ -672,14 +689,31 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# positional arguments yourself, this could do the wrong thing. # positional arguments yourself, this could do the wrong thing.
self.function_reorders = reorders_v2.reorders self.function_reorders = reorders_v2.reorders
# Specially handled functions. # Specially handled functions (pasta version)
self.function_handle = { # Each transformer is a callable which will be called with the arguments
"tf.batch_gather": self._batch_gather_handler, # transformer(parent, node, full_name, name, logs, errors)
"tf.nn.dropout": self._dropout_handler, # Where logs and errors are lists to which (line, col, msg) tuples can be
"tf.gradients": self._colocate_handler("tf.gradients"), # appended, full_name is the FQN of the function called (or None if that is
"*.minimize": self._colocate_handler("Optimizer.minimize"), # unknown), name is the name of the function called (or None is that is
"*.compute_gradients": # unknown). node is an ast.Call node representing this function call, and
self._colocate_handler("Optimizer.compute_gradients"), # parent is its parent in the AST.
# The function may modify node (but not parent), and must return
# - none, if nothing was modified
# - node, if node was modified in place (make sure to use
# pasta.ast_utils.replace_child to swap out children, otherwise formatting
# may get messy)
# - a replacement for node, if the whole call node was replaced. The caller
# will take care of changing parent.
self.function_transformers = {
"tf.nn.dropout": self._dropout_transformer,
"tf.batch_gather": self._batch_gather_transformer,
"tf.to_bfloat16": self._cast_transformer,
"tf.to_complex128": self._cast_transformer,
"tf.to_complex64": self._cast_transformer,
"tf.to_double": self._cast_transformer,
"tf.to_float": self._cast_transformer,
"tf.to_int32": self._cast_transformer,
"tf.to_int64": self._cast_transformer,
} }
decay_function_comment = ( decay_function_comment = (
@ -748,9 +782,45 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"compat.v1 for backward compatibility. Please update these calls to " "compat.v1 for backward compatibility. Please update these calls to "
"the TF 2.0 versions.") "the TF 2.0 versions.")
export_saved_model_renamed = (
"(Manual edit required) Please rename the method export_savedmodel() "
"to export_saved_model(). Two things to note:\n\t(1) The argument "
"strip_default_attributes has been removed. The function will always "
"strip the default attributes from ops. If this breaks your code, "
"please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
"only effects core estimator. If you are using "
"tf.contrib.learn.Estimator, please switch to using core estimator.")
make_initializable_iterator_deprecation = (
"(Manual edit required) The "
"`tf.data.Dataset.make_initializable_iterator()` method has been "
"removed. If you are using the Estimator API, you can return a dataset "
"directly from your input functions without creating an iterator. "
"As a last resort, please replace calls to that method on `dataset` "
"with a call to "
"`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
make_one_shot_iterator_deprecation = (
"(Manual edit required) The "
"`tf.data.Dataset.make_one_shot_iterator()` method has been "
"removed. If you are using eager execution, you can iterate over "
"`dataset` using a Python `for` loop. If you are using the Estimator "
"API, you can return a dataset directly from your input functions "
"without creating an iterator. As a last resort, please replace calls "
"to that method on `dataset` with a call to "
"`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
# Function warnings. <function name> placeholder inside warnings will be # Function warnings. <function name> placeholder inside warnings will be
# replaced by function name. # replaced by function name.
# You can use *. to add items which do not check the FQN, and apply to e.g.,
# methods.
self.function_warnings = { self.function_warnings = {
"*.export_savedmodel":
export_saved_model_renamed,
"*.make_initializable_iterator":
make_initializable_iterator_deprecation,
"*.make_one_shot_iterator":
make_one_shot_iterator_deprecation,
"tf.assert_greater": "tf.assert_greater":
assert_return_type_comment, assert_return_type_comment,
"tf.assert_equal": "tf.assert_equal":
@ -834,11 +904,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
default_loss_reduction_changed, default_loss_reduction_changed,
"tf.estimator.BaselineRegressor": "tf.estimator.BaselineRegressor":
default_loss_reduction_changed, default_loss_reduction_changed,
"tf.hessians":
"tf.hessians no longer takes "
"'colocate_gradients_with_ops' argument. Also, "
"arguments have been reordered so that 'name' is the "
"last argument.",
"tf.nn.conv1d": "tf.nn.conv1d":
"WARNING: use_cudnn_on_gpu argument has been removed and \"value\"" "WARNING: use_cudnn_on_gpu argument has been removed and \"value\""
" was renamed to \"input\"", " was renamed to \"input\"",
@ -1064,111 +1129,118 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
metrics_comment, metrics_comment,
} }
# Warnings that are emitted only if a specific arg is found.
self.function_arg_warnings = {
"tf.gradients": {
"colocate_gradients_with_ops":
"tf.gradients no longer takes "
"'colocate_gradients_with_ops' argument, it behaves as if it "
"was set to True.",
},
"*.minimize": {
"colocate_gradients_with_ops":
"Optimizer.minimize no longer takes "
"'colocate_gradients_with_ops' argument, it behaves as if it "
"was set to True.",
},
"*.compute_gradients": {
"colocate_gradients_with_ops":
"Optimizer.compute_gradients no "
"longer takes 'colocate_gradients_with_ops' argument, it "
"behaves as if it was set to True.",
},
}
self.symbol_renames = { self.symbol_renames = {
name: new_name name: new_name
for name, new_name in self.symbol_renames.items() for name, new_name in self.symbol_renames.items()
} }
export_saved_model_renamed = (
"(Manual edit required) Please rename the method export_savedmodel() "
"to export_saved_model(). Two things to note:\n\t(1) The argument "
"strip_default_attributes has been removed. The function will always "
"strip the default attributes from ops. If this breaks your code, "
"please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
"only effects core estimator. If you are using "
"tf.contrib.learn.Estimator, please switch to using core estimator.")
make_initializable_iterator_deprecation = (
"(Manual edit required) The "
"`tf.data.Dataset.make_initializable_iterator()` method has been "
"removed. If you are using the Estimator API, you can return a dataset "
"directly from your input functions without creating an iterator. "
"As a last resort, please replace calls to that method on `dataset` "
"with a call to "
"`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
make_one_shot_iterator_deprecation = (
"(Manual edit required) The "
"`tf.data.Dataset.make_one_shot_iterator()` method has been "
"removed. If you are using eager execution, you can iterate over "
"`dataset` using a Python `for` loop. If you are using the Estimator "
"API, you can return a dataset directly from your input functions "
"without creating an iterator. As a last resort, please replace calls "
"to that method on `dataset` with a call to "
"`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
# Specify warnings for functions that aren't restricted to the tf.x.y.z
# format. This should only be used for methods with unique names, e.g.
# export_savedmodel, which is only defined in Estimator objects.
self.unrestricted_function_warnings = {
"export_savedmodel": export_saved_model_renamed,
"make_initializable_iterator": make_initializable_iterator_deprecation,
"make_one_shot_iterator": make_one_shot_iterator_deprecation,
}
@staticmethod @staticmethod
def _dropout_handler(file_edit_recorder, node, lines): def _dropout_transformer(parent, node, full_name, name, logs, errors):
del lines def _replace_keep_prob_node(parent, old_value):
"""Replaces old_value with 1-(old_value)."""
one = ast.Num(n=1)
one.lineno = 0
one.col_offset = 0
new_value = ast.BinOp(left=one, op=ast.Sub(),
right=old_value)
# This copies the prefix and suffix on old_value to new_value.
pasta.ast_utils.replace_child(parent, old_value, new_value)
ast.copy_location(new_value, old_value)
# Put parentheses around keep_prob.value (and remove the old prefix/
# suffix, they should only be around new_value).
pasta.base.formatting.set(old_value, "prefix", "(")
pasta.base.formatting.set(old_value, "suffix", ")")
# Check if we have a keep_prob keyword arg
for keep_prob in node.keywords:
if keep_prob.arg == "keep_prob":
logs.append((node.lineno, node.col_offset,
"Changing keep_prob arg of tf.nn.dropout to rate, and "
"recomputing value. Please check this transformation.\n"))
keep_prob.arg = "rate"
_replace_keep_prob_node(keep_prob, keep_prob.value)
return node
# Maybe it was a positional arg
if len(node.args) < 2: if len(node.args) < 2:
comment = ("ERROR: tf.nn.dropout did not take arguments, so automatic " errors.append((node.lineno, node.col_offset,
"transformation was disabled. tf.nn.dropout has changed " "ERROR: tf.nn.dropout called without arguments, so "
"the semantics of the second argument.") "automatic fix was disabled. tf.nn.dropout has changed "
file_edit_recorder.add( "the semantics of the second argument."))
comment,
node.lineno,
node.col_offset,
"tf.nn.dropout",
"tf.nn.dropout",
error="tf.nn.dropout requires manual check.")
else: else:
comment = ("WARNING: tf.nn.dropout has changed the semantics of the " _replace_keep_prob_node(node, node.args[1])
"second argument. Please check the transformation.\n") logs.append((node.lineno, node.col_offset,
file_edit_recorder.add( "Changing keep_prob arg of tf.nn.dropout to rate, and "
comment, "recomputing value.\n"))
node.args[1].lineno, errors.append((node.lineno, node.col_offset,
node.args[1].col_offset, "WARNING: tf.nn.dropout has changed the semantics of the "
"", "second argument. Please check the applied transformation."
"1 - ") ))
return node
@staticmethod @staticmethod
def _colocate_handler(name): def _cast_transformer(parent, node, full_name, name, logs, errors):
def _helper(file_edit_recorder, node, lines): """Transforms to_int and to_float to cast(..., dtype=...)."""
"""Handler for updating colocate arguments."""
del lines # Find out the dtype to cast to from the function name
for keyword in node.keywords: dtype_str = name[3:]
if keyword.arg == "colocate_gradients_with_ops": new_arg = ast.keyword(arg="dtype",
# TODO(jhseu): Since ast_edit.py does string replacement, there's no value=ast.Attribute(value=ast.Name(id="tf",
# straightforward way to remove the argument. Try to fix before 2.0 is ctx=ast.Load()),
# final. attr=dtype_str, ctx=ast.Load()))
comment = ("For tf.gradients and tf.Optimizer.minimize, "
"colocate_gradients_with_op has been removed and now " # Python3 ast requires the args for the Attribute, but codegen will mess up
"defaults to True.") # the arg order if we just set them to 0.
file_edit_recorder.add( new_arg.value.lineno = node.lineno
comment, new_arg.value.col_offset = node.col_offset+100
node.lineno,
node.col_offset, node.keywords.append(new_arg)
"", if isinstance(node.func, ast.Attribute):
"", node.func.attr = "cast"
error="{} requires manual check.".format(name)) else:
return _helper assert isinstance(node.func, ast.Name)
node.func.id = "cast"
logs.append((node.lineno, node.col_offset,
"Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name,
dtype_str)))
return node
@staticmethod @staticmethod
def _batch_gather_handler(file_edit_recorder, node, lines): def _batch_gather_transformer(parent, node, full_name, name, logs, errors):
lineno = node.lineno # Check if the call already has a batch_dims argument
column = node.col_offset if any([kw.arg == "batch_dims" for kw in node.keywords]):
logs.append((node.lineno, node.col_offset, "tf.batch_gather already has "
"batch_dims argument. Neat."))
return None
# Find the position to add the batch_dims argument. We add it as the minus_one = ast.Num(n=-1)
# first argument, since that's easiest. This is safe because we included minus_one.lineno = 0
# batch_gather in self.reordered_function_names, so it will have all minus_one.col_offset = 0
# of its arguments changed to keyword arguments. new_arg = ast.keyword("batch_dims", minus_one)
m = re.match(r"tf\s*\.\s*batch_gather\s*\(", lines[lineno - 1][column:]) node.keywords.append(new_arg)
if m is not None: logs.append((node.lineno, node.col_offset,
file_edit_recorder.add( "Added keyword argument batch_dims=-1 to tf.batch_gather."))
"Added keyword argument 'batch_dims=-1' to 'tf.batch_gather'", return node
lineno, column + m.end(), "", "batch_dims=-1, ")
else:
file_edit_recorder.add(
"Unable to add keyword argument 'batch_dims=-1' to 'tf.batch_gather'",
lineno, column, "", "",
error="Unable to add keyword argument batch_dims=-1 to "
"tf.batch_gather; please add it manually.")

View File

@ -239,8 +239,8 @@ class TestUpgrade(test_util.TensorFlowTestCase):
} }
function_warnings = ( function_warnings = (
tf_upgrade_v2.TFAPIChangeSpec().function_warnings) tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
function_handles = ( function_transformers = (
tf_upgrade_v2.TFAPIChangeSpec().function_handle) tf_upgrade_v2.TFAPIChangeSpec().function_transformers)
keyword_renames = ( keyword_renames = (
tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames) tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
@ -255,7 +255,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
for name in names_v1: for name in names_v1:
tf_name = "tf.%s" % name tf_name = "tf.%s" % name
if tf_name in function_warnings or tf_name in function_handles: if tf_name in function_warnings or tf_name in function_transformers:
continue # These require manual change continue # These require manual change
if tf_name in v1_name_exceptions: if tf_name in v1_name_exceptions:
continue continue
@ -362,15 +362,14 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = "%s(a, b)\n" % decay text = "%s(a, b)\n" % decay
_, report, errors, _ = self._upgrade(text) _, report, errors, _ = self._upgrade(text)
self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay]) self.assertIn("%s requires manual check" % decay, errors[0])
self.assertIn("%s has been changed" % decay, report) self.assertIn("%s has been changed" % decay, report)
def testPiecewiseDecay(self): def testPiecewiseDecay(self):
text = "tf.train.piecewise_constant_decay(a, b)\n" text = "tf.train.piecewise_constant_decay(a, b)\n"
_, report, errors, _ = self._upgrade(text) _, report, errors, _ = self._upgrade(text)
self.assertEqual( self.assertIn("tf.train.piecewise_constant_decay requires manual check",
errors, errors[0])
["test.py:1: tf.train.piecewise_constant_decay requires manual check."])
self.assertIn("tf.train.piecewise_constant_decay has been changed", report) self.assertIn("tf.train.piecewise_constant_decay has been changed", report)
def testMetrics(self): def testMetrics(self):
@ -414,7 +413,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = ns + "(a, b)" text = ns + "(a, b)"
_, report, errors, new_text = self._upgrade(text) _, report, errors, new_text = self._upgrade(text)
self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text) self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns]) self.assertIn("test.py:1:0: %s requires manual check" % ns, errors[0])
self.assertIn( self.assertIn(
"WARNING: tf.metrics have been converted to object oriented" "WARNING: tf.metrics have been converted to object oriented"
" versions in TF 2.0 and after. The metric function calls have been " " versions in TF 2.0 and after. The metric function calls have been "
@ -445,7 +444,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = ns + "(a, b)" text = ns + "(a, b)"
_, report, errors, new_text = self._upgrade(text) _, report, errors, new_text = self._upgrade(text)
self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text) self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns]) self.assertIn("test.py:1:0: %s requires manual check" % ns, errors[0])
self.assertIn( self.assertIn(
"WARNING: tf.losses have been converted to object oriented" "WARNING: tf.losses have been converted to object oriented"
" versions in TF 2.0 and after. The loss function calls have been " " versions in TF 2.0 and after. The loss function calls have been "
@ -463,7 +462,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = ns + "(a, b)" text = ns + "(a, b)"
_, report, errors, new_text = self._upgrade(text) _, report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text) self.assertEqual(text, new_text)
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns]) self.assertIn("%s requires manual check" % ns, errors[0])
self.assertIn("loss_reduction has been changed", report) self.assertIn("loss_reduction has been changed", report)
def testDropout(self): def testDropout(self):
@ -471,15 +470,40 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
_, unused_report, unused_errors, new_text = self._upgrade(text) _, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual( self.assertEqual(
new_text, new_text,
"tf.nn.dropout(x, 1 - keep_prob, name=\"foo\")\n", "tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n",
)
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(
new_text,
"tf.nn.dropout(x, rate=1 - (.4), name=\"foo\")\n",
)
text = (
"tf.nn.dropout(x, # Stuff before\n"
" keep_prob=.4, # Stuff after\n"
" name=\"foo\")\n"
)
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(
new_text,
"tf.nn.dropout(x, # Stuff before\n"
" rate=1 - (.4), # Stuff after\n"
" name=\"foo\")\n",
) )
text = "tf.nn.dropout(x)\n" text = "tf.nn.dropout(x)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, text) self.assertEqual(new_text, text)
self.assertIn("tf.nn.dropout called without arguments", errors[0])
def testDropoutExpr(self):
text = "tf.nn.dropout(x, 1 - func(3 + 4.), name=\"foo\")\n"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual( self.assertEqual(
errors, new_text,
["test.py:1: tf.nn.dropout requires manual check."] "tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n",
) )
def testCountNonZeroChanges(self): def testCountNonZeroChanges(self):
@ -543,9 +567,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = "tf.gradients(a, colocate_gradients_with_ops=False)\n" text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text) self.assertEqual("tf.gradients(a)\n", new_text)
self.assertEqual(errors, ["test.py:1: tf.gradients requires manual check."]) self.assertIn("tf.gradients", errors[0])
self.assertIn("requires manual check", errors[0])
def testColocateGradientsWithOpsMinimize(self):
text = "optimizer.minimize(a, foo=False)\n" text = "optimizer.minimize(a, foo=False)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text) self.assertEqual(text, new_text)
@ -553,10 +579,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n" text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text) self.assertEqual("optimizer.minimize(a)\n", new_text)
self.assertEqual(errors, self.assertIn("requires manual check", errors[0])
["test.py:1: Optimizer.minimize requires manual check."]) self.assertIn("minimize", errors[0])
def testColocateGradientsWithOpsComputeGradients(self):
text = "optimizer.compute_gradients(a, foo=False)\n" text = "optimizer.compute_gradients(a, foo=False)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text) self.assertEqual(text, new_text)
@ -564,10 +591,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n" text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
_, unused_report, errors, new_text = self._upgrade(text) _, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text) self.assertEqual("optimizer.compute_gradients(a)\n", new_text)
self.assertEqual(errors, self.assertIn("requires manual check", errors[0])
["test.py:1: Optimizer.compute_gradients " self.assertIn("compute_gradients", errors[0])
"requires manual check."])
def testExportSavedModelRename(self): def testExportSavedModelRename(self):
text = "self.est.export_savedmodel(path)" text = "self.est.export_savedmodel(path)"
@ -673,7 +699,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
_, report, errors, new_text = self._upgrade(text) _, report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text) self.assertEqual(new_text, expected_text)
self.assertIn( self.assertIn(
"tf.nn.softmax_cross_entropy_with_logits requires manual check.", "tf.nn.softmax_cross_entropy_with_logits requires manual check",
errors[0]) errors[0])
self.assertIn( self.assertIn(
"tf.nn.softmax_cross_entropy_with_logits behavior has changed. ", "tf.nn.softmax_cross_entropy_with_logits behavior has changed. ",
@ -817,27 +843,24 @@ tf.print('abc')
def testBatchGather(self): def testBatchGather(self):
text = "tf.batch_gather(foo, bar)" text = "tf.batch_gather(foo, bar)"
expected_text = "tf.gather(batch_dims=-1, params=foo, indices=bar)" expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
_, unused_report, unused_errors, new_text = self._upgrade(text) _, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text) self.assertIn(new_text, [expected_text1, expected_text2])
text = "tf.batch_gather(params=foo, indices=bar)" text = "tf.batch_gather(params=foo, indices=bar)"
expected_text = "tf.gather(batch_dims=-1, params=foo, indices=bar)" expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
_, unused_report, unused_errors, new_text = self._upgrade(text) _, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text) self.assertIn(new_text, [expected_text1, expected_text2])
text = "tf.batch_gather ( foo, bar)" def testCast(self):
expected_text = "tf.gather (batch_dims=-1, params=foo, indices=bar)" for dtype in ["int32", "int64", "float", "double",
_, unused_report, unused_errors, new_text = self._upgrade(text) "complex64", "complex128", "bfloat16"]:
self.assertEqual(new_text, expected_text) text = "tf.to_%s(x, name='test')" % dtype
expected_text = "tf.cast(x, name='test', dtype=tf.%s)" % dtype
text = "(tf.batch_gather\n(foo, bar))" _, unused_report, unused_errors, new_text = self._upgrade(text)
expected_text = "(tf.gather\n(params=foo, indices=bar))" self.assertEqual(expected_text, new_text)
expected_errors = ["test.py:1: Unable to add keyword argument batch_dims=-1"
" to tf.batch_gather; please add it manually."]
_, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(errors, expected_errors)
self.assertEqual(new_text, expected_text)
class TestUpgradeFiles(test_util.TensorFlowTestCase): class TestUpgradeFiles(test_util.TensorFlowTestCase):

View File

@ -169,6 +169,7 @@ filegroup(
"@local_config_sycl//sycl:LICENSE.text", "@local_config_sycl//sycl:LICENSE.text",
"@nasm//:LICENSE", "@nasm//:LICENSE",
"@nsync//:LICENSE", "@nsync//:LICENSE",
"@pasta//:LICENSE",
"@pcre//:LICENCE", "@pcre//:LICENCE",
"@png_archive//:LICENSE", "@png_archive//:LICENSE",
"@protobuf_archive//:LICENSE", "@protobuf_archive//:LICENSE",

View File

@ -29,6 +29,7 @@ load("//third_party/jpeg:workspace.bzl", jpeg = "repo")
load("//third_party/nasm:workspace.bzl", nasm = "repo") load("//third_party/nasm:workspace.bzl", nasm = "repo")
load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/kissfft:workspace.bzl", kissfft = "repo")
load("//third_party/keras_applications_archive:workspace.bzl", keras_applications = "repo") load("//third_party/keras_applications_archive:workspace.bzl", keras_applications = "repo")
load("//third_party/pasta:workspace.bzl", pasta = "repo")
def initialize_third_party(): def initialize_third_party():
""" Load third party repositories. See above load() statements. """ """ Load third party repositories. See above load() statements. """
@ -41,6 +42,7 @@ def initialize_third_party():
kissfft() kissfft()
jpeg() jpeg()
nasm() nasm()
pasta()
# Sanitize a dependency so that it works correctly from code that includes # Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule. # TensorFlow as a submodule.

1
third_party/pasta/BUILD vendored Normal file
View File

@ -0,0 +1 @@
# Empty BUILD file to force build system to see this directory at all.

29
third_party/pasta/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,29 @@
# Description:
# AST-based python refactoring.
licenses(["notice"]) # Apache2
exports_files(["LICENSE"])
py_library(
name = "pasta",
srcs = [
"__init__.py",
"augment/__init__.py",
"augment/errors.py",
"augment/import_utils.py",
"augment/inline.py",
"augment/rename.py",
"base/__init__.py",
"base/annotate.py",
"base/ast_constants.py",
"base/ast_utils.py",
"base/codegen.py",
"base/formatting.py",
"base/scope.py",
"base/test_utils.py",
"base/token_generator.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)

13
third_party/pasta/BUILD.system vendored Normal file
View File

@ -0,0 +1,13 @@
# Description: Pasta, AST based python refactoring.
licenses(["notice"]) # Apache2
filegroup(
name = "LICENSE",
visibility = ["//visibility:public"],
)
py_library(
name = "pasta",
visibility = ["//visibility:public"],
)

16
third_party/pasta/workspace.bzl vendored Normal file
View File

@ -0,0 +1,16 @@
"""Loads pasta python package."""
load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "pasta",
urls = [
"https://mirror.bazel.build/github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
"https://github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
],
strip_prefix = "pasta-c3d72cdee6fc806251949e912510444d58d7413c/pasta",
sha256 = "b5905f9cecc4b28363c563f3c4cb0545288bd35f7cc72c55066e97e53befc084",
build_file = "//third_party/pasta:BUILD.bazel",
system_build_file = "//third_party/pasta:BUILD.system",
)