Added an `--upgrade_compat_v1_import` flag to the upgrade script that allows it to upgrade `import tensorflow.compat.v1 as tf`imports to `import tensorflow as tf` imports. Note that this flag does not upgrade tf.compat.v1 when imported under other aliases, such as `import tensorflow.compat.v1 as tfv1`
PiperOrigin-RevId: 304323241 Change-Id: I0bbad98cb8642969a7ad9bb0bca7e1ebbe2620c2
This commit is contained in:
parent
b3b6d2ab7b
commit
cd84bc9820
|
@ -213,8 +213,8 @@ class APIChangeSpec(object):
|
|||
"""
|
||||
|
||||
def preprocess(self, root_node): # pylint: disable=unused-argument
|
||||
"""Preprocess a parse tree. Return any produced logs and errors."""
|
||||
return [], []
|
||||
"""Preprocess a parse tree. Return a preprocessed node, logs and errors."""
|
||||
return root_node, [], []
|
||||
|
||||
def clear_preprocessing(self):
|
||||
"""Restore this APIChangeSpec to before it preprocessed a file.
|
||||
|
@ -942,7 +942,7 @@ class ASTCodeUpgrader(object):
|
|||
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
|
||||
return 0, "", log, []
|
||||
|
||||
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
|
||||
t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
|
||||
|
||||
visitor = _PastaEditVisitor(self._api_change_spec)
|
||||
visitor.visit(t)
|
||||
|
|
|
@ -54,21 +54,47 @@ class VersionedTFImport(ast_edits.AnalysisResult):
|
|||
"` was directly imported as `tf`.")
|
||||
|
||||
|
||||
compat_v1_import = VersionedTFImport("compat.v1")
|
||||
compat_v2_import = VersionedTFImport("compat.v2")
|
||||
|
||||
|
||||
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.symbols_to_detect = {}
|
||||
self.imports_to_detect = {
|
||||
("tensorflow", None): UnaliasedTFImport(),
|
||||
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
|
||||
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
|
||||
("tensorflow.compat.v1", "tf"): compat_v1_import,
|
||||
("tensorflow.compat.v2", "tf"): compat_v2_import,
|
||||
}
|
||||
|
||||
|
||||
class CompatV1ImportReplacer(ast.NodeVisitor):
|
||||
"""AST Visitor that replaces `import tensorflow.compat.v1 as tf`.
|
||||
|
||||
Converts `import tensorflow.compat.v1 as tf` to `import tensorflow as tf`
|
||||
"""
|
||||
|
||||
def visit_Import(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting an import node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
for import_alias in node.names:
|
||||
# Detect based on full import name and alias
|
||||
if (import_alias.name == "tensorflow.compat.v1" and
|
||||
import_alias.asname == "tf"):
|
||||
import_alias.name = "tensorflow"
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||
"""List of maps that describe what changed in the API."""
|
||||
|
||||
def __init__(self, import_rename=False):
|
||||
def __init__(self, import_rename=False, upgrade_compat_v1_import=False):
|
||||
self.upgrade_compat_v1_import = upgrade_compat_v1_import
|
||||
|
||||
# Maps from a function name to a dictionary that describes how to
|
||||
# map from an old argument keyword to the new argument keyword.
|
||||
# If the new argument is None, it will be removed.
|
||||
|
@ -1612,10 +1638,21 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
|||
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
|
||||
def preprocess(self, root_node):
|
||||
def preprocess(self, root_node, after_compat_v1_upgrade=False):
|
||||
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
|
||||
visitor.visit(root_node)
|
||||
detections = set(visitor.results)
|
||||
|
||||
# Upgrade explicit compat v1 imports if `upgrade_compat_v1_import` is
|
||||
# enabled. Then preprocess the updated root node.
|
||||
# We only do this upgrading once, because some forms of the import may
|
||||
# still cause errors but aren't trivially upgradeable, and we don't want
|
||||
# to enter an infinite loop. E.g. `from tensorflow.compat import v1, v2`.
|
||||
if (compat_v1_import in detections and self.upgrade_compat_v1_import and
|
||||
not after_compat_v1_upgrade):
|
||||
CompatV1ImportReplacer().visit(root_node)
|
||||
return self.preprocess(root_node, after_compat_v1_upgrade=True)
|
||||
|
||||
# If we have detected the presence of imports of specific TF versions,
|
||||
# We want to modify the update spec to check only module deprecations
|
||||
# and skip all other conversions.
|
||||
|
@ -1629,7 +1666,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
|||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
self.function_transformers = {}
|
||||
self.import_renames = {}
|
||||
return visitor.log, visitor.warnings_and_errors
|
||||
return root_node, visitor.log, visitor.warnings_and_errors
|
||||
|
||||
def clear_preprocessing(self):
|
||||
self.__init__()
|
||||
|
|
|
@ -101,7 +101,15 @@ Simple usage:
|
|||
parser.add_argument(
|
||||
"--no_import_rename",
|
||||
dest="no_import_rename",
|
||||
help=("Not to rename import to compact.v2 explicitly."),
|
||||
help=("Not to rename import to compat.v2 explicitly."),
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"--no_upgrade_compat_v1_import",
|
||||
dest="no_upgrade_compat_v1_import",
|
||||
help=("If specified, don't upgrade explicit imports of "
|
||||
"`tensorflow.compat.v1 as tf` to the v2 apis. Otherwise, "
|
||||
"explicit imports of the form `tensorflow.compat.v1 as tf` will "
|
||||
"be upgraded."),
|
||||
action="store_true")
|
||||
parser.add_argument(
|
||||
"--reportfile",
|
||||
|
@ -132,10 +140,13 @@ Simple usage:
|
|||
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
|
||||
else:
|
||||
if args.no_import_rename:
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec(import_rename=False)
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
|
||||
import_rename=False,
|
||||
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
|
||||
else:
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
|
||||
import_rename=_IMPORT_RENAME_DEFAULT)
|
||||
import_rename=_IMPORT_RENAME_DEFAULT,
|
||||
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
|
||||
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
|
||||
|
||||
report_text = None
|
||||
|
|
|
@ -117,11 +117,15 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
visitor.private_map["tf.compat"] = ["v1", "v2"]
|
||||
traverse.traverse(tf.compat.v1, visitor)
|
||||
|
||||
def _upgrade(self, old_file_text, import_rename=False):
|
||||
def _upgrade(self,
|
||||
old_file_text,
|
||||
import_rename=False,
|
||||
upgrade_compat_v1_import=False):
|
||||
in_file = six.StringIO(old_file_text)
|
||||
out_file = six.StringIO()
|
||||
upgrader = ast_edits.ASTCodeUpgrader(
|
||||
tf_upgrade_v2.TFAPIChangeSpec(import_rename))
|
||||
tf_upgrade_v2.TFAPIChangeSpec(
|
||||
import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import))
|
||||
count, report, errors = (
|
||||
upgrader.process_opened_file("test.py", in_file,
|
||||
"test_out.py", out_file))
|
||||
|
@ -2215,6 +2219,30 @@ def _log_prob(self, x):
|
|||
_, _, _, new_text = self._upgrade(text, import_rename=True)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = ("import tensorflow.compat.v1 as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
text = import_header + old_symbol
|
||||
expected_header = ("import tensorflow.compat.v2 as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
expected_text = expected_header + new_symbol
|
||||
_, _, _, new_text = self._upgrade(
|
||||
text, import_rename=True, upgrade_compat_v1_import=True)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = ("import tensorflow.compat.v1 as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
text = import_header + old_symbol
|
||||
expected_header = ("import tensorflow as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
expected_text = expected_header + new_symbol
|
||||
_, _, _, new_text = self._upgrade(
|
||||
text, import_rename=False, upgrade_compat_v1_import=True)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = "from tensorflow import foo\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol
|
||||
|
|
Loading…
Reference in New Issue