This CL allows the tf_upgrade_v2 script to avoid changing code when import tensorflow.compat.v1/v2 as tf is used, and to raise an error when import tensorflow is directly used without any tf alias.

To do this, the CL support for a code analysis step before upgrading code.

In the future this would allow correctly handling non-`tf` aliases of tensorflow, as well as detecting when additional imports need to be added.

PiperOrigin-RevId: 249264134
This commit is contained in:
A. Unique TensorFlower 2019-05-21 09:22:08 -07:00 committed by TensorFlower Gardener
parent 269efd353a
commit 9d46f5599f
4 changed files with 310 additions and 34 deletions

View File

@ -146,6 +146,31 @@ class APIChangeSpec(object):
For an example, see `TFAPIChangeSpec`.
"""
def preprocess(self, root_node): # pylint: disable=unused-argument
"""Preprocess a parse tree. Return any produced logs and errors."""
return [], []
def clear_preprocessing(self):
"""Restore this APIChangeSpec to before it preprocessed a file.
This is needed if preprocessing a file changed any rewriting rules.
"""
pass
class NoUpdateSpec(APIChangeSpec):
"""A specification of an API change which doesn't change anything."""
def __init__(self):
self.function_handle = {}
self.function_reorders = {}
self.function_keyword_renames = {}
self.symbol_renames = {}
self.function_warnings = {}
self.change_to_function = {}
self.module_deprecations = {}
self.import_renames = {}
class _PastaEditVisitor(ast.NodeVisitor):
"""AST Visitor that processes function calls.
@ -618,6 +643,112 @@ class _PastaEditVisitor(ast.NodeVisitor):
self.generic_visit(node)
class AnalysisResult(object):
"""This class represents an analysis result and how it should be logged.
This class must provide the following fields:
* `log_level`: The log level to which this detection should be logged
* `log_message`: The message that should be logged for this detection
For an example, see `VersionedTFImport`.
"""
class APIAnalysisSpec(object):
"""This class defines how `AnalysisResult`s should be generated.
It specifies how to map imports and symbols to `AnalysisResult`s.
This class must provide the following fields:
* `symbols_to_detect`: maps function names to `AnalysisResult`s
* `imports_to_detect`: maps imports represented as (full module name, alias)
tuples to `AnalysisResult`s
notifications)
For an example, see `TFAPIImportAnalysisSpec`.
"""
class PastaAnalyzeVisitor(_PastaEditVisitor):
"""AST Visitor that looks for specific API usage without editing anything.
This is used before any rewriting is done to detect if any symbols are used
that require changing imports or disabling rewriting altogether.
"""
def __init__(self, api_analysis_spec):
super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec())
self._api_analysis_spec = api_analysis_spec
self._results = [] # Holds AnalysisResult objects
@property
def results(self):
return self._results
def add_result(self, analysis_result):
self._results.append(analysis_result)
def visit_Attribute(self, node): # pylint: disable=invalid-name
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
full_name = self._get_full_name(node)
if full_name:
detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None)
if detection:
self.add_result(detection)
self.add_log(
detection.log_level, node.lineno, node.col_offset,
detection.log_message)
self.generic_visit(node)
def visit_Import(self, node): # pylint: disable=invalid-name
"""Handle visiting an import node in the AST.
Args:
node: Current Node
"""
for import_alias in node.names:
# Detect based on full import name and alias)
full_import = (import_alias.name, import_alias.asname)
detection = (self._api_analysis_spec
.imports_to_detect.get(full_import, None))
if detection:
self.add_result(detection)
self.add_log(
detection.log_level, node.lineno, node.col_offset,
detection.log_message)
self.generic_visit(node)
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
"""Handle visiting an import-from node in the AST.
Args:
node: Current Node
"""
if not node.module:
self.generic_visit(node)
return
from_import = node.module
for import_alias in node.names:
# Detect based on full import name(to & as)
full_module_name = "%s.%s" % (from_import, import_alias.name)
full_import = (full_module_name, import_alias.asname)
detection = (self._api_analysis_spec
.imports_to_detect.get(full_import, None))
if detection:
self.add_result(detection)
self.add_log(
detection.log_level, node.lineno, node.col_offset,
detection.log_message)
self.generic_visit(node)
class ASTCodeUpgrader(object):
"""Handles upgrading a set of Python files using a given API change spec."""
@ -663,12 +794,18 @@ class ASTCodeUpgrader(object):
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
return 0, "", log, []
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
visitor = _PastaEditVisitor(self._api_change_spec)
visitor.visit(t)
logs = [self.format_log(log, None) for log in visitor.log]
self._api_change_spec.clear_preprocessing()
logs = [self.format_log(log, None) for log in (preprocess_logs +
visitor.log)]
errors = [self.format_log(error, in_filename)
for error in visitor.warnings_and_errors]
for error in (preprocess_errors +
visitor.warnings_and_errors)]
return 1, pasta.dump(t), logs, errors
def _format_log(self, log, in_filename, out_filename):

View File

@ -52,29 +52,15 @@ from tensorflow.python.platform import test as test_lib
from tensorflow.tools.compatibility import ast_edits
class NoUpdateSpec(ast_edits.APIChangeSpec):
"""A specification of an API change which doesn't change anything."""
def __init__(self):
self.function_handle = {}
self.function_reorders = {}
self.function_keyword_renames = {}
self.symbol_renames = {}
self.function_warnings = {}
self.change_to_function = {}
self.module_deprecations = {}
self.import_renames = {}
class ModuleDeprecationSpec(NoUpdateSpec):
class ModuleDeprecationSpec(ast_edits.NoUpdateSpec):
"""A specification which deprecates 'a.b'."""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")})
class RenameKeywordSpec(NoUpdateSpec):
class RenameKeywordSpec(ast_edits.NoUpdateSpec):
"""A specification where kw2 gets renamed to kw3.
The new API is
@ -84,14 +70,14 @@ class RenameKeywordSpec(NoUpdateSpec):
"""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.update_renames()
def update_renames(self):
self.function_keyword_renames["f"] = {"kw2": "kw3"}
class ReorderKeywordSpec(NoUpdateSpec):
class ReorderKeywordSpec(ast_edits.NoUpdateSpec):
"""A specification where kw2 gets moved in front of kw1.
The new API is
@ -101,7 +87,7 @@ class ReorderKeywordSpec(NoUpdateSpec):
"""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.update_reorders()
def update_reorders(self):
@ -125,7 +111,7 @@ class ReorderAndRenameKeywordSpec(ReorderKeywordSpec, RenameKeywordSpec):
self.update_reorders()
class RemoveDeprecatedAliasKeyword(NoUpdateSpec):
class RemoveDeprecatedAliasKeyword(ast_edits.NoUpdateSpec):
"""A specification where kw1_alias is removed in g.
The new API is
@ -136,7 +122,7 @@ class RemoveDeprecatedAliasKeyword(NoUpdateSpec):
"""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.function_keyword_renames["g"] = {"kw1_alias": "kw1"}
self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"}
@ -158,7 +144,7 @@ class RemoveDeprecatedAliasAndReorderRest(RemoveDeprecatedAliasKeyword):
self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"]
class RemoveMultipleKeywordArguments(NoUpdateSpec):
class RemoveMultipleKeywordArguments(ast_edits.NoUpdateSpec):
"""A specification where both keyword aliases are removed from h.
The new API is
@ -168,18 +154,18 @@ class RemoveMultipleKeywordArguments(NoUpdateSpec):
"""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.function_keyword_renames["h"] = {
"kw1_alias": "kw1",
"kw2_alias": "kw2",
}
class RenameImports(NoUpdateSpec):
class RenameImports(ast_edits.NoUpdateSpec):
"""Specification for renaming imports."""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.import_renames = {
"foo": ast_edits.ImportRename(
"bar",
@ -209,11 +195,11 @@ class TestAstEdits(test_util.TensorFlowTestCase):
def testNoTransformIfNothingIsSupplied(self):
text = "f(a, b, kw1=c, kw2=d)\n"
_, new_text = self._upgrade(NoUpdateSpec(), text)
_, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
self.assertEqual(new_text, text)
text = "f(a, b, c, d)\n"
_, new_text = self._upgrade(NoUpdateSpec(), text)
_, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
self.assertEqual(new_text, text)
def testKeywordRename(self):
@ -447,11 +433,11 @@ class TestAstEdits(test_util.TensorFlowTestCase):
self.assertIn(new_text, acceptable_outputs)
def testUnrestrictedFunctionWarnings(self):
class FooWarningSpec(NoUpdateSpec):
class FooWarningSpec(ast_edits.NoUpdateSpec):
"""Usages of function attribute foo() prints out a warning."""
def __init__(self):
NoUpdateSpec.__init__(self)
ast_edits.NoUpdateSpec.__init__(self)
self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
texts = ["object.foo()", "get_object().foo()",

View File

@ -34,7 +34,35 @@ from tensorflow.tools.compatibility import reorders_v2
# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
class UnaliasedTFImport(ast_edits.AnalysisResult):
def __init__(self):
self.log_level = ast_edits.ERROR
self.log_message = ("The tf_upgrade_v2 script detected an unaliased "
"`import tensorflow`. The script can only run when "
"importing with `import tensorflow as tf`.")
class VersionedTFImport(ast_edits.AnalysisResult):
def __init__(self, version):
self.log_level = ast_edits.INFO
self.log_message = ("Not upgrading symbols because `tensorflow." + version
+ "` was directly imported as `tf`.")
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
def __init__(self):
self.symbols_to_detect = {}
self.imports_to_detect = {
("tensorflow", None): UnaliasedTFImport(),
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
}
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
"""List of maps that describe what changed in the API."""
def __init__(self):
@ -481,6 +509,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Add additional renames not in renames_v2.py to all_renames_v2.py.
self.symbol_renames = all_renames_v2.symbol_renames
self.import_renames = {}
# Variables that should be changed to functions.
self.change_to_function = {}
@ -771,7 +801,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"extended.call_for_each_replica->experimental_run_v2, "
"reduce requires an axis argument, "
"unwrap->experimental_local_results "
"experimental_initialize and experimenta_finalize no longer needed ")
"experimental_initialize and experimental_finalize no longer needed ")
contrib_mirrored_strategy_warning = (
ast_edits.ERROR,
@ -1474,6 +1504,27 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
def preprocess(self, root_node):
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
visitor.visit(root_node)
detections = set(visitor.results)
# If we have detected the presence of imports of specific TF versions,
# We want to modify the update spec to check only module deprecations
# and skip all other conversions.
if detections:
self.function_handle = {}
self.function_reorders = {}
self.function_keyword_renames = {}
self.symbol_renames = {}
self.function_warnings = {}
self.change_to_function = {}
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
self.import_renames = {}
return visitor.log, visitor.warnings_and_errors
def clear_preprocessing(self):
self.__init__()
def _is_ast_str(node):
"""Determine whether this node represents a string."""

View File

@ -123,6 +123,18 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
"test_out.py", out_file))
return count, report, errors, out_file.getvalue()
def _upgrade_multiple(self, old_file_texts):
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
results = []
for old_file_text in old_file_texts:
in_file = six.StringIO(old_file_text)
out_file = six.StringIO()
count, report, errors = (
upgrader.process_opened_file("test.py", in_file,
"test_out.py", out_file))
results.append([count, report, errors, out_file.getvalue()])
return results
def testParseError(self):
_, report, unused_errors, unused_new_text = self._upgrade(
"import tensorflow as tf\na + \n")
@ -1984,6 +1996,96 @@ def _log_prob(self, x):
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def test_import_analysis(self):
old_symbol = "tf.conj(a)"
new_symbol = "tf.math.conj(a)"
# We upgrade the base un-versioned tensorflow aliased as tf
import_header = "import tensorflow as tf\n"
text = import_header + old_symbol
expected_text = import_header + new_symbol
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
import_header = ("import tensorflow as tf\n"
"import tensorflow.compat.v1 as tf_v1\n"
"import tensorflow.compat.v2 as tf_v2\n")
text = import_header + old_symbol
expected_text = import_header + new_symbol
_, _, _, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
# We don't handle unaliased tensorflow imports currently,
# So the upgrade script show log errors
import_header = "import tensorflow\n"
text = import_header + old_symbol
expected_text = import_header + old_symbol
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn("unaliased `import tensorflow`", "\n".join(errors))
# Upgrading explicitly-versioned tf code is unsafe, but we don't
# need to throw errors when we detect explicitly-versioned tf.
import_header = "import tensorflow.compat.v1 as tf\n"
text = import_header + old_symbol
expected_text = import_header + old_symbol
_, report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
report)
self.assertEmpty(errors)
import_header = "from tensorflow.compat import v1 as tf\n"
text = import_header + old_symbol
expected_text = import_header + old_symbol
_, report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
report)
self.assertEmpty(errors)
import_header = "from tensorflow.compat import v1 as tf, v2 as tf2\n"
text = import_header + old_symbol
expected_text = import_header + old_symbol
_, report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
report)
self.assertEmpty(errors)
import_header = "import tensorflow.compat.v2 as tf\n"
text = import_header + old_symbol
expected_text = import_header + old_symbol
_, report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
report)
self.assertEmpty(errors)
import_header = "from tensorflow.compat import v1 as tf1, v2 as tf\n"
text = import_header + old_symbol
expected_text = import_header + old_symbol
_, report, errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
report)
self.assertEmpty(errors)
def test_api_spec_reset_between_files(self):
old_symbol = "tf.conj(a)"
new_symbol = "tf.math.conj(a)"
## Test that the api spec is reset in between files:
import_header = "import tensorflow.compat.v2 as tf\n"
text_a = import_header + old_symbol
expected_text_a = import_header + old_symbol
text_b = old_symbol
expected_text_b = new_symbol
results = self._upgrade_multiple([text_a, text_b])
result_a, result_b = results[0], results[1]
self.assertEqual(result_a[3], expected_text_a)
self.assertEqual(result_b[3], expected_text_b)
class TestUpgradeFiles(test_util.TensorFlowTestCase):