Change "tensorflow" imports to "tensorflow.compat.v1" when running

tf_upgrade_v2 script in SAFETY mode.

PiperOrigin-RevId: 248596358
This commit is contained in:
Anna R 2019-05-16 14:07:43 -07:00 committed by TensorFlower Gardener
parent dc563ad0fc
commit 13a5d637bd
4 changed files with 345 additions and 0 deletions

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import ast
import collections
import os
import re
import shutil
@ -39,6 +40,10 @@ WARNING = "WARNING"
ERROR = "ERROR"
ImportRename = collections.namedtuple(
"ImportRename", ["new_name", "excluded_prefixes"])
def full_name_node(name, ctx=ast.Load()):
"""Make an Attribute or Name node for name.
@ -101,6 +106,23 @@ def get_arg_value(node, arg_name, arg_pos=None):
return (False, None)
def excluded_from_module_rename(module, import_rename_spec):
"""Check if this module import should not be renamed.
Args:
module: (string) module name.
import_rename_spec: ImportRename instance.
Returns:
True if this import should not be renamed according to the
import_rename_spec.
"""
for excluded_prefix in import_rename_spec.excluded_prefixes:
if module.startswith(excluded_prefix):
return True
return False
class APIChangeSpec(object):
"""This class defines the transformations that need to happen.
@ -118,6 +140,8 @@ class APIChangeSpec(object):
* `function_transformers`: maps function names to custom handlers
* `module_deprecations`: maps module names to warnings that will be printed
if the module is still used after all other transformations have run
* `import_renames`: maps import name (must be a short name without '.')
to ImportRename instance.
For an example, see `TFAPIChangeSpec`.
"""
@ -466,6 +490,133 @@ class _PastaEditVisitor(ast.NodeVisitor):
self.generic_visit(node)
def visit_Import(self, node): # pylint: disable=invalid-name
"""Handle visiting an import node in the AST.
Args:
node: Current Node
"""
new_aliases = []
import_updated = False
import_renames = getattr(self._api_change_spec, "import_renames", {})
# This loop processes imports in the format
# import foo as f, bar as b
for import_alias in node.names:
# Look for rename based on first component of from-import.
# i.e. based on foo in foo.bar.
import_first_component = import_alias.name.split(".")[0]
import_rename_spec = import_renames.get(import_first_component, None)
if not import_rename_spec or excluded_from_module_rename(
import_alias.name, import_rename_spec):
new_aliases.append(import_alias) # no change needed
continue
new_name = (
import_rename_spec.new_name +
import_alias.name[len(import_first_component):])
# If current import is
# import foo
# then new import should preserve imported name:
# import new_foo as foo
# This happens when module has just one component.
new_asname = import_alias.asname
if not new_asname and "." not in import_alias.name:
new_asname = import_alias.name
new_alias = ast.alias(name=new_name, asname=new_asname)
new_aliases.append(new_alias)
import_updated = True
# Replace the node if at least one import needs to be updated.
if import_updated:
assert self._stack[-1] is node
parent = self._stack[-2]
new_node = ast.Import(new_aliases)
ast.copy_location(new_node, node)
pasta.ast_utils.replace_child(parent, node, new_node)
self.add_log(
INFO, node.lineno, node.col_offset,
"Changed import from %r to %r." %
(pasta.dump(node), pasta.dump(new_node)))
self.generic_visit(node)
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
"""Handle visiting an import-from node in the AST.
Args:
node: Current Node
"""
if not node.module:
self.generic_visit(node)
return
from_import = node.module
# Look for rename based on first component of from-import.
# i.e. based on foo in foo.bar.
from_import_first_component = from_import.split(".")[0]
import_renames = getattr(self._api_change_spec, "import_renames", {})
import_rename_spec = import_renames.get(from_import_first_component, None)
if not import_rename_spec:
self.generic_visit(node)
return
# Split module aliases into the ones that require import update
# and those that don't. For e.g. if we want to rename "a" to "b"
# unless we import "a.c" in the following:
# from a import c, d
# we want to update import for "d" but not for "c".
updated_aliases = []
same_aliases = []
for import_alias in node.names:
full_module_name = "%s.%s" % (from_import, import_alias.name)
if excluded_from_module_rename(full_module_name, import_rename_spec):
same_aliases.append(import_alias)
else:
updated_aliases.append(import_alias)
if not updated_aliases:
self.generic_visit(node)
return
assert self._stack[-1] is node
parent = self._stack[-2]
# Replace first component of from-import with new name.
new_from_import = (
import_rename_spec.new_name +
from_import[len(from_import_first_component):])
updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level)
ast.copy_location(updated_node, node)
pasta.ast_utils.replace_child(parent, node, updated_node)
# If some imports had to stay the same, add another import for them.
additional_import_log = ""
if same_aliases:
same_node = ast.ImportFrom(from_import, same_aliases, node.level,
col_offset=node.col_offset, lineno=node.lineno)
ast.copy_location(same_node, node)
parent.body.insert(parent.body.index(updated_node), same_node)
# Apply indentation to new node.
pasta.base.formatting.set(
same_node, "prefix",
pasta.base.formatting.get(updated_node, "prefix"))
additional_import_log = " and %r" % pasta.dump(same_node)
self.add_log(
INFO, node.lineno, node.col_offset,
"Changed import from %r to %r%s." %
(pasta.dump(node),
pasta.dump(updated_node),
additional_import_log))
self.generic_visit(node)
class ASTCodeUpgrader(object):
"""Handles upgrading a set of Python files using a given API change spec."""

View File

@ -16,6 +16,8 @@
All of the tests assume that we want to change from an API containing
import foo as f
def f(a, b, kw1, kw2): ...
def g(a, b, kw1, c, kw1_alias): ...
def g2(a, b, kw1, c, d, kw1_alias): ...
@ -25,6 +27,8 @@ and the changes to the API consist of renaming, reordering, and/or removing
arguments. Thus, we want to be able to generate changes to produce each of the
following new APIs:
import bar as f
def f(a, b, kw1, kw3): ...
def f(a, b, kw2, kw1): ...
def f(a, b, kw3, kw1): ...
@ -59,6 +63,7 @@ class NoUpdateSpec(ast_edits.APIChangeSpec):
self.function_warnings = {}
self.change_to_function = {}
self.module_deprecations = {}
self.import_renames = {}
class ModuleDeprecationSpec(NoUpdateSpec):
@ -170,6 +175,18 @@ class RemoveMultipleKeywordArguments(NoUpdateSpec):
}
class RenameImports(NoUpdateSpec):
"""Specification for renaming imports."""
def __init__(self):
NoUpdateSpec.__init__(self)
self.import_renames = {
"foo": ast_edits.ImportRename(
"bar",
excluded_prefixes=["foo.baz"])
}
class TestAstEdits(test_util.TensorFlowTestCase):
def _upgrade(self, spec, old_file_text):
@ -458,5 +475,112 @@ class TestAstEdits(test_util.TensorFlowTestCase):
"ctx=Load()), attr='c', ctx=Load())"
)
def testImport(self):
# foo should be renamed to bar.
text = "import foo as f"
expected_text = "import bar as f"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "import foo"
expected_text = "import bar as foo"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "import foo.test"
expected_text = "import bar.test"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "import foo.test as t"
expected_text = "import bar.test as t"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "import foo as f, a as b"
expected_text = "import bar as f, a as b"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
def testFromImport(self):
# foo should be renamed to bar.
text = "from foo import a"
expected_text = "from bar import a"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "from foo.a import b"
expected_text = "from bar.a import b"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "from foo import *"
expected_text = "from bar import *"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "from foo import a, b"
expected_text = "from bar import a, b"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
def testImport_NoChangeNeeded(self):
text = "import bar as b"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(text, new_text)
def testFromImport_NoChangeNeeded(self):
text = "from bar import a as b"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(text, new_text)
def testExcludedImport(self):
# foo.baz module is excluded from changes.
text = "import foo.baz"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(text, new_text)
text = "import foo.baz as a"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(text, new_text)
text = "from foo import baz as a"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(text, new_text)
text = "from foo.baz import a"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(text, new_text)
def testMultipleImports(self):
text = "import foo.bar as a, foo.baz as b, foo.baz.c, foo.d"
expected_text = "import bar.bar as a, foo.baz as b, foo.baz.c, bar.d"
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
text = "from foo import baz, a, c"
expected_text = """from foo import baz
from bar import a, c"""
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
def testImportInsideFunction(self):
text = """
def t():
from c import d
from foo import baz, a
from e import y
"""
expected_text = """
def t():
from c import d
from foo import baz
from bar import a
from e import y
"""
_, new_text = self._upgrade(RenameImports(), text)
self.assertEqual(expected_text, new_text)
if __name__ == "__main__":
test_lib.main()

View File

@ -34,5 +34,16 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
self.function_transformers = {}
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
# List module renames. Right now, we just support renames from a module
# names that don't contain '.'.
self.import_renames = {
"tensorflow": ast_edits.ImportRename(
"tensorflow.compat.v1",
excluded_prefixes=["tensorflow.contrib",
"tensorflow.flags",
"tensorflow.compat.v1",
"tensorflow.compat.v2"])
}
# TODO(kaftan,annarev): specify replacement from TensorFlow import to
# compat.v1 import.

View File

@ -42,6 +42,65 @@ class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
expected_info = "tf.contrib will not be distributed"
self.assertIn(expected_info, report)
def testTensorFlowImport(self):
text = "import tensorflow as tf"
expected_text = "import tensorflow.compat.v1 as tf"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "import tensorflow"
expected_text = "import tensorflow.compat.v1 as tensorflow"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "import tensorflow.foo"
expected_text = "import tensorflow.compat.v1.foo"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "import tensorflow.foo as bar"
expected_text = "import tensorflow.compat.v1.foo as bar"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def testTensorFlowFromImport(self):
text = "from tensorflow import foo"
expected_text = "from tensorflow.compat.v1 import foo"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "from tensorflow.foo import bar"
expected_text = "from tensorflow.compat.v1.foo import bar"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "from tensorflow import *"
expected_text = "from tensorflow.compat.v1 import *"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def testTensorFlowImportAlreadyHasCompat(self):
text = "import tensorflow.compat.v1 as tf"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(text, new_text)
text = "import tensorflow.compat.v2 as tf"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(text, new_text)
text = "from tensorflow.compat import v2 as tf"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(text, new_text)
def testTensorFlowDontChangeContrib(self):
text = "import tensorflow.contrib as foo"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(text, new_text)
text = "from tensorflow import contrib"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(text, new_text)
if __name__ == "__main__":
test_lib.main()