Improve upgrade script to handle list comprehensions as arguments. (#7229)
python's ast module does not return the correct location, so we have to do our best to scan backwards to find where the [ token that trully started the list comprehension occurs.
This commit is contained in:
parent
ee770d990f
commit
114a4627cb
@ -347,6 +347,62 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
items.append(curr.id)
|
||||
return ".".join(reversed(items))
|
||||
|
||||
def _find_true_position(self, node):
|
||||
"""Return correct line number and column offset for a given node.
|
||||
|
||||
This is necessary mainly because ListComp's location reporting reports
|
||||
the next token after the list comprehension list opening.
|
||||
|
||||
Args:
|
||||
node: Node for which we wish to know the lineno and col_offset
|
||||
"""
|
||||
import re
|
||||
find_open = re.compile("^\s*(\\[).*$")
|
||||
find_string_chars = re.compile("['\"]")
|
||||
|
||||
if isinstance(node, ast.ListComp):
|
||||
# Strangely, ast.ListComp returns the col_offset of the first token
|
||||
# after the '[' token which appears to be a bug. Workaround by
|
||||
# explicitly finding the real start of the list comprehension.
|
||||
line = node.lineno
|
||||
col = node.col_offset
|
||||
# loop over lines
|
||||
while 1:
|
||||
# Reverse the text to and regular expression search for whitespace
|
||||
text = self._lines[line-1]
|
||||
reversed_preceding_text = text[:col][::-1]
|
||||
# First find if a [ can be found with only whitespace between it and
|
||||
# col.
|
||||
m = find_open.match(reversed_preceding_text)
|
||||
if m:
|
||||
new_col_offset = col - m.start(1) - 1
|
||||
return line, new_col_offset
|
||||
else:
|
||||
if (reversed_preceding_text=="" or
|
||||
reversed_preceding_text.isspace()):
|
||||
line = line - 1
|
||||
prev_line = self._lines[line - 1]
|
||||
# TODO(aselle):
|
||||
# this is poor comment detection, but it is good enough for
|
||||
# cases where the comment does not contain string literal starting/
|
||||
# ending characters. If ast gave us start and end locations of the
|
||||
# ast nodes rather than just start, we could use string literal
|
||||
# node ranges to filter out spurious #'s that appear in string
|
||||
# literals.
|
||||
comment_start = prev_line.find("#")
|
||||
if comment_start == -1:
|
||||
col = len(prev_line) -1
|
||||
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||
col = comment_start
|
||||
else:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
# Most other nodes return proper locations (with notably does not), but
|
||||
# it is not possible to use that in an argument.
|
||||
return node.lineno, node.col_offset
|
||||
|
||||
|
||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting a call node in the AST.
|
||||
|
||||
@ -376,13 +432,21 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
if full_name in function_reorders:
|
||||
reordered = function_reorders[full_name]
|
||||
for idx, arg in enumerate(node.args):
|
||||
keyword_arg = reordered[idx]
|
||||
if (full_name in function_keyword_renames and
|
||||
keyword_arg in function_keyword_renames[full_name]):
|
||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), arg.lineno,
|
||||
arg.col_offset, "", keyword_arg + "=")
|
||||
lineno, col_offset = self._find_true_position(arg)
|
||||
if lineno is None or col_offset is None:
|
||||
self._file_edit.add(
|
||||
"Failed to add keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), arg.lineno, arg.col_offset,
|
||||
"", "",
|
||||
error="A necessary keyword argument failed to be inserted.")
|
||||
else:
|
||||
keyword_arg = reordered[idx]
|
||||
if (full_name in function_keyword_renames and
|
||||
keyword_arg in function_keyword_renames[full_name]):
|
||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||
self._file_edit.add("Added keyword %r to reordered function %r"
|
||||
% (reordered[idx], full_name), lineno,
|
||||
col_offset, "", keyword_arg + "=")
|
||||
|
||||
# Examine each keyword argument and convert it to the final renamed form
|
||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||
@ -390,12 +454,31 @@ class TensorFlowCallVisitor(ast.NodeVisitor):
|
||||
for keyword in node.keywords:
|
||||
argkey = keyword.arg
|
||||
argval = keyword.value
|
||||
|
||||
if argkey in renamed_keywords:
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
||||
if (argval_lineno is not None and argval_col_offset is not None):
|
||||
# TODO(aselle): We should scan backward to find the start of the
|
||||
# keyword key. Unfortunately ast does not give you the location of
|
||||
# keyword keys, so we are forced to infer it from the keyword arg
|
||||
# value.
|
||||
key_start = argval_col_offset - len(argkey) - 1
|
||||
key_end = key_start + len(argkey) + 1
|
||||
if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
argval_lineno,
|
||||
argval_col_offset - len(argkey) - 1,
|
||||
argkey + "=", renamed_keywords[argkey] + "=")
|
||||
continue
|
||||
self._file_edit.add(
|
||||
"Failed to rename keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
"", "",
|
||||
error="Failed to find keyword lexographically. Fix manually.")
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||
|
@ -113,6 +113,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(new_text, new_text)
|
||||
self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."])
|
||||
|
||||
def testListComprehension(self):
|
||||
def _test(input, output):
|
||||
_, unused_report, errors, new_text = self._upgrade(input)
|
||||
self.assertEqual(new_text, output)
|
||||
_test("tf.concat(0, \t[x for x in y])\n",
|
||||
"tf.concat(axis=0, \tvalues=[x for x in y])\n")
|
||||
_test("tf.concat(0,[x for x in y])\n",
|
||||
"tf.concat(axis=0,values=[x for x in y])\n")
|
||||
_test("tf.concat(0,[\nx for x in y])\n",
|
||||
"tf.concat(axis=0,values=[\nx for x in y])\n")
|
||||
_test("tf.concat(0,[\n \tx for x in y])\n",
|
||||
"tf.concat(axis=0,values=[\n \tx for x in y])\n")
|
||||
|
||||
# TODO(aselle): Explicitly not testing command line interface and process_tree
|
||||
# for now, since this is a one off utility.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user