Setup skeleton for adding SAFETY mode to tf_upgrade_v2.py script.
PiperOrigin-RevId: 247692270
This commit is contained in:
parent
a5a28d5778
commit
188791fc0b
tensorflow/tools/compatibility
@ -91,6 +91,12 @@ py_library(
|
||||
deps = [":renames_v2"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "module_deprecations_v2",
|
||||
srcs = ["module_deprecations_v2.py"],
|
||||
deps = [":ast_edits"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_upgrade_v2_lib",
|
||||
srcs = ["tf_upgrade_v2.py"],
|
||||
@ -98,11 +104,22 @@ py_library(
|
||||
deps = [
|
||||
":all_renames_v2",
|
||||
":ast_edits",
|
||||
":module_deprecations_v2",
|
||||
":reorders_v2",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_upgrade_v2_safety_lib",
|
||||
srcs = ["tf_upgrade_v2_safety.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ast_edits",
|
||||
":module_deprecations_v2",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "tf_upgrade_v2",
|
||||
srcs = ["tf_upgrade_v2_main.py"],
|
||||
@ -113,6 +130,7 @@ py_binary(
|
||||
":ast_edits",
|
||||
":ipynb",
|
||||
":tf_upgrade_v2_lib",
|
||||
":tf_upgrade_v2_safety_lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -133,6 +151,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_upgrade_v2_safety_test",
|
||||
srcs = ["tf_upgrade_v2_safety_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tf_upgrade_v2_safety_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
# Keep for reference, this test will succeed in 0.11 but fail in 1.0
|
||||
# py_test(
|
||||
# name = "test_file_v0_11",
|
||||
|
64
tensorflow/tools/compatibility/module_deprecations_v2.py
Normal file
64
tensorflow/tools/compatibility/module_deprecations_v2.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module deprecation warnings for TensorFlow 2.0."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
|
||||
|
||||
_CONTRIB_WARNING = (
|
||||
ast_edits.ERROR,
|
||||
"<function name> cannot be converted automatically. tf.contrib will not"
|
||||
" be distributed with TensorFlow 2.0, please consider an alternative in"
|
||||
" non-contrib TensorFlow, a community-maintained repository, or fork "
|
||||
"the required code."
|
||||
)
|
||||
|
||||
_FLAGS_WARNING = (
|
||||
ast_edits.ERROR,
|
||||
"tf.flags has been removed, please use the argparse or absl"
|
||||
" modules if you need command line parsing."
|
||||
)
|
||||
|
||||
_CONTRIB_CUDNN_RNN_WARNING = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
|
||||
"and the CuDNN kernel has been integrated with "
|
||||
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
|
||||
"and use that instead."
|
||||
)
|
||||
|
||||
_CONTRIB_RNN_WARNING = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
|
||||
"widely used cells/functions will be moved to tensorflow/addons "
|
||||
"repository. Please check it there and file Github issues if necessary."
|
||||
)
|
||||
|
||||
_CONTRIB_DIST_STRAT_WARNING = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual edit required) tf.contrib.distribute.* have been migrated to"
|
||||
"tf.distribute.*. Please check out the new module for updates APIs.")
|
||||
|
||||
MODULE_DEPRECATIONS = {
|
||||
"tf.contrib": _CONTRIB_WARNING,
|
||||
"tf.contrib.cudnn_rnn": _CONTRIB_CUDNN_RNN_WARNING,
|
||||
"tf.contrib.rnn": _CONTRIB_RNN_WARNING,
|
||||
"tf.flags": _FLAGS_WARNING,
|
||||
"tf.contrib.distribute": _CONTRIB_DIST_STRAT_WARNING
|
||||
}
|
@ -27,6 +27,7 @@ import pasta
|
||||
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import module_deprecations_v2
|
||||
from tensorflow.tools.compatibility import reorders_v2
|
||||
|
||||
# These pylint warnings are a mistake.
|
||||
@ -622,34 +623,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
self.function_reorders = dict(reorders_v2.reorders)
|
||||
self.function_reorders.update(self.manual_function_reorders)
|
||||
|
||||
contrib_warning = (
|
||||
ast_edits.ERROR,
|
||||
"<function name> cannot be converted automatically. tf.contrib will not"
|
||||
" be distributed with TensorFlow 2.0, please consider an alternative in"
|
||||
" non-contrib TensorFlow, a community-maintained repository, or fork "
|
||||
"the required code."
|
||||
)
|
||||
|
||||
flags_warning = (
|
||||
ast_edits.ERROR,
|
||||
"tf.flags has been removed, please use the argparse or absl"
|
||||
" modules if you need command line parsing.")
|
||||
|
||||
contrib_cudnn_rnn_warning = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
|
||||
"and the CuDNN kernel has been integrated with "
|
||||
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
|
||||
"and use that instead."
|
||||
)
|
||||
|
||||
contrib_rnn_warning = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
|
||||
"widely used cells/functions will be moved to tensorflow/addons "
|
||||
"repository. Please check it there and file Github issues if necessary."
|
||||
)
|
||||
|
||||
decay_function_comment = (
|
||||
ast_edits.INFO,
|
||||
"To use learning rate decay schedules with TensorFlow 2.0, switch to "
|
||||
@ -790,11 +763,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"default, instead of HDF5. To continue saving to HDF5, add the "
|
||||
"argument save_format='h5' to the save() function.")
|
||||
|
||||
contrib_dist_strat_warning = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual edit required) tf.contrib.distribute.* have been migrated to"
|
||||
"tf.distribute.*. Please check out the new module for updates APIs.")
|
||||
|
||||
distribute_strategy_api_changes = (
|
||||
"If you're using the strategy with a "
|
||||
"custom training loop, note the following changes in methods: "
|
||||
@ -1504,13 +1472,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
arg_value_ast=ast.Str("h5")),
|
||||
}
|
||||
|
||||
self.module_deprecations = {
|
||||
"tf.contrib": contrib_warning,
|
||||
"tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning,
|
||||
"tf.contrib.rnn": contrib_rnn_warning,
|
||||
"tf.flags": flags_warning,
|
||||
"tf.contrib.distribute": contrib_dist_strat_warning
|
||||
}
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
|
||||
|
||||
def _is_ast_str(node):
|
||||
|
@ -19,11 +19,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
|
||||
from tensorflow.tools.compatibility import ipynb
|
||||
|
||||
# Make straightforward changes to convert to 2.0. In harder cases,
|
||||
# use compat.v1.
|
||||
_DEFAULT_MODE = "DEFAULT"
|
||||
|
||||
# Convert to use compat.v1.
|
||||
# TODO(kaftan): remove EXPERIMENTAL_ prefix once safety mode is
|
||||
# implemented.
|
||||
_SAFETY_MODE = "EXPERIMENTAL_SAFETY"
|
||||
|
||||
|
||||
def process_file(in_filename, out_filename, upgrader):
|
||||
"""Process a file of type `.py` or `.ipynb`."""
|
||||
@ -91,9 +102,27 @@ Simple usage:
|
||||
"stored."
|
||||
"(default: %(default)s)"),
|
||||
default="report.txt")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
dest="mode",
|
||||
choices=[_DEFAULT_MODE, _SAFETY_MODE],
|
||||
help=("Upgrade script mode. Supported modes:\n"
|
||||
"%s: Perform only straightforward conversions to upgrade to "
|
||||
"2.0. In more difficult cases, switch to use compat.v1.\n"
|
||||
"%s: Keep 1.* code intact and import compat.v1 "
|
||||
"module. Note: safety mode is under development and not available "
|
||||
"yet." % (_DEFAULT_MODE, _SAFETY_MODE)),
|
||||
default=_DEFAULT_MODE)
|
||||
args = parser.parse_args()
|
||||
|
||||
upgrade = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
|
||||
if args.mode == _SAFETY_MODE:
|
||||
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
|
||||
sys.stderr.write(
|
||||
"%s mode is not fully implemented yet." % _SAFETY_MODE)
|
||||
else:
|
||||
change_spec = tf_upgrade_v2.TFAPIChangeSpec()
|
||||
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
|
||||
|
||||
report_text = None
|
||||
report_filename = args.report_filename
|
||||
files_processed = 0
|
||||
|
38
tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
Normal file
38
tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Upgrader for Python scripts from 1.* to 2.0 TensorFlow using SAFETY mode."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import module_deprecations_v2
|
||||
|
||||
|
||||
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"""List of maps that describe what changed in the API."""
|
||||
|
||||
def __init__(self):
|
||||
self.function_keyword_renames = {}
|
||||
self.symbol_renames = {}
|
||||
self.change_to_function = {}
|
||||
self.function_reorders = {}
|
||||
self.function_warnings = {}
|
||||
self.function_transformers = {}
|
||||
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||
|
||||
# TODO(kaftan,annarev): specify replacement from TensorFlow import to
|
||||
# compat.v1 import.
|
47
tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
Normal file
47
tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf 2.0 upgrader in safety mode."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
|
||||
|
||||
|
||||
class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _upgrade(self, old_file_text):
|
||||
in_file = six.StringIO(old_file_text)
|
||||
out_file = six.StringIO()
|
||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2_safety.TFAPIChangeSpec())
|
||||
count, report, errors = (
|
||||
upgrader.process_opened_file("test.py", in_file,
|
||||
"test_out.py", out_file))
|
||||
return count, report, errors, out_file.getvalue()
|
||||
|
||||
def testContribWarning(self):
|
||||
text = "tf.contrib.foo()"
|
||||
_, report, _, _ = self._upgrade(text)
|
||||
expected_info = "tf.contrib will not be distributed"
|
||||
self.assertIn(expected_info, report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
Loading…
Reference in New Issue
Block a user