Added an --upgrade_compat_v1_import
flag to the upgrade script that allows it to upgrade import tensorflow.compat.v1 as tf
imports to import tensorflow as tf
imports. Note that this flag does not upgrade tf.compat.v1 when imported under other aliases, such as import tensorflow.compat.v1 as tfv1
PiperOrigin-RevId: 304323241 Change-Id: I0bbad98cb8642969a7ad9bb0bca7e1ebbe2620c2
This commit is contained in:
parent
b3b6d2ab7b
commit
cd84bc9820
@ -213,8 +213,8 @@ class APIChangeSpec(object):
|
||||
"""
|
||||
|
||||
def preprocess(self, root_node): # pylint: disable=unused-argument
|
||||
"""Preprocess a parse tree. Return any produced logs and errors."""
|
||||
return [], []
|
||||
"""Preprocess a parse tree. Return a preprocessed node, logs and errors."""
|
||||
return root_node, [], []
|
||||
|
||||
def clear_preprocessing(self):
|
||||
"""Restore this APIChangeSpec to before it preprocessed a file.
|
||||
@ -942,7 +942,7 @@ class ASTCodeUpgrader(object):
|
||||
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
|
||||
return 0, "", log, []
|
||||
|
||||
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
|
||||
t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
|
||||
|
||||
visitor = _PastaEditVisitor(self._api_change_spec)
|
||||
visitor.visit(t)
|
||||
|
@ -54,21 +54,47 @@ class VersionedTFImport(ast_edits.AnalysisResult):
|
||||
"` was directly imported as `tf`.")
|
||||
|
||||
|
||||
compat_v1_import = VersionedTFImport("compat.v1")
|
||||
compat_v2_import = VersionedTFImport("compat.v2")
|
||||
|
||||
|
||||
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.symbols_to_detect = {}
|
||||
self.imports_to_detect = {
|
||||
("tensorflow", None): UnaliasedTFImport(),
|
||||
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
|
||||
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
|
||||
("tensorflow.compat.v1", "tf"): compat_v1_import,
|
||||
("tensorflow.compat.v2", "tf"): compat_v2_import,
|
||||
}
|
||||
|
||||
|
||||
class CompatV1ImportReplacer(ast.NodeVisitor):
|
||||
"""AST Visitor that replaces `import tensorflow.compat.v1 as tf`.
|
||||
|
||||
Converts `import tensorflow.compat.v1 as tf` to `import tensorflow as tf`
|
||||
"""
|
||||
|
||||
def visit_Import(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting an import node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
for import_alias in node.names:
|
||||
# Detect based on full import name and alias
|
||||
if (import_alias.name == "tensorflow.compat.v1" and
|
||||
import_alias.asname == "tf"):
|
||||
import_alias.name = "tensorflow"
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||
"""List of maps that describe what changed in the API."""
|
||||
|
||||
def __init__(self, import_rename=False):
|
||||
def __init__(self, import_rename=False, upgrade_compat_v1_import=False):
|
||||
self.upgrade_compat_v1_import = upgrade_compat_v1_import
|
||||
|
||||
# Maps from a function name to a dictionary that describes how to
|
||||
# map from an old argument keyword to the new argument keyword.
|
||||
# If the new argument is None, it will be removed.
|
||||
@ -1612,10 +1638,21 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
|
||||
def preprocess(self, root_node):
|
||||
def preprocess(self, root_node, after_compat_v1_upgrade=False):
|
||||
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
|
||||
visitor.visit(root_node)
|
||||
detections = set(visitor.results)
|
||||
|
||||
# Upgrade explicit compat v1 imports if `upgrade_compat_v1_import` is
|
||||
# enabled. Then preprocess the updated root node.
|
||||
# We only do this upgrading once, because some forms of the import may
|
||||
# still cause errors but aren't trivially upgradeable, and we don't want
|
||||
# to enter an infinite loop. E.g. `from tensorflow.compat import v1, v2`.
|
||||
if (compat_v1_import in detections and self.upgrade_compat_v1_import and
|
||||
not after_compat_v1_upgrade):
|
||||
CompatV1ImportReplacer().visit(root_node)
|
||||
return self.preprocess(root_node, after_compat_v1_upgrade=True)
|
||||
|
||||
# If we have detected the presence of imports of specific TF versions,
|
||||
# We want to modify the update spec to check only module deprecations
|
||||
# and skip all other conversions.
|
||||
@ -1629,7 +1666,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
self.function_transformers = {}
|
||||
self.import_renames = {}
|
||||
return visitor.log, visitor.warnings_and_errors
|
||||
return root_node, visitor.log, visitor.warnings_and_errors
|
||||
|
||||
def clear_preprocessing(self):
|
||||
self.__init__()
|
||||
|
@ -101,7 +101,15 @@ Simple usage:
|
||||
parser.add_argument(
|
||||
"--no_import_rename",
|
||||
dest="no_import_rename",
|
||||
help=("Not to rename import to compact.v2 explicitly."),
|
||||
help=("Not to rename import to compat.v2 explicitly."),
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"--no_upgrade_compat_v1_import",
|
||||
dest="no_upgrade_compat_v1_import",
|
||||
help=("If specified, don't upgrade explicit imports of "
|
||||
"`tensorflow.compat.v1 as tf` to the v2 apis. Otherwise, "
|
||||
"explicit imports of the form `tensorflow.compat.v1 as tf` will "
|
||||
"be upgraded."),
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"--reportfile",
|
||||
@ -132,10 +140,13 @@ Simple usage:
|
||||
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
|
||||
else:
|
||||
if args.no_import_rename:
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec(import_rename=False)
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
|
||||
import_rename=False,
|
||||
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
|
||||
else:
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
|
||||
import_rename=_IMPORT_RENAME_DEFAULT)
|
||||
import_rename=_IMPORT_RENAME_DEFAULT,
|
||||
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
|
||||
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
|
||||
|
||||
report_text = None
|
||||
|
@ -117,11 +117,15 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
visitor.private_map["tf.compat"] = ["v1", "v2"]
|
||||
traverse.traverse(tf.compat.v1, visitor)
|
||||
|
||||
def _upgrade(self, old_file_text, import_rename=False):
|
||||
def _upgrade(self,
|
||||
old_file_text,
|
||||
import_rename=False,
|
||||
upgrade_compat_v1_import=False):
|
||||
in_file = six.StringIO(old_file_text)
|
||||
out_file = six.StringIO()
|
||||
upgrader = ast_edits.ASTCodeUpgrader(
|
||||
tf_upgrade_v2.TFAPIChangeSpec(import_rename))
|
||||
tf_upgrade_v2.TFAPIChangeSpec(
|
||||
import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import))
|
||||
count, report, errors = (
|
||||
upgrader.process_opened_file("test.py", in_file,
|
||||
"test_out.py", out_file))
|
||||
@ -2215,6 +2219,30 @@ def _log_prob(self, x):
|
||||
_, _, _, new_text = self._upgrade(text, import_rename=True)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = ("import tensorflow.compat.v1 as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
text = import_header + old_symbol
|
||||
expected_header = ("import tensorflow.compat.v2 as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
expected_text = expected_header + new_symbol
|
||||
_, _, _, new_text = self._upgrade(
|
||||
text, import_rename=True, upgrade_compat_v1_import=True)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = ("import tensorflow.compat.v1 as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
text = import_header + old_symbol
|
||||
expected_header = ("import tensorflow as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
expected_text = expected_header + new_symbol
|
||||
_, _, _, new_text = self._upgrade(
|
||||
text, import_rename=False, upgrade_compat_v1_import=True)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = "from tensorflow import foo\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol
|
||||
|
Loading…
Reference in New Issue
Block a user