Added an --upgrade_compat_v1_import flag to the upgrade script that allows it to upgrade import tensorflow.compat.v1 as tfimports to import tensorflow as tf imports. Note that this flag does not upgrade tf.compat.v1 when imported under other aliases, such as import tensorflow.compat.v1 as tfv1

PiperOrigin-RevId: 304323241
Change-Id: I0bbad98cb8642969a7ad9bb0bca7e1ebbe2620c2
This commit is contained in:
A. Unique TensorFlower 2020-04-01 21:28:01 -07:00 committed by TensorFlower Gardener
parent b3b6d2ab7b
commit cd84bc9820
4 changed files with 89 additions and 13 deletions

View File

@ -213,8 +213,8 @@ class APIChangeSpec(object):
"""
def preprocess(self, root_node): # pylint: disable=unused-argument
"""Preprocess a parse tree. Return any produced logs and errors."""
return [], []
"""Preprocess a parse tree. Return a preprocessed node, logs and errors."""
return root_node, [], []
def clear_preprocessing(self):
"""Restore this APIChangeSpec to before it preprocessed a file.
@ -942,7 +942,7 @@ class ASTCodeUpgrader(object):
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
return 0, "", log, []
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
visitor = _PastaEditVisitor(self._api_change_spec)
visitor.visit(t)

View File

@ -54,21 +54,47 @@ class VersionedTFImport(ast_edits.AnalysisResult):
"` was directly imported as `tf`.")
compat_v1_import = VersionedTFImport("compat.v1")
compat_v2_import = VersionedTFImport("compat.v2")
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
def __init__(self):
self.symbols_to_detect = {}
self.imports_to_detect = {
("tensorflow", None): UnaliasedTFImport(),
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
("tensorflow.compat.v1", "tf"): compat_v1_import,
("tensorflow.compat.v2", "tf"): compat_v2_import,
}
class CompatV1ImportReplacer(ast.NodeVisitor):
"""AST Visitor that replaces `import tensorflow.compat.v1 as tf`.
Converts `import tensorflow.compat.v1 as tf` to `import tensorflow as tf`
"""
def visit_Import(self, node): # pylint: disable=invalid-name
"""Handle visiting an import node in the AST.
Args:
node: Current Node
"""
for import_alias in node.names:
# Detect based on full import name and alias
if (import_alias.name == "tensorflow.compat.v1" and
import_alias.asname == "tf"):
import_alias.name = "tensorflow"
self.generic_visit(node)
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
"""List of maps that describe what changed in the API."""
def __init__(self, import_rename=False):
def __init__(self, import_rename=False, upgrade_compat_v1_import=False):
self.upgrade_compat_v1_import = upgrade_compat_v1_import
# Maps from a function name to a dictionary that describes how to
# map from an old argument keyword to the new argument keyword.
# If the new argument is None, it will be removed.
@ -1612,10 +1638,21 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
def preprocess(self, root_node):
def preprocess(self, root_node, after_compat_v1_upgrade=False):
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
visitor.visit(root_node)
detections = set(visitor.results)
# Upgrade explicit compat v1 imports if `upgrade_compat_v1_import` is
# enabled. Then preprocess the updated root node.
# We only do this upgrading once, because some forms of the import may
# still cause errors but aren't trivially upgradeable, and we don't want
# to enter an infinite loop. E.g. `from tensorflow.compat import v1, v2`.
if (compat_v1_import in detections and self.upgrade_compat_v1_import and
not after_compat_v1_upgrade):
CompatV1ImportReplacer().visit(root_node)
return self.preprocess(root_node, after_compat_v1_upgrade=True)
# If we have detected the presence of imports of specific TF versions,
# We want to modify the update spec to check only module deprecations
# and skip all other conversions.
@ -1629,7 +1666,7 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
self.function_transformers = {}
self.import_renames = {}
return visitor.log, visitor.warnings_and_errors
return root_node, visitor.log, visitor.warnings_and_errors
def clear_preprocessing(self):
self.__init__()

View File

@ -101,7 +101,15 @@ Simple usage:
parser.add_argument(
"--no_import_rename",
dest="no_import_rename",
help=("Not to rename import to compact.v2 explicitly."),
help=("Not to rename import to compat.v2 explicitly."),
action="store_true")
parser.add_argument(
"--no_upgrade_compat_v1_import",
dest="no_upgrade_compat_v1_import",
help=("If specified, don't upgrade explicit imports of "
"`tensorflow.compat.v1 as tf` to the v2 apis. Otherwise, "
"explicit imports of the form `tensorflow.compat.v1 as tf` will "
"be upgraded."),
action="store_true")
parser.add_argument(
"--reportfile",
@ -132,10 +140,13 @@ Simple usage:
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
else:
if args.no_import_rename:
change_spec = tf_upgrade_v2.TFAPIChangeSpec(import_rename=False)
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
import_rename=False,
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
else:
change_spec = tf_upgrade_v2.TFAPIChangeSpec(
import_rename=_IMPORT_RENAME_DEFAULT)
import_rename=_IMPORT_RENAME_DEFAULT,
upgrade_compat_v1_import=not args.no_upgrade_compat_v1_import)
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
report_text = None

View File

@ -117,11 +117,15 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
visitor.private_map["tf.compat"] = ["v1", "v2"]
traverse.traverse(tf.compat.v1, visitor)
def _upgrade(self, old_file_text, import_rename=False):
def _upgrade(self,
old_file_text,
import_rename=False,
upgrade_compat_v1_import=False):
in_file = six.StringIO(old_file_text)
out_file = six.StringIO()
upgrader = ast_edits.ASTCodeUpgrader(
tf_upgrade_v2.TFAPIChangeSpec(import_rename))
tf_upgrade_v2.TFAPIChangeSpec(
import_rename, upgrade_compat_v1_import=upgrade_compat_v1_import))
count, report, errors = (
upgrader.process_opened_file("test.py", in_file,
"test_out.py", out_file))
@ -2215,6 +2219,30 @@ def _log_prob(self, x):
_, _, _, new_text = self._upgrade(text, import_rename=True)
self.assertEqual(new_text, expected_text)
import_header = ("import tensorflow.compat.v1 as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
text = import_header + old_symbol
expected_header = ("import tensorflow.compat.v2 as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
expected_text = expected_header + new_symbol
_, _, _, new_text = self._upgrade(
text, import_rename=True, upgrade_compat_v1_import=True)
self.assertEqual(new_text, expected_text)
import_header = ("import tensorflow.compat.v1 as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
text = import_header + old_symbol
expected_header = ("import tensorflow as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
expected_text = expected_header + new_symbol
_, _, _, new_text = self._upgrade(
text, import_rename=False, upgrade_compat_v1_import=True)
self.assertEqual(new_text, expected_text)
import_header = "from tensorflow import foo\n"
text = import_header + old_symbol
expected_text = "from tensorflow.compat.v2 import foo\n" + new_symbol