From 9d46f5599fc9256daa8e7ff63965289c56457019 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 21 May 2019 09:22:08 -0700 Subject: [PATCH] This CL allows the tf_upgrade_v2 script to avoid changing code when `import tensorflow.compat.v1/v2 as tf` is used, and to raise an error when `import tensorflow` is directly used without any `tf` alias. To do this, the CL support for a code analysis step before upgrading code. In the future this would allow correctly handling non-`tf` aliases of tensorflow, as well as detecting when additional imports need to be added. PiperOrigin-RevId: 249264134 --- tensorflow/tools/compatibility/ast_edits.py | 141 +++++++++++++++++- .../tools/compatibility/ast_edits_test.py | 46 ++---- .../tools/compatibility/tf_upgrade_v2.py | 55 ++++++- .../tools/compatibility/tf_upgrade_v2_test.py | 102 +++++++++++++ 4 files changed, 310 insertions(+), 34 deletions(-) diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py index 05b26d88b7d..d4dce186e0a 100644 --- a/tensorflow/tools/compatibility/ast_edits.py +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -146,6 +146,31 @@ class APIChangeSpec(object): For an example, see `TFAPIChangeSpec`. """ + def preprocess(self, root_node): # pylint: disable=unused-argument + """Preprocess a parse tree. Return any produced logs and errors.""" + return [], [] + + def clear_preprocessing(self): + """Restore this APIChangeSpec to before it preprocessed a file. + + This is needed if preprocessing a file changed any rewriting rules. + """ + pass + + +class NoUpdateSpec(APIChangeSpec): + """A specification of an API change which doesn't change anything.""" + + def __init__(self): + self.function_handle = {} + self.function_reorders = {} + self.function_keyword_renames = {} + self.symbol_renames = {} + self.function_warnings = {} + self.change_to_function = {} + self.module_deprecations = {} + self.import_renames = {} + class _PastaEditVisitor(ast.NodeVisitor): """AST Visitor that processes function calls. @@ -618,6 +643,112 @@ class _PastaEditVisitor(ast.NodeVisitor): self.generic_visit(node) +class AnalysisResult(object): + """This class represents an analysis result and how it should be logged. + + This class must provide the following fields: + + * `log_level`: The log level to which this detection should be logged + * `log_message`: The message that should be logged for this detection + + For an example, see `VersionedTFImport`. + """ + + +class APIAnalysisSpec(object): + """This class defines how `AnalysisResult`s should be generated. + + It specifies how to map imports and symbols to `AnalysisResult`s. + + This class must provide the following fields: + + * `symbols_to_detect`: maps function names to `AnalysisResult`s + * `imports_to_detect`: maps imports represented as (full module name, alias) + tuples to `AnalysisResult`s + notifications) + + For an example, see `TFAPIImportAnalysisSpec`. + """ + + +class PastaAnalyzeVisitor(_PastaEditVisitor): + """AST Visitor that looks for specific API usage without editing anything. + + This is used before any rewriting is done to detect if any symbols are used + that require changing imports or disabling rewriting altogether. + """ + + def __init__(self, api_analysis_spec): + super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec()) + self._api_analysis_spec = api_analysis_spec + self._results = [] # Holds AnalysisResult objects + + @property + def results(self): + return self._results + + def add_result(self, analysis_result): + self._results.append(analysis_result) + + def visit_Attribute(self, node): # pylint: disable=invalid-name + """Handle bare Attributes i.e. [tf.foo, tf.bar].""" + full_name = self._get_full_name(node) + if full_name: + detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None) + if detection: + self.add_result(detection) + self.add_log( + detection.log_level, node.lineno, node.col_offset, + detection.log_message) + + self.generic_visit(node) + + def visit_Import(self, node): # pylint: disable=invalid-name + """Handle visiting an import node in the AST. + + Args: + node: Current Node + """ + for import_alias in node.names: + # Detect based on full import name and alias) + full_import = (import_alias.name, import_alias.asname) + detection = (self._api_analysis_spec + .imports_to_detect.get(full_import, None)) + if detection: + self.add_result(detection) + self.add_log( + detection.log_level, node.lineno, node.col_offset, + detection.log_message) + + self.generic_visit(node) + + def visit_ImportFrom(self, node): # pylint: disable=invalid-name + """Handle visiting an import-from node in the AST. + + Args: + node: Current Node + """ + if not node.module: + self.generic_visit(node) + return + + from_import = node.module + + for import_alias in node.names: + # Detect based on full import name(to & as) + full_module_name = "%s.%s" % (from_import, import_alias.name) + full_import = (full_module_name, import_alias.asname) + detection = (self._api_analysis_spec + .imports_to_detect.get(full_import, None)) + if detection: + self.add_result(detection) + self.add_log( + detection.log_level, node.lineno, node.col_offset, + detection.log_message) + + self.generic_visit(node) + + class ASTCodeUpgrader(object): """Handles upgrading a set of Python files using a given API change spec.""" @@ -663,12 +794,18 @@ class ASTCodeUpgrader(object): log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] return 0, "", log, [] + preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t) + visitor = _PastaEditVisitor(self._api_change_spec) visitor.visit(t) - logs = [self.format_log(log, None) for log in visitor.log] + self._api_change_spec.clear_preprocessing() + + logs = [self.format_log(log, None) for log in (preprocess_logs + + visitor.log)] errors = [self.format_log(error, in_filename) - for error in visitor.warnings_and_errors] + for error in (preprocess_errors + + visitor.warnings_and_errors)] return 1, pasta.dump(t), logs, errors def _format_log(self, log, in_filename, out_filename): diff --git a/tensorflow/tools/compatibility/ast_edits_test.py b/tensorflow/tools/compatibility/ast_edits_test.py index dc2d9298f85..4571d30118b 100644 --- a/tensorflow/tools/compatibility/ast_edits_test.py +++ b/tensorflow/tools/compatibility/ast_edits_test.py @@ -52,29 +52,15 @@ from tensorflow.python.platform import test as test_lib from tensorflow.tools.compatibility import ast_edits -class NoUpdateSpec(ast_edits.APIChangeSpec): - """A specification of an API change which doesn't change anything.""" - - def __init__(self): - self.function_handle = {} - self.function_reorders = {} - self.function_keyword_renames = {} - self.symbol_renames = {} - self.function_warnings = {} - self.change_to_function = {} - self.module_deprecations = {} - self.import_renames = {} - - -class ModuleDeprecationSpec(NoUpdateSpec): +class ModuleDeprecationSpec(ast_edits.NoUpdateSpec): """A specification which deprecates 'a.b'.""" def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")}) -class RenameKeywordSpec(NoUpdateSpec): +class RenameKeywordSpec(ast_edits.NoUpdateSpec): """A specification where kw2 gets renamed to kw3. The new API is @@ -84,14 +70,14 @@ class RenameKeywordSpec(NoUpdateSpec): """ def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.update_renames() def update_renames(self): self.function_keyword_renames["f"] = {"kw2": "kw3"} -class ReorderKeywordSpec(NoUpdateSpec): +class ReorderKeywordSpec(ast_edits.NoUpdateSpec): """A specification where kw2 gets moved in front of kw1. The new API is @@ -101,7 +87,7 @@ class ReorderKeywordSpec(NoUpdateSpec): """ def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.update_reorders() def update_reorders(self): @@ -125,7 +111,7 @@ class ReorderAndRenameKeywordSpec(ReorderKeywordSpec, RenameKeywordSpec): self.update_reorders() -class RemoveDeprecatedAliasKeyword(NoUpdateSpec): +class RemoveDeprecatedAliasKeyword(ast_edits.NoUpdateSpec): """A specification where kw1_alias is removed in g. The new API is @@ -136,7 +122,7 @@ class RemoveDeprecatedAliasKeyword(NoUpdateSpec): """ def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.function_keyword_renames["g"] = {"kw1_alias": "kw1"} self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"} @@ -158,7 +144,7 @@ class RemoveDeprecatedAliasAndReorderRest(RemoveDeprecatedAliasKeyword): self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"] -class RemoveMultipleKeywordArguments(NoUpdateSpec): +class RemoveMultipleKeywordArguments(ast_edits.NoUpdateSpec): """A specification where both keyword aliases are removed from h. The new API is @@ -168,18 +154,18 @@ class RemoveMultipleKeywordArguments(NoUpdateSpec): """ def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.function_keyword_renames["h"] = { "kw1_alias": "kw1", "kw2_alias": "kw2", } -class RenameImports(NoUpdateSpec): +class RenameImports(ast_edits.NoUpdateSpec): """Specification for renaming imports.""" def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.import_renames = { "foo": ast_edits.ImportRename( "bar", @@ -209,11 +195,11 @@ class TestAstEdits(test_util.TensorFlowTestCase): def testNoTransformIfNothingIsSupplied(self): text = "f(a, b, kw1=c, kw2=d)\n" - _, new_text = self._upgrade(NoUpdateSpec(), text) + _, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text) self.assertEqual(new_text, text) text = "f(a, b, c, d)\n" - _, new_text = self._upgrade(NoUpdateSpec(), text) + _, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text) self.assertEqual(new_text, text) def testKeywordRename(self): @@ -447,11 +433,11 @@ class TestAstEdits(test_util.TensorFlowTestCase): self.assertIn(new_text, acceptable_outputs) def testUnrestrictedFunctionWarnings(self): - class FooWarningSpec(NoUpdateSpec): + class FooWarningSpec(ast_edits.NoUpdateSpec): """Usages of function attribute foo() prints out a warning.""" def __init__(self): - NoUpdateSpec.__init__(self) + ast_edits.NoUpdateSpec.__init__(self) self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")} texts = ["object.foo()", "get_object().foo()", diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index e55ad592bff..5cf31930aad 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -34,7 +34,35 @@ from tensorflow.tools.compatibility import reorders_v2 # pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison -class TFAPIChangeSpec(ast_edits.APIChangeSpec): +class UnaliasedTFImport(ast_edits.AnalysisResult): + + def __init__(self): + self.log_level = ast_edits.ERROR + self.log_message = ("The tf_upgrade_v2 script detected an unaliased " + "`import tensorflow`. The script can only run when " + "importing with `import tensorflow as tf`.") + + +class VersionedTFImport(ast_edits.AnalysisResult): + + def __init__(self, version): + self.log_level = ast_edits.INFO + self.log_message = ("Not upgrading symbols because `tensorflow." + version + + "` was directly imported as `tf`.") + + +class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec): + + def __init__(self): + self.symbols_to_detect = {} + self.imports_to_detect = { + ("tensorflow", None): UnaliasedTFImport(), + ("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"), + ("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"), + } + + +class TFAPIChangeSpec(ast_edits.NoUpdateSpec): """List of maps that describe what changed in the API.""" def __init__(self): @@ -481,6 +509,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Add additional renames not in renames_v2.py to all_renames_v2.py. self.symbol_renames = all_renames_v2.symbol_renames + self.import_renames = {} + # Variables that should be changed to functions. self.change_to_function = {} @@ -771,7 +801,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "extended.call_for_each_replica->experimental_run_v2, " "reduce requires an axis argument, " "unwrap->experimental_local_results " - "experimental_initialize and experimenta_finalize no longer needed ") + "experimental_initialize and experimental_finalize no longer needed ") contrib_mirrored_strategy_warning = ( ast_edits.ERROR, @@ -1474,6 +1504,27 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS + def preprocess(self, root_node): + visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec()) + visitor.visit(root_node) + detections = set(visitor.results) + # If we have detected the presence of imports of specific TF versions, + # We want to modify the update spec to check only module deprecations + # and skip all other conversions. + if detections: + self.function_handle = {} + self.function_reorders = {} + self.function_keyword_renames = {} + self.symbol_renames = {} + self.function_warnings = {} + self.change_to_function = {} + self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS + self.import_renames = {} + return visitor.log, visitor.warnings_and_errors + + def clear_preprocessing(self): + self.__init__() + def _is_ast_str(node): """Determine whether this node represents a string.""" diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index f02482a9eb1..98e2a0ec021 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -123,6 +123,18 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): "test_out.py", out_file)) return count, report, errors, out_file.getvalue() + def _upgrade_multiple(self, old_file_texts): + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec()) + results = [] + for old_file_text in old_file_texts: + in_file = six.StringIO(old_file_text) + out_file = six.StringIO() + count, report, errors = ( + upgrader.process_opened_file("test.py", in_file, + "test_out.py", out_file)) + results.append([count, report, errors, out_file.getvalue()]) + return results + def testParseError(self): _, report, unused_errors, unused_new_text = self._upgrade( "import tensorflow as tf\na + \n") @@ -1984,6 +1996,96 @@ def _log_prob(self, x): _, _, _, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) + def test_import_analysis(self): + old_symbol = "tf.conj(a)" + new_symbol = "tf.math.conj(a)" + + # We upgrade the base un-versioned tensorflow aliased as tf + import_header = "import tensorflow as tf\n" + text = import_header + old_symbol + expected_text = import_header + new_symbol + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + + import_header = ("import tensorflow as tf\n" + "import tensorflow.compat.v1 as tf_v1\n" + "import tensorflow.compat.v2 as tf_v2\n") + text = import_header + old_symbol + expected_text = import_header + new_symbol + _, _, _, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + + # We don't handle unaliased tensorflow imports currently, + # So the upgrade script show log errors + import_header = "import tensorflow\n" + text = import_header + old_symbol + expected_text = import_header + old_symbol + _, _, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + self.assertIn("unaliased `import tensorflow`", "\n".join(errors)) + + # Upgrading explicitly-versioned tf code is unsafe, but we don't + # need to throw errors when we detect explicitly-versioned tf. + import_header = "import tensorflow.compat.v1 as tf\n" + text = import_header + old_symbol + expected_text = import_header + old_symbol + _, report, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`", + report) + self.assertEmpty(errors) + + import_header = "from tensorflow.compat import v1 as tf\n" + text = import_header + old_symbol + expected_text = import_header + old_symbol + _, report, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`", + report) + self.assertEmpty(errors) + + import_header = "from tensorflow.compat import v1 as tf, v2 as tf2\n" + text = import_header + old_symbol + expected_text = import_header + old_symbol + _, report, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`", + report) + self.assertEmpty(errors) + + import_header = "import tensorflow.compat.v2 as tf\n" + text = import_header + old_symbol + expected_text = import_header + old_symbol + _, report, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`", + report) + self.assertEmpty(errors) + + import_header = "from tensorflow.compat import v1 as tf1, v2 as tf\n" + text = import_header + old_symbol + expected_text = import_header + old_symbol + _, report, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, expected_text) + self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`", + report) + self.assertEmpty(errors) + + def test_api_spec_reset_between_files(self): + old_symbol = "tf.conj(a)" + new_symbol = "tf.math.conj(a)" + + ## Test that the api spec is reset in between files: + import_header = "import tensorflow.compat.v2 as tf\n" + text_a = import_header + old_symbol + expected_text_a = import_header + old_symbol + text_b = old_symbol + expected_text_b = new_symbol + results = self._upgrade_multiple([text_a, text_b]) + result_a, result_b = results[0], results[1] + self.assertEqual(result_a[3], expected_text_a) + self.assertEqual(result_b[3], expected_text_b) + class TestUpgradeFiles(test_util.TensorFlowTestCase):