Change "tensorflow" imports to "tensorflow.compat.v1" when running
tf_upgrade_v2 script in SAFETY mode. PiperOrigin-RevId: 248596358
This commit is contained in:
parent
dc563ad0fc
commit
13a5d637bd
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@ -39,6 +40,10 @@ WARNING = "WARNING"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
ImportRename = collections.namedtuple(
|
||||
"ImportRename", ["new_name", "excluded_prefixes"])
|
||||
|
||||
|
||||
def full_name_node(name, ctx=ast.Load()):
|
||||
"""Make an Attribute or Name node for name.
|
||||
|
||||
@ -101,6 +106,23 @@ def get_arg_value(node, arg_name, arg_pos=None):
|
||||
return (False, None)
|
||||
|
||||
|
||||
def excluded_from_module_rename(module, import_rename_spec):
|
||||
"""Check if this module import should not be renamed.
|
||||
|
||||
Args:
|
||||
module: (string) module name.
|
||||
import_rename_spec: ImportRename instance.
|
||||
|
||||
Returns:
|
||||
True if this import should not be renamed according to the
|
||||
import_rename_spec.
|
||||
"""
|
||||
for excluded_prefix in import_rename_spec.excluded_prefixes:
|
||||
if module.startswith(excluded_prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class APIChangeSpec(object):
|
||||
"""This class defines the transformations that need to happen.
|
||||
|
||||
@ -118,6 +140,8 @@ class APIChangeSpec(object):
|
||||
* `function_transformers`: maps function names to custom handlers
|
||||
* `module_deprecations`: maps module names to warnings that will be printed
|
||||
if the module is still used after all other transformations have run
|
||||
* `import_renames`: maps import name (must be a short name without '.')
|
||||
to ImportRename instance.
|
||||
|
||||
For an example, see `TFAPIChangeSpec`.
|
||||
"""
|
||||
@ -466,6 +490,133 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Import(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting an import node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
new_aliases = []
|
||||
import_updated = False
|
||||
import_renames = getattr(self._api_change_spec, "import_renames", {})
|
||||
|
||||
# This loop processes imports in the format
|
||||
# import foo as f, bar as b
|
||||
for import_alias in node.names:
|
||||
# Look for rename based on first component of from-import.
|
||||
# i.e. based on foo in foo.bar.
|
||||
import_first_component = import_alias.name.split(".")[0]
|
||||
import_rename_spec = import_renames.get(import_first_component, None)
|
||||
|
||||
if not import_rename_spec or excluded_from_module_rename(
|
||||
import_alias.name, import_rename_spec):
|
||||
new_aliases.append(import_alias) # no change needed
|
||||
continue
|
||||
|
||||
new_name = (
|
||||
import_rename_spec.new_name +
|
||||
import_alias.name[len(import_first_component):])
|
||||
|
||||
# If current import is
|
||||
# import foo
|
||||
# then new import should preserve imported name:
|
||||
# import new_foo as foo
|
||||
# This happens when module has just one component.
|
||||
new_asname = import_alias.asname
|
||||
if not new_asname and "." not in import_alias.name:
|
||||
new_asname = import_alias.name
|
||||
|
||||
new_alias = ast.alias(name=new_name, asname=new_asname)
|
||||
new_aliases.append(new_alias)
|
||||
import_updated = True
|
||||
|
||||
# Replace the node if at least one import needs to be updated.
|
||||
if import_updated:
|
||||
assert self._stack[-1] is node
|
||||
parent = self._stack[-2]
|
||||
|
||||
new_node = ast.Import(new_aliases)
|
||||
ast.copy_location(new_node, node)
|
||||
pasta.ast_utils.replace_child(parent, node, new_node)
|
||||
self.add_log(
|
||||
INFO, node.lineno, node.col_offset,
|
||||
"Changed import from %r to %r." %
|
||||
(pasta.dump(node), pasta.dump(new_node)))
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting an import-from node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
if not node.module:
|
||||
self.generic_visit(node)
|
||||
return
|
||||
|
||||
from_import = node.module
|
||||
|
||||
# Look for rename based on first component of from-import.
|
||||
# i.e. based on foo in foo.bar.
|
||||
from_import_first_component = from_import.split(".")[0]
|
||||
import_renames = getattr(self._api_change_spec, "import_renames", {})
|
||||
import_rename_spec = import_renames.get(from_import_first_component, None)
|
||||
if not import_rename_spec:
|
||||
self.generic_visit(node)
|
||||
return
|
||||
|
||||
# Split module aliases into the ones that require import update
|
||||
# and those that don't. For e.g. if we want to rename "a" to "b"
|
||||
# unless we import "a.c" in the following:
|
||||
# from a import c, d
|
||||
# we want to update import for "d" but not for "c".
|
||||
updated_aliases = []
|
||||
same_aliases = []
|
||||
for import_alias in node.names:
|
||||
full_module_name = "%s.%s" % (from_import, import_alias.name)
|
||||
if excluded_from_module_rename(full_module_name, import_rename_spec):
|
||||
same_aliases.append(import_alias)
|
||||
else:
|
||||
updated_aliases.append(import_alias)
|
||||
|
||||
if not updated_aliases:
|
||||
self.generic_visit(node)
|
||||
return
|
||||
|
||||
assert self._stack[-1] is node
|
||||
parent = self._stack[-2]
|
||||
|
||||
# Replace first component of from-import with new name.
|
||||
new_from_import = (
|
||||
import_rename_spec.new_name +
|
||||
from_import[len(from_import_first_component):])
|
||||
updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level)
|
||||
ast.copy_location(updated_node, node)
|
||||
pasta.ast_utils.replace_child(parent, node, updated_node)
|
||||
|
||||
# If some imports had to stay the same, add another import for them.
|
||||
additional_import_log = ""
|
||||
if same_aliases:
|
||||
same_node = ast.ImportFrom(from_import, same_aliases, node.level,
|
||||
col_offset=node.col_offset, lineno=node.lineno)
|
||||
ast.copy_location(same_node, node)
|
||||
parent.body.insert(parent.body.index(updated_node), same_node)
|
||||
# Apply indentation to new node.
|
||||
pasta.base.formatting.set(
|
||||
same_node, "prefix",
|
||||
pasta.base.formatting.get(updated_node, "prefix"))
|
||||
additional_import_log = " and %r" % pasta.dump(same_node)
|
||||
|
||||
self.add_log(
|
||||
INFO, node.lineno, node.col_offset,
|
||||
"Changed import from %r to %r%s." %
|
||||
(pasta.dump(node),
|
||||
pasta.dump(updated_node),
|
||||
additional_import_log))
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
class ASTCodeUpgrader(object):
|
||||
"""Handles upgrading a set of Python files using a given API change spec."""
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
All of the tests assume that we want to change from an API containing
|
||||
|
||||
import foo as f
|
||||
|
||||
def f(a, b, kw1, kw2): ...
|
||||
def g(a, b, kw1, c, kw1_alias): ...
|
||||
def g2(a, b, kw1, c, d, kw1_alias): ...
|
||||
@ -25,6 +27,8 @@ and the changes to the API consist of renaming, reordering, and/or removing
|
||||
arguments. Thus, we want to be able to generate changes to produce each of the
|
||||
following new APIs:
|
||||
|
||||
import bar as f
|
||||
|
||||
def f(a, b, kw1, kw3): ...
|
||||
def f(a, b, kw2, kw1): ...
|
||||
def f(a, b, kw3, kw1): ...
|
||||
@ -59,6 +63,7 @@ class NoUpdateSpec(ast_edits.APIChangeSpec):
|
||||
self.function_warnings = {}
|
||||
self.change_to_function = {}
|
||||
self.module_deprecations = {}
|
||||
self.import_renames = {}
|
||||
|
||||
|
||||
class ModuleDeprecationSpec(NoUpdateSpec):
|
||||
@ -170,6 +175,18 @@ class RemoveMultipleKeywordArguments(NoUpdateSpec):
|
||||
}
|
||||
|
||||
|
||||
class RenameImports(NoUpdateSpec):
|
||||
"""Specification for renaming imports."""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
self.import_renames = {
|
||||
"foo": ast_edits.ImportRename(
|
||||
"bar",
|
||||
excluded_prefixes=["foo.baz"])
|
||||
}
|
||||
|
||||
|
||||
class TestAstEdits(test_util.TensorFlowTestCase):
|
||||
|
||||
def _upgrade(self, spec, old_file_text):
|
||||
@ -458,5 +475,112 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
||||
"ctx=Load()), attr='c', ctx=Load())"
|
||||
)
|
||||
|
||||
def testImport(self):
|
||||
# foo should be renamed to bar.
|
||||
text = "import foo as f"
|
||||
expected_text = "import bar as f"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import foo"
|
||||
expected_text = "import bar as foo"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import foo.test"
|
||||
expected_text = "import bar.test"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import foo.test as t"
|
||||
expected_text = "import bar.test as t"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import foo as f, a as b"
|
||||
expected_text = "import bar as f, a as b"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testFromImport(self):
|
||||
# foo should be renamed to bar.
|
||||
text = "from foo import a"
|
||||
expected_text = "from bar import a"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "from foo.a import b"
|
||||
expected_text = "from bar.a import b"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "from foo import *"
|
||||
expected_text = "from bar import *"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "from foo import a, b"
|
||||
expected_text = "from bar import a, b"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testImport_NoChangeNeeded(self):
|
||||
text = "import bar as b"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
def testFromImport_NoChangeNeeded(self):
|
||||
text = "from bar import a as b"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
def testExcludedImport(self):
|
||||
# foo.baz module is excluded from changes.
|
||||
text = "import foo.baz"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
text = "import foo.baz as a"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
text = "from foo import baz as a"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
text = "from foo.baz import a"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
def testMultipleImports(self):
|
||||
text = "import foo.bar as a, foo.baz as b, foo.baz.c, foo.d"
|
||||
expected_text = "import bar.bar as a, foo.baz as b, foo.baz.c, bar.d"
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "from foo import baz, a, c"
|
||||
expected_text = """from foo import baz
|
||||
from bar import a, c"""
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testImportInsideFunction(self):
|
||||
text = """
|
||||
def t():
|
||||
from c import d
|
||||
from foo import baz, a
|
||||
from e import y
|
||||
"""
|
||||
expected_text = """
|
||||
def t():
|
||||
from c import d
|
||||
from foo import baz
|
||||
from bar import a
|
||||
from e import y
|
||||
"""
|
||||
_, new_text = self._upgrade(RenameImports(), text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
@ -34,5 +34,16 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
self.function_transformers = {}
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
|
||||
# List module renames. Right now, we just support renames from a module
|
||||
# names that don't contain '.'.
|
||||
self.import_renames = {
|
||||
"tensorflow": ast_edits.ImportRename(
|
||||
"tensorflow.compat.v1",
|
||||
excluded_prefixes=["tensorflow.contrib",
|
||||
"tensorflow.flags",
|
||||
"tensorflow.compat.v1",
|
||||
"tensorflow.compat.v2"])
|
||||
}
|
||||
|
||||
# TODO(kaftan,annarev): specify replacement from TensorFlow import to
|
||||
# compat.v1 import.
|
||||
|
@ -42,6 +42,65 @@ class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
|
||||
expected_info = "tf.contrib will not be distributed"
|
||||
self.assertIn(expected_info, report)
|
||||
|
||||
def testTensorFlowImport(self):
|
||||
text = "import tensorflow as tf"
|
||||
expected_text = "import tensorflow.compat.v1 as tf"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import tensorflow"
|
||||
expected_text = "import tensorflow.compat.v1 as tensorflow"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import tensorflow.foo"
|
||||
expected_text = "import tensorflow.compat.v1.foo"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "import tensorflow.foo as bar"
|
||||
expected_text = "import tensorflow.compat.v1.foo as bar"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testTensorFlowFromImport(self):
|
||||
text = "from tensorflow import foo"
|
||||
expected_text = "from tensorflow.compat.v1 import foo"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "from tensorflow.foo import bar"
|
||||
expected_text = "from tensorflow.compat.v1.foo import bar"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "from tensorflow import *"
|
||||
expected_text = "from tensorflow.compat.v1 import *"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testTensorFlowImportAlreadyHasCompat(self):
|
||||
text = "import tensorflow.compat.v1 as tf"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
text = "import tensorflow.compat.v2 as tf"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
text = "from tensorflow.compat import v2 as tf"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
def testTensorFlowDontChangeContrib(self):
|
||||
text = "import tensorflow.contrib as foo"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
text = "from tensorflow import contrib"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(text, new_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
Loading…
Reference in New Issue
Block a user