This CL allows the tf_upgrade_v2 script to avoid changing code when import tensorflow.compat.v1/v2 as tf
is used, and to raise an error when import tensorflow
is directly used without any tf
alias.
To do this, the CL support for a code analysis step before upgrading code. In the future this would allow correctly handling non-`tf` aliases of tensorflow, as well as detecting when additional imports need to be added. PiperOrigin-RevId: 249264134
This commit is contained in:
parent
269efd353a
commit
9d46f5599f
@ -146,6 +146,31 @@ class APIChangeSpec(object):
|
|||||||
For an example, see `TFAPIChangeSpec`.
|
For an example, see `TFAPIChangeSpec`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def preprocess(self, root_node): # pylint: disable=unused-argument
|
||||||
|
"""Preprocess a parse tree. Return any produced logs and errors."""
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
def clear_preprocessing(self):
|
||||||
|
"""Restore this APIChangeSpec to before it preprocessed a file.
|
||||||
|
|
||||||
|
This is needed if preprocessing a file changed any rewriting rules.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoUpdateSpec(APIChangeSpec):
|
||||||
|
"""A specification of an API change which doesn't change anything."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.function_handle = {}
|
||||||
|
self.function_reorders = {}
|
||||||
|
self.function_keyword_renames = {}
|
||||||
|
self.symbol_renames = {}
|
||||||
|
self.function_warnings = {}
|
||||||
|
self.change_to_function = {}
|
||||||
|
self.module_deprecations = {}
|
||||||
|
self.import_renames = {}
|
||||||
|
|
||||||
|
|
||||||
class _PastaEditVisitor(ast.NodeVisitor):
|
class _PastaEditVisitor(ast.NodeVisitor):
|
||||||
"""AST Visitor that processes function calls.
|
"""AST Visitor that processes function calls.
|
||||||
@ -618,6 +643,112 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
|||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisResult(object):
|
||||||
|
"""This class represents an analysis result and how it should be logged.
|
||||||
|
|
||||||
|
This class must provide the following fields:
|
||||||
|
|
||||||
|
* `log_level`: The log level to which this detection should be logged
|
||||||
|
* `log_message`: The message that should be logged for this detection
|
||||||
|
|
||||||
|
For an example, see `VersionedTFImport`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class APIAnalysisSpec(object):
|
||||||
|
"""This class defines how `AnalysisResult`s should be generated.
|
||||||
|
|
||||||
|
It specifies how to map imports and symbols to `AnalysisResult`s.
|
||||||
|
|
||||||
|
This class must provide the following fields:
|
||||||
|
|
||||||
|
* `symbols_to_detect`: maps function names to `AnalysisResult`s
|
||||||
|
* `imports_to_detect`: maps imports represented as (full module name, alias)
|
||||||
|
tuples to `AnalysisResult`s
|
||||||
|
notifications)
|
||||||
|
|
||||||
|
For an example, see `TFAPIImportAnalysisSpec`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PastaAnalyzeVisitor(_PastaEditVisitor):
|
||||||
|
"""AST Visitor that looks for specific API usage without editing anything.
|
||||||
|
|
||||||
|
This is used before any rewriting is done to detect if any symbols are used
|
||||||
|
that require changing imports or disabling rewriting altogether.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_analysis_spec):
|
||||||
|
super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec())
|
||||||
|
self._api_analysis_spec = api_analysis_spec
|
||||||
|
self._results = [] # Holds AnalysisResult objects
|
||||||
|
|
||||||
|
@property
|
||||||
|
def results(self):
|
||||||
|
return self._results
|
||||||
|
|
||||||
|
def add_result(self, analysis_result):
|
||||||
|
self._results.append(analysis_result)
|
||||||
|
|
||||||
|
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||||
|
"""Handle bare Attributes i.e. [tf.foo, tf.bar]."""
|
||||||
|
full_name = self._get_full_name(node)
|
||||||
|
if full_name:
|
||||||
|
detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None)
|
||||||
|
if detection:
|
||||||
|
self.add_result(detection)
|
||||||
|
self.add_log(
|
||||||
|
detection.log_level, node.lineno, node.col_offset,
|
||||||
|
detection.log_message)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_Import(self, node): # pylint: disable=invalid-name
|
||||||
|
"""Handle visiting an import node in the AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current Node
|
||||||
|
"""
|
||||||
|
for import_alias in node.names:
|
||||||
|
# Detect based on full import name and alias)
|
||||||
|
full_import = (import_alias.name, import_alias.asname)
|
||||||
|
detection = (self._api_analysis_spec
|
||||||
|
.imports_to_detect.get(full_import, None))
|
||||||
|
if detection:
|
||||||
|
self.add_result(detection)
|
||||||
|
self.add_log(
|
||||||
|
detection.log_level, node.lineno, node.col_offset,
|
||||||
|
detection.log_message)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
|
||||||
|
"""Handle visiting an import-from node in the AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current Node
|
||||||
|
"""
|
||||||
|
if not node.module:
|
||||||
|
self.generic_visit(node)
|
||||||
|
return
|
||||||
|
|
||||||
|
from_import = node.module
|
||||||
|
|
||||||
|
for import_alias in node.names:
|
||||||
|
# Detect based on full import name(to & as)
|
||||||
|
full_module_name = "%s.%s" % (from_import, import_alias.name)
|
||||||
|
full_import = (full_module_name, import_alias.asname)
|
||||||
|
detection = (self._api_analysis_spec
|
||||||
|
.imports_to_detect.get(full_import, None))
|
||||||
|
if detection:
|
||||||
|
self.add_result(detection)
|
||||||
|
self.add_log(
|
||||||
|
detection.log_level, node.lineno, node.col_offset,
|
||||||
|
detection.log_message)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
|
||||||
class ASTCodeUpgrader(object):
|
class ASTCodeUpgrader(object):
|
||||||
"""Handles upgrading a set of Python files using a given API change spec."""
|
"""Handles upgrading a set of Python files using a given API change spec."""
|
||||||
|
|
||||||
@ -663,12 +794,18 @@ class ASTCodeUpgrader(object):
|
|||||||
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
|
log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
|
||||||
return 0, "", log, []
|
return 0, "", log, []
|
||||||
|
|
||||||
|
preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t)
|
||||||
|
|
||||||
visitor = _PastaEditVisitor(self._api_change_spec)
|
visitor = _PastaEditVisitor(self._api_change_spec)
|
||||||
visitor.visit(t)
|
visitor.visit(t)
|
||||||
|
|
||||||
logs = [self.format_log(log, None) for log in visitor.log]
|
self._api_change_spec.clear_preprocessing()
|
||||||
|
|
||||||
|
logs = [self.format_log(log, None) for log in (preprocess_logs +
|
||||||
|
visitor.log)]
|
||||||
errors = [self.format_log(error, in_filename)
|
errors = [self.format_log(error, in_filename)
|
||||||
for error in visitor.warnings_and_errors]
|
for error in (preprocess_errors +
|
||||||
|
visitor.warnings_and_errors)]
|
||||||
return 1, pasta.dump(t), logs, errors
|
return 1, pasta.dump(t), logs, errors
|
||||||
|
|
||||||
def _format_log(self, log, in_filename, out_filename):
|
def _format_log(self, log, in_filename, out_filename):
|
||||||
|
@ -52,29 +52,15 @@ from tensorflow.python.platform import test as test_lib
|
|||||||
from tensorflow.tools.compatibility import ast_edits
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
|
|
||||||
|
|
||||||
class NoUpdateSpec(ast_edits.APIChangeSpec):
|
class ModuleDeprecationSpec(ast_edits.NoUpdateSpec):
|
||||||
"""A specification of an API change which doesn't change anything."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.function_handle = {}
|
|
||||||
self.function_reorders = {}
|
|
||||||
self.function_keyword_renames = {}
|
|
||||||
self.symbol_renames = {}
|
|
||||||
self.function_warnings = {}
|
|
||||||
self.change_to_function = {}
|
|
||||||
self.module_deprecations = {}
|
|
||||||
self.import_renames = {}
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleDeprecationSpec(NoUpdateSpec):
|
|
||||||
"""A specification which deprecates 'a.b'."""
|
"""A specification which deprecates 'a.b'."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")})
|
self.module_deprecations.update({"a.b": (ast_edits.ERROR, "a.b is evil.")})
|
||||||
|
|
||||||
|
|
||||||
class RenameKeywordSpec(NoUpdateSpec):
|
class RenameKeywordSpec(ast_edits.NoUpdateSpec):
|
||||||
"""A specification where kw2 gets renamed to kw3.
|
"""A specification where kw2 gets renamed to kw3.
|
||||||
|
|
||||||
The new API is
|
The new API is
|
||||||
@ -84,14 +70,14 @@ class RenameKeywordSpec(NoUpdateSpec):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.update_renames()
|
self.update_renames()
|
||||||
|
|
||||||
def update_renames(self):
|
def update_renames(self):
|
||||||
self.function_keyword_renames["f"] = {"kw2": "kw3"}
|
self.function_keyword_renames["f"] = {"kw2": "kw3"}
|
||||||
|
|
||||||
|
|
||||||
class ReorderKeywordSpec(NoUpdateSpec):
|
class ReorderKeywordSpec(ast_edits.NoUpdateSpec):
|
||||||
"""A specification where kw2 gets moved in front of kw1.
|
"""A specification where kw2 gets moved in front of kw1.
|
||||||
|
|
||||||
The new API is
|
The new API is
|
||||||
@ -101,7 +87,7 @@ class ReorderKeywordSpec(NoUpdateSpec):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.update_reorders()
|
self.update_reorders()
|
||||||
|
|
||||||
def update_reorders(self):
|
def update_reorders(self):
|
||||||
@ -125,7 +111,7 @@ class ReorderAndRenameKeywordSpec(ReorderKeywordSpec, RenameKeywordSpec):
|
|||||||
self.update_reorders()
|
self.update_reorders()
|
||||||
|
|
||||||
|
|
||||||
class RemoveDeprecatedAliasKeyword(NoUpdateSpec):
|
class RemoveDeprecatedAliasKeyword(ast_edits.NoUpdateSpec):
|
||||||
"""A specification where kw1_alias is removed in g.
|
"""A specification where kw1_alias is removed in g.
|
||||||
|
|
||||||
The new API is
|
The new API is
|
||||||
@ -136,7 +122,7 @@ class RemoveDeprecatedAliasKeyword(NoUpdateSpec):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.function_keyword_renames["g"] = {"kw1_alias": "kw1"}
|
self.function_keyword_renames["g"] = {"kw1_alias": "kw1"}
|
||||||
self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"}
|
self.function_keyword_renames["g2"] = {"kw1_alias": "kw1"}
|
||||||
|
|
||||||
@ -158,7 +144,7 @@ class RemoveDeprecatedAliasAndReorderRest(RemoveDeprecatedAliasKeyword):
|
|||||||
self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"]
|
self.function_reorders["g2"] = ["a", "b", "kw1", "c", "d"]
|
||||||
|
|
||||||
|
|
||||||
class RemoveMultipleKeywordArguments(NoUpdateSpec):
|
class RemoveMultipleKeywordArguments(ast_edits.NoUpdateSpec):
|
||||||
"""A specification where both keyword aliases are removed from h.
|
"""A specification where both keyword aliases are removed from h.
|
||||||
|
|
||||||
The new API is
|
The new API is
|
||||||
@ -168,18 +154,18 @@ class RemoveMultipleKeywordArguments(NoUpdateSpec):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.function_keyword_renames["h"] = {
|
self.function_keyword_renames["h"] = {
|
||||||
"kw1_alias": "kw1",
|
"kw1_alias": "kw1",
|
||||||
"kw2_alias": "kw2",
|
"kw2_alias": "kw2",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RenameImports(NoUpdateSpec):
|
class RenameImports(ast_edits.NoUpdateSpec):
|
||||||
"""Specification for renaming imports."""
|
"""Specification for renaming imports."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.import_renames = {
|
self.import_renames = {
|
||||||
"foo": ast_edits.ImportRename(
|
"foo": ast_edits.ImportRename(
|
||||||
"bar",
|
"bar",
|
||||||
@ -209,11 +195,11 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testNoTransformIfNothingIsSupplied(self):
|
def testNoTransformIfNothingIsSupplied(self):
|
||||||
text = "f(a, b, kw1=c, kw2=d)\n"
|
text = "f(a, b, kw1=c, kw2=d)\n"
|
||||||
_, new_text = self._upgrade(NoUpdateSpec(), text)
|
_, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
|
||||||
self.assertEqual(new_text, text)
|
self.assertEqual(new_text, text)
|
||||||
|
|
||||||
text = "f(a, b, c, d)\n"
|
text = "f(a, b, c, d)\n"
|
||||||
_, new_text = self._upgrade(NoUpdateSpec(), text)
|
_, new_text = self._upgrade(ast_edits.NoUpdateSpec(), text)
|
||||||
self.assertEqual(new_text, text)
|
self.assertEqual(new_text, text)
|
||||||
|
|
||||||
def testKeywordRename(self):
|
def testKeywordRename(self):
|
||||||
@ -447,11 +433,11 @@ class TestAstEdits(test_util.TensorFlowTestCase):
|
|||||||
self.assertIn(new_text, acceptable_outputs)
|
self.assertIn(new_text, acceptable_outputs)
|
||||||
|
|
||||||
def testUnrestrictedFunctionWarnings(self):
|
def testUnrestrictedFunctionWarnings(self):
|
||||||
class FooWarningSpec(NoUpdateSpec):
|
class FooWarningSpec(ast_edits.NoUpdateSpec):
|
||||||
"""Usages of function attribute foo() prints out a warning."""
|
"""Usages of function attribute foo() prints out a warning."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
NoUpdateSpec.__init__(self)
|
ast_edits.NoUpdateSpec.__init__(self)
|
||||||
self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
|
self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
|
||||||
|
|
||||||
texts = ["object.foo()", "get_object().foo()",
|
texts = ["object.foo()", "get_object().foo()",
|
||||||
|
@ -34,7 +34,35 @@ from tensorflow.tools.compatibility import reorders_v2
|
|||||||
# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison
|
# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison
|
||||||
|
|
||||||
|
|
||||||
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
class UnaliasedTFImport(ast_edits.AnalysisResult):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.log_level = ast_edits.ERROR
|
||||||
|
self.log_message = ("The tf_upgrade_v2 script detected an unaliased "
|
||||||
|
"`import tensorflow`. The script can only run when "
|
||||||
|
"importing with `import tensorflow as tf`.")
|
||||||
|
|
||||||
|
|
||||||
|
class VersionedTFImport(ast_edits.AnalysisResult):
|
||||||
|
|
||||||
|
def __init__(self, version):
|
||||||
|
self.log_level = ast_edits.INFO
|
||||||
|
self.log_message = ("Not upgrading symbols because `tensorflow." + version
|
||||||
|
+ "` was directly imported as `tf`.")
|
||||||
|
|
||||||
|
|
||||||
|
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.symbols_to_detect = {}
|
||||||
|
self.imports_to_detect = {
|
||||||
|
("tensorflow", None): UnaliasedTFImport(),
|
||||||
|
("tensorflow.compat.v1", "tf"): VersionedTFImport("compat.v1"),
|
||||||
|
("tensorflow.compat.v2", "tf"): VersionedTFImport("compat.v2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TFAPIChangeSpec(ast_edits.NoUpdateSpec):
|
||||||
"""List of maps that describe what changed in the API."""
|
"""List of maps that describe what changed in the API."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -481,6 +509,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
# Add additional renames not in renames_v2.py to all_renames_v2.py.
|
# Add additional renames not in renames_v2.py to all_renames_v2.py.
|
||||||
self.symbol_renames = all_renames_v2.symbol_renames
|
self.symbol_renames = all_renames_v2.symbol_renames
|
||||||
|
|
||||||
|
self.import_renames = {}
|
||||||
|
|
||||||
# Variables that should be changed to functions.
|
# Variables that should be changed to functions.
|
||||||
self.change_to_function = {}
|
self.change_to_function = {}
|
||||||
|
|
||||||
@ -771,7 +801,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
"extended.call_for_each_replica->experimental_run_v2, "
|
"extended.call_for_each_replica->experimental_run_v2, "
|
||||||
"reduce requires an axis argument, "
|
"reduce requires an axis argument, "
|
||||||
"unwrap->experimental_local_results "
|
"unwrap->experimental_local_results "
|
||||||
"experimental_initialize and experimenta_finalize no longer needed ")
|
"experimental_initialize and experimental_finalize no longer needed ")
|
||||||
|
|
||||||
contrib_mirrored_strategy_warning = (
|
contrib_mirrored_strategy_warning = (
|
||||||
ast_edits.ERROR,
|
ast_edits.ERROR,
|
||||||
@ -1474,6 +1504,27 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
|
|
||||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||||
|
|
||||||
|
def preprocess(self, root_node):
|
||||||
|
visitor = ast_edits.PastaAnalyzeVisitor(TFAPIImportAnalysisSpec())
|
||||||
|
visitor.visit(root_node)
|
||||||
|
detections = set(visitor.results)
|
||||||
|
# If we have detected the presence of imports of specific TF versions,
|
||||||
|
# We want to modify the update spec to check only module deprecations
|
||||||
|
# and skip all other conversions.
|
||||||
|
if detections:
|
||||||
|
self.function_handle = {}
|
||||||
|
self.function_reorders = {}
|
||||||
|
self.function_keyword_renames = {}
|
||||||
|
self.symbol_renames = {}
|
||||||
|
self.function_warnings = {}
|
||||||
|
self.change_to_function = {}
|
||||||
|
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||||
|
self.import_renames = {}
|
||||||
|
return visitor.log, visitor.warnings_and_errors
|
||||||
|
|
||||||
|
def clear_preprocessing(self):
|
||||||
|
self.__init__()
|
||||||
|
|
||||||
|
|
||||||
def _is_ast_str(node):
|
def _is_ast_str(node):
|
||||||
"""Determine whether this node represents a string."""
|
"""Determine whether this node represents a string."""
|
||||||
|
@ -123,6 +123,18 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
"test_out.py", out_file))
|
"test_out.py", out_file))
|
||||||
return count, report, errors, out_file.getvalue()
|
return count, report, errors, out_file.getvalue()
|
||||||
|
|
||||||
|
def _upgrade_multiple(self, old_file_texts):
|
||||||
|
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
|
||||||
|
results = []
|
||||||
|
for old_file_text in old_file_texts:
|
||||||
|
in_file = six.StringIO(old_file_text)
|
||||||
|
out_file = six.StringIO()
|
||||||
|
count, report, errors = (
|
||||||
|
upgrader.process_opened_file("test.py", in_file,
|
||||||
|
"test_out.py", out_file))
|
||||||
|
results.append([count, report, errors, out_file.getvalue()])
|
||||||
|
return results
|
||||||
|
|
||||||
def testParseError(self):
|
def testParseError(self):
|
||||||
_, report, unused_errors, unused_new_text = self._upgrade(
|
_, report, unused_errors, unused_new_text = self._upgrade(
|
||||||
"import tensorflow as tf\na + \n")
|
"import tensorflow as tf\na + \n")
|
||||||
@ -1984,6 +1996,96 @@ def _log_prob(self, x):
|
|||||||
_, _, _, new_text = self._upgrade(text)
|
_, _, _, new_text = self._upgrade(text)
|
||||||
self.assertEqual(expected_text, new_text)
|
self.assertEqual(expected_text, new_text)
|
||||||
|
|
||||||
|
def test_import_analysis(self):
|
||||||
|
old_symbol = "tf.conj(a)"
|
||||||
|
new_symbol = "tf.math.conj(a)"
|
||||||
|
|
||||||
|
# We upgrade the base un-versioned tensorflow aliased as tf
|
||||||
|
import_header = "import tensorflow as tf\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + new_symbol
|
||||||
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
|
||||||
|
import_header = ("import tensorflow as tf\n"
|
||||||
|
"import tensorflow.compat.v1 as tf_v1\n"
|
||||||
|
"import tensorflow.compat.v2 as tf_v2\n")
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + new_symbol
|
||||||
|
_, _, _, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
|
||||||
|
# We don't handle unaliased tensorflow imports currently,
|
||||||
|
# So the upgrade script show log errors
|
||||||
|
import_header = "import tensorflow\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + old_symbol
|
||||||
|
_, _, errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
self.assertIn("unaliased `import tensorflow`", "\n".join(errors))
|
||||||
|
|
||||||
|
# Upgrading explicitly-versioned tf code is unsafe, but we don't
|
||||||
|
# need to throw errors when we detect explicitly-versioned tf.
|
||||||
|
import_header = "import tensorflow.compat.v1 as tf\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + old_symbol
|
||||||
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
|
||||||
|
report)
|
||||||
|
self.assertEmpty(errors)
|
||||||
|
|
||||||
|
import_header = "from tensorflow.compat import v1 as tf\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + old_symbol
|
||||||
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
|
||||||
|
report)
|
||||||
|
self.assertEmpty(errors)
|
||||||
|
|
||||||
|
import_header = "from tensorflow.compat import v1 as tf, v2 as tf2\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + old_symbol
|
||||||
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
self.assertIn("`tensorflow.compat.v1` was directly imported as `tf`",
|
||||||
|
report)
|
||||||
|
self.assertEmpty(errors)
|
||||||
|
|
||||||
|
import_header = "import tensorflow.compat.v2 as tf\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + old_symbol
|
||||||
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
|
||||||
|
report)
|
||||||
|
self.assertEmpty(errors)
|
||||||
|
|
||||||
|
import_header = "from tensorflow.compat import v1 as tf1, v2 as tf\n"
|
||||||
|
text = import_header + old_symbol
|
||||||
|
expected_text = import_header + old_symbol
|
||||||
|
_, report, errors, new_text = self._upgrade(text)
|
||||||
|
self.assertEqual(new_text, expected_text)
|
||||||
|
self.assertIn("`tensorflow.compat.v2` was directly imported as `tf`",
|
||||||
|
report)
|
||||||
|
self.assertEmpty(errors)
|
||||||
|
|
||||||
|
def test_api_spec_reset_between_files(self):
|
||||||
|
old_symbol = "tf.conj(a)"
|
||||||
|
new_symbol = "tf.math.conj(a)"
|
||||||
|
|
||||||
|
## Test that the api spec is reset in between files:
|
||||||
|
import_header = "import tensorflow.compat.v2 as tf\n"
|
||||||
|
text_a = import_header + old_symbol
|
||||||
|
expected_text_a = import_header + old_symbol
|
||||||
|
text_b = old_symbol
|
||||||
|
expected_text_b = new_symbol
|
||||||
|
results = self._upgrade_multiple([text_a, text_b])
|
||||||
|
result_a, result_b = results[0], results[1]
|
||||||
|
self.assertEqual(result_a[3], expected_text_a)
|
||||||
|
self.assertEqual(result_b[3], expected_text_b)
|
||||||
|
|
||||||
|
|
||||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user