Added an `--upgrade_compat_v1_import` flag to the upgrade script that allows it to upgrade `import tensorflow.compat.v1 as tf`imports to `import tensorflow as tf` imports. Note that this flag does not upgrade tf.compat.v1 when imported under other aliases, such as `import tensorflow.compat.v1 as tfv1`

PiperOrigin-RevId: 304323241
Change-Id: I0bbad98cb8642969a7ad9bb0bca7e1ebbe2620c2
This commit is contained in:
A. Unique TensorFlower 2020-04-01 21:28:01 -07:00 committed by TensorFlower Gardener
parent b3b6d2ab7b
commit cd84bc9820
4 changed files with 89 additions and 13 deletions

View File

@ -213,8 +213,8 @@ class APIChangeSpec(object):
"""
def preprocess(self, root_node): # pylint: disable=unused-argument
"""Preprocess a parse tree. Return any produced logs and errors."""
return [], []
"""Preprocess a parse tree. Return a preprocessed node, logs and errors."""
return root_node, [], []
def clear_preprocessing(self):
"""Restore this APIChangeSpec to before it preprocessed a file.
@ -942,7 +942,7 @@ class ASTCodeUpgrader(object):
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
return 0, "", log, []
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
visitor = _PastaEditVisitor(self._api_change_spec)
visitor.visit(t)

View File

@ -54,21 +54,47 @@ class VersionedTFImport(ast_edits.AnalysisResult):
"` was directly imported as `tf`.")
compat_v1_import = VersionedTFImport("compat.v1")
compat_v2_import = VersionedTFImport("compat.v2")
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
def __init__(self):
self.symbols_to_detect = {}
self.imports_to_detect = {
("tensorflow", None): UnaliasedTFImport(),
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
("tensorflow.compat.v1", "tf"): compat_v1_import,
("tensorflow.compat.v2", "tf"): compat_v2_import,
}
class CompatV1ImportReplacer(ast.NodeVisitor):
"""AST Visitor that replaces `import tensorflow.compat.v1 as tf`.
Converts `import tensorflow.compat.v1 as tf` to `import tensorflow as tf`
"""
def visit_Import(self, node): # pylint: disable=invalid-name
"""Handle visiting an import node in the AST.
Args:
node: Current Node
"""
for import_alias in node.names:
# Detect based on full import name and alias
if (import_alias.name == "tensorflow.compat.v1" and
import_alias.asname == "tf"):
import_alias.name = "tensorflow"
self.generic_visit(node)
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
"""List of maps that describe what changed in the API."""
def __init__(self, import_rename=False):
def __init__(self, import_rename=False, upgrade_compat_v1_import=False):
self.upgrade_compat_v1_import = upgrade_compat_v1_import
# Maps from a function name to a dictionary that describes how to
# map from an old argument keyword to the new argument keyword.
# If the new argument is None, it will be removed.
@ -1612,10 +1638,21 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
def preprocess(self, root_node):
def preprocess(self, root_node, after_compat_v1_upgrade=False):
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
visitor.visit(root_node)
detections = set(visitor.results)
# Upgrade explicit compat v1 imports if `upgrade_compat_v1_import` is
# enabled. Then preprocess the updated root node.
# We only do this upgrading once, because some forms of the import may
# still cause errors but aren't trivially upgradeable, and we don't want
# to enter an infinite loop. E.g. `from tensorflow.compat import v1, v2`.
if (compat_v1_import in detections and self.upgrade_compat_v1_import and
not after_compat_v1_upgrade):
CompatV1ImportReplacer().visit(root_node)
return self.preprocess(root_node, after_compat_v1_upgrade=True)
# If we have detected the presence of imports of specific TF versions,
# We want to modify the update spec to check only module deprecations
# and skip all other conversions.
@ -1629,7 +1666,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
self.function_transformers = {}
self.import_renames = {}
return visitor.log, visitor.warnings_and_errors
return root_node, visitor.log, visitor.warnings_and_errors
def clear_preprocessing(self):
self.__init__()

View File

@ -101,7 +101,15 @@ Simple usage:
parser.add_argument(
"--no_import_rename",
dest="no_import_rename",
help=("Not to rename import to compact.v2 explicitly."),
help=("Not to rename import to compat.v2 explicitly."),
action="store_true")
parser.add_argument(
"--no_upgrade_compat_v1_import",
dest="no_upgrade_compat_v1_import",
help=("If specified, don't upgrade explicit imports of "
"`tensorflow.compat.v1 as tf` to the v2 apis. Otherwise, "
"explicit imports of the form `tensorflow.compat.v1 as tf` will "
"be upgraded."),
action="store_true")
parser.add_argument(
"--reportfile",
@ -132,10 +140,13 @@ Simple usage:
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
else:
if args.no_import_rename:
change_spec = tf_upgrade_v2.TFAPIChangeSpec(import_rename=False)
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
import_rename=False,
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
else:
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
import_rename=_IMPORT_RENAME_DEFAULT)
import_rename=_IMPORT_RENAME_DEFAULT,
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
report_text = None

View File

@ -117,11 +117,15 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
visitor.private_map["tf.compat"] = ["v1", "v2"]
traverse.traverse(tf.compat.v1, visitor)
def _upgrade(self, old_file_text, import_rename=False):
def _upgrade(self,
old_file_text,
import_rename=False,
upgrade_compat_v1_import=False):
in_file = six.StringIO(old_file_text)
out_file = six.StringIO()
upgrader = ast_edits.ASTCodeUpgrader(
tf_upgrade_v2.TFAPIChangeSpec(import_rename))
tf_upgrade_v2.TFAPIChangeSpec(
import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import))
count, report, errors = (
upgrader.process_opened_file("test.py", in_file,
"test_out.py", out_file))
@ -2215,6 +2219,30 @@ def _log_prob(self, x):
_, _, _, new_text = self._upgrade(text, import_rename=True)
self.assertEqual(new_text, expected_text)
import_header = ("import tensorflow.compat.v1 as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
text = import_header + old_symbol
expected_header = ("import tensorflow.compat.v2 as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
expected_text = expected_header + new_symbol
_, _, _, new_text = self._upgrade(
text, import_rename=True, upgrade_compat_v1_import=True)
self.assertEqual(new_text, expected_text)
import_header = ("import tensorflow.compat.v1 as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
text = import_header + old_symbol
expected_header = ("import tensorflow as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
expected_text = expected_header + new_symbol
_, _, _, new_text = self._upgrade(
text, import_rename=False, upgrade_compat_v1_import=True)
self.assertEqual(new_text, expected_text)
import_header = "from tensorflow import foo\n"
text = import_header + old_symbol
expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol