This CL allows the tf_upgrade_v2 script to avoid changing code when import tensorflow.compat.v1/v2 as tf
is used, and to raise an error when import tensorflow
is directly used without any tf
alias.
To do this, the CL support for a code analysis step before upgrading code. In the future this would allow correctly handling non-`tf` aliases of tensorflow, as well as detecting when additional imports need to be added. PiperOrigin-RevId: 249264134
This commit is contained in:
parent
269efd353a
commit
9d46f5599f
@ -146,6 +146,31 @@ class APIChangeSpec(object):
|
||||
For an example, see `TFAPIChangeSpec`.
|
||||
"""
|
||||
|
||||
def preprocess(self, root_node): # pylint: disable=unused-argument
|
||||
"""Preprocess a parse tree. Return any produced logs and errors."""
|
||||
return [], []
|
||||
|
||||
def clear_preprocessing(self):
|
||||
"""Restore this APIChangeSpec to before it preprocessed a file.
|
||||
|
||||
This is needed if preprocessing a file changed any rewriting rules.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NoUpdateSpec(APIChangeSpec):
|
||||
"""A specification of an API change which doesn't change anything."""
|
||||
|
||||
def __init__(self):
|
||||
self.function_handle = {}
|
||||
self.function_reorders = {}
|
||||
self.function_keyword_renames = {}
|
||||
self.symbol_renames = {}
|
||||
self.function_warnings = {}
|
||||
self.change_to_function = {}
|
||||
self.module_deprecations = {}
|
||||
self.import_renames = {}
|
||||
|
||||
|
||||
class _PastaEditVisitor(ast.NodeVisitor):
|
||||
"""AST Visitor that processes function calls.
|
||||
@ -618,6 +643,112 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
class AnalysisResult(object):
|
||||
"""This class represents an analysis result and how it should be logged.
|
||||
|
||||
This class must provide the following fields:
|
||||
|
||||
* `log_level`: The log level to which this detection should be logged
|
||||
* `log_message`: The message that should be logged for this detection
|
||||
|
||||
For an example, see `VersionedTFImport`.
|
||||
"""
|
||||
|
||||
|
||||
class APIAnalysisSpec(object):
|
||||
"""This class defines how `AnalysisResult`s should be generated.
|
||||
|
||||
It specifies how to map imports and symbols to `AnalysisResult`s.
|
||||
|
||||
This class must provide the following fields:
|
||||
|
||||
* `symbols_to_detect`: maps function names to `AnalysisResult`s
|
||||
* `imports_to_detect`: maps imports represented as (full module name, alias)
|
||||
tuples to `AnalysisResult`s
|
||||
notifications)
|
||||
|
||||
For an example, see `TFAPIImportAnalysisSpec`.
|
||||
"""
|
||||
|
||||
|
||||
class PastaAnalyzeVisitor(_PastaEditVisitor):
|
||||
"""AST Visitor that looks for specific API usage without editing anything.
|
||||
|
||||
This is used before any rewriting is done to detect if any symbols are used
|
||||
that require changing imports or disabling rewriting altogether.
|
||||
"""
|
||||
|
||||
def __init__(self, api_analysis_spec):
|
||||
super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec())
|
||||
self._api_analysis_spec = api_analysis_spec
|
||||
self._results = [] # Holds AnalysisResult objects
|
||||
|
||||
@property
|
||||
def results(self):
|
||||
return self._results
|
||||
|
||||
def add_result(self, analysis_result):
|
||||
self._results.append(analysis_result)
|
||||
|
||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
|
||||
full_name = self._get_full_name(node)
|
||||
if full_name:
|
||||
detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None)
|
||||
if detection:
|
||||
self.add_result(detection)
|
||||
self.add_log(
|
||||
detection.log_level, node.lineno, node.col_offset,
|
||||
detection.log_message)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Import(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting an import node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
for import_alias in node.names:
|
||||
# Detect based on full import name and alias)
|
||||
full_import = (import_alias.name, import_alias.asname)
|
||||
detection = (self._api_analysis_spec
|
||||
.imports_to_detect.get(full_import, None))
|
||||
if detection:
|
||||
self.add_result(detection)
|
||||
self.add_log(
|
||||
detection.log_level, node.lineno, node.col_offset,
|
||||
detection.log_message)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting an import-from node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
if not node.module:
|
||||
self.generic_visit(node)
|
||||
return
|
||||
|
||||
from_import = node.module
|
||||
|
||||
for import_alias in node.names:
|
||||
# Detect based on full import name(to & as)
|
||||
full_module_name = "%s.%s" % (from_import, import_alias.name)
|
||||
full_import = (full_module_name, import_alias.asname)
|
||||
detection = (self._api_analysis_spec
|
||||
.imports_to_detect.get(full_import, None))
|
||||
if detection:
|
||||
self.add_result(detection)
|
||||
self.add_log(
|
||||
detection.log_level, node.lineno, node.col_offset,
|
||||
detection.log_message)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
class ASTCodeUpgrader(object):
|
||||
"""Handles upgrading a set of Python files using a given API change spec."""
|
||||
|
||||
@ -663,12 +794,18 @@ class ASTCodeUpgrader(object):
|
||||
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
|
||||
return 0, "", log, []
|
||||
|
||||
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
|
||||
|
||||
visitor = _PastaEditVisitor(self._api_change_spec)
|
||||
visitor.visit(t)
|
||||
|
||||
logs = [self.format_log(log, None) for log in visitor.log]
|
||||
self._api_change_spec.clear_preprocessing()
|
||||
|
||||
logs = [self.format_log(log, None) for log in (preprocess_logs +
|
||||
visitor.log)]
|
||||
errors = [self.format_log(error, in_filename)
|
||||
for error in visitor.warnings_and_errors]
|
||||
for error in (preprocess_errors +
|
||||
visitor.warnings_and_errors)]
|
||||
return 1, pasta.dump(t), logs, errors
|
||||
|
||||
def _format_log(self, log, in_filename, out_filename):
|
||||
|
@ -52,29 +52,15 @@ from tensorflow.python.platform import test as test_lib
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
|
||||
|
||||
class NoUpdateSpec(ast_edits.APIChangeSpec):
|
||||
"""A specification of an API change which doesn't change anything."""
|
||||
|
||||
def __init__(self):
|
||||
self.function_handle = {}
|
||||
self.function_reorders = {}
|
||||
self.function_keyword_renames = {}
|
||||
self.symbol_renames = {}
|
||||
self.function_warnings = {}
|
||||
self.change_to_function = {}
|
||||
self.module_deprecations = {}
|
||||
self.import_renames = {}
|
||||
|
||||
|
||||
class ModuleDeprecationSpec(NoUpdateSpec):
|
||||
class ModuleDeprecationSpec(ast_edits.NoUpdateSpec):
|
||||
"""A specification which deprecates 'a.b'."""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")})
|
||||
|
||||
|
||||
class RenameKeywordSpec(NoUpdateSpec):
|
||||
class RenameKeywordSpec(ast_edits.NoUpdateSpec):
|
||||
"""A specification where kw2 gets renamed to kw3.
|
||||
|
||||
The new API is
|
||||
@ -84,14 +70,14 @@ class RenameKeywordSpec(NoUpdateSpec):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.update_renames()
|
||||
|
||||
def update_renames(self):
|
||||
self.function_keyword_renames["f"] = {"kw2": "kw3"}
|
||||
|
||||
|
||||
class ReorderKeywordSpec(NoUpdateSpec):
|
||||
class ReorderKeywordSpec(ast_edits.NoUpdateSpec):
|
||||
"""A specification where kw2 gets moved in front of kw1.
|
||||
|
||||
The new API is
|
||||
@ -101,7 +87,7 @@ class ReorderKeywordSpec(NoUpdateSpec):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.update_reorders()
|
||||
|
||||
def update_reorders(self):
|
||||
@ -125,7 +111,7 @@ class ReorderAndRenameKeywordSpec(ReorderKeywordSpec, RenameKeywordSpec):
|
||||
self.update_reorders()
|
||||
|
||||
|
||||
class RemoveDeprecatedAliasKeyword(NoUpdateSpec):
|
||||
class RemoveDeprecatedAliasKeyword(ast_edits.NoUpdateSpec):
|
||||
"""A specification where kw1_alias is removed in g.
|
||||
|
||||
The new API is
|
||||
@ -136,7 +122,7 @@ class RemoveDeprecatedAliasKeyword(NoUpdateSpec):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.function_keyword_renames["g"] = {"kw1_alias": "kw1"}
|
||||
self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"}
|
||||
|
||||
@ -158,7 +144,7 @@ class RemoveDeprecatedAliasAndReorderRest(RemoveDeprecatedAliasKeyword):
|
||||
self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"]
|
||||
|
||||
|
||||
class RemoveMultipleKeywordArguments(NoUpdateSpec):
|
||||
class RemoveMultipleKeywordArguments(ast_edits.NoUpdateSpec):
|
||||
"""A specification where both keyword aliases are removed from h.
|
||||
|
||||
The new API is
|
||||
@ -168,18 +154,18 @@ class RemoveMultipleKeywordArguments(NoUpdateSpec):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.function_keyword_renames["h"] = {
|
||||
"kw1_alias": "kw1",
|
||||
"kw2_alias": "kw2",
|
||||
}
|
||||
|
||||
|
||||
class RenameImports(NoUpdateSpec):
|
||||
class RenameImports(ast_edits.NoUpdateSpec):
|
||||
"""Specification for renaming imports."""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.import_renames = {
|
||||
"foo": ast_edits.ImportRename(
|
||||
"bar",
|
||||
@ -209,11 +195,11 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
||||
|
||||
def testNoTransformIfNothingIsSupplied(self):
|
||||
text = "f(a, b, kw1=c, kw2=d)\n"
|
||||
_, new_text = self._upgrade(NoUpdateSpec(), text)
|
||||
_, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
|
||||
self.assertEqual(new_text, text)
|
||||
|
||||
text = "f(a, b, c, d)\n"
|
||||
_, new_text = self._upgrade(NoUpdateSpec(), text)
|
||||
_, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
|
||||
self.assertEqual(new_text, text)
|
||||
|
||||
def testKeywordRename(self):
|
||||
@ -447,11 +433,11 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
||||
self.assertIn(new_text, acceptable_outputs)
|
||||
|
||||
def testUnrestrictedFunctionWarnings(self):
|
||||
class FooWarningSpec(NoUpdateSpec):
|
||||
class FooWarningSpec(ast_edits.NoUpdateSpec):
|
||||
"""Usages of function attribute foo() prints out a warning."""
|
||||
|
||||
def __init__(self):
|
||||
NoUpdateSpec.__init__(self)
|
||||
ast_edits.NoUpdateSpec.__init__(self)
|
||||
self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
|
||||
|
||||
texts = ["object.foo()", "get_object().foo()",
|
||||
|
@ -34,7 +34,35 @@ from tensorflow.tools.compatibility import reorders_v2
|
||||
# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison
|
||||
|
||||
|
||||
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
class UnaliasedTFImport(ast_edits.AnalysisResult):
|
||||
|
||||
def __init__(self):
|
||||
self.log_level = ast_edits.ERROR
|
||||
self.log_message = ("The tf_upgrade_v2 script detected an unaliased "
|
||||
"`import tensorflow`. The script can only run when "
|
||||
"importing with `import tensorflow as tf`.")
|
||||
|
||||
|
||||
class VersionedTFImport(ast_edits.AnalysisResult):
|
||||
|
||||
def __init__(self, version):
|
||||
self.log_level = ast_edits.INFO
|
||||
self.log_message = ("Not upgrading symbols because `tensorflow." + version
|
||||
+ "` was directly imported as `tf`.")
|
||||
|
||||
|
||||
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.symbols_to_detect = {}
|
||||
self.imports_to_detect = {
|
||||
("tensorflow", None): UnaliasedTFImport(),
|
||||
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
|
||||
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
|
||||
}
|
||||
|
||||
|
||||
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||
"""List of maps that describe what changed in the API."""
|
||||
|
||||
def __init__(self):
|
||||
@ -481,6 +509,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
# Add additional renames not in renames_v2.py to all_renames_v2.py.
|
||||
self.symbol_renames = all_renames_v2.symbol_renames
|
||||
|
||||
self.import_renames = {}
|
||||
|
||||
# Variables that should be changed to functions.
|
||||
self.change_to_function = {}
|
||||
|
||||
@ -771,7 +801,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"extended.call_for_each_replica->experimental_run_v2, "
|
||||
"reduce requires an axis argument, "
|
||||
"unwrap->experimental_local_results "
|
||||
"experimental_initialize and experimenta_finalize no longer needed ")
|
||||
"experimental_initialize and experimental_finalize no longer needed ")
|
||||
|
||||
contrib_mirrored_strategy_warning = (
|
||||
ast_edits.ERROR,
|
||||
@ -1474,6 +1504,27 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
|
||||
def preprocess(self, root_node):
|
||||
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
|
||||
visitor.visit(root_node)
|
||||
detections = set(visitor.results)
|
||||
# If we have detected the presence of imports of specific TF versions,
|
||||
# We want to modify the update spec to check only module deprecations
|
||||
# and skip all other conversions.
|
||||
if detections:
|
||||
self.function_handle = {}
|
||||
self.function_reorders = {}
|
||||
self.function_keyword_renames = {}
|
||||
self.symbol_renames = {}
|
||||
self.function_warnings = {}
|
||||
self.change_to_function = {}
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
self.import_renames = {}
|
||||
return visitor.log, visitor.warnings_and_errors
|
||||
|
||||
def clear_preprocessing(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
def _is_ast_str(node):
|
||||
"""Determine whether this node represents a string."""
|
||||
|
@ -123,6 +123,18 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
"test_out.py", out_file))
|
||||
return count, report, errors, out_file.getvalue()
|
||||
|
||||
def _upgrade_multiple(self, old_file_texts):
|
||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
|
||||
results = []
|
||||
for old_file_text in old_file_texts:
|
||||
in_file = six.StringIO(old_file_text)
|
||||
out_file = six.StringIO()
|
||||
count, report, errors = (
|
||||
upgrader.process_opened_file("test.py", in_file,
|
||||
"test_out.py", out_file))
|
||||
results.append([count, report, errors, out_file.getvalue()])
|
||||
return results
|
||||
|
||||
def testParseError(self):
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(
|
||||
"import tensorflow as tf\na + \n")
|
||||
@ -1984,6 +1996,96 @@ def _log_prob(self, x):
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def test_import_analysis(self):
|
||||
old_symbol = "tf.conj(a)"
|
||||
new_symbol = "tf.math.conj(a)"
|
||||
|
||||
# We upgrade the base un-versioned tensorflow aliased as tf
|
||||
import_header = "import tensorflow as tf\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + new_symbol
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
import_header = ("import tensorflow as tf\n"
|
||||
"import tensorflow.compat.v1 as tf_v1\n"
|
||||
"import tensorflow.compat.v2 as tf_v2\n")
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + new_symbol
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
# We don't handle unaliased tensorflow imports currently,
|
||||
# So the upgrade script show log errors
|
||||
import_header = "import tensorflow\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + old_symbol
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn("unaliased `import tensorflow`", "\n".join(errors))
|
||||
|
||||
# Upgrading explicitly-versioned tf code is unsafe, but we don't
|
||||
# need to throw errors when we detect explicitly-versioned tf.
|
||||
import_header = "import tensorflow.compat.v1 as tf\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + old_symbol
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
|
||||
report)
|
||||
self.assertEmpty(errors)
|
||||
|
||||
import_header = "from tensorflow.compat import v1 as tf\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + old_symbol
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
|
||||
report)
|
||||
self.assertEmpty(errors)
|
||||
|
||||
import_header = "from tensorflow.compat import v1 as tf, v2 as tf2\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + old_symbol
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
|
||||
report)
|
||||
self.assertEmpty(errors)
|
||||
|
||||
import_header = "import tensorflow.compat.v2 as tf\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + old_symbol
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
|
||||
report)
|
||||
self.assertEmpty(errors)
|
||||
|
||||
import_header = "from tensorflow.compat import v1 as tf1, v2 as tf\n"
|
||||
text = import_header + old_symbol
|
||||
expected_text = import_header + old_symbol
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
|
||||
report)
|
||||
self.assertEmpty(errors)
|
||||
|
||||
def test_api_spec_reset_between_files(self):
|
||||
old_symbol = "tf.conj(a)"
|
||||
new_symbol = "tf.math.conj(a)"
|
||||
|
||||
## Test that the api spec is reset in between files:
|
||||
import_header = "import tensorflow.compat.v2 as tf\n"
|
||||
text_a = import_header + old_symbol
|
||||
expected_text_a = import_header + old_symbol
|
||||
text_b = old_symbol
|
||||
expected_text_b = new_symbol
|
||||
results = self._upgrade_multiple([text_a, text_b])
|
||||
result_a, result_b = results[0], results[1]
|
||||
self.assertEqual(result_a[3], expected_text_a)
|
||||
self.assertEqual(result_b[3], expected_text_b)
|
||||
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user