Setup skeleton for adding SAFETY mode to tf_upgrade_v2.py script.
PiperOrigin-RevId: 247692270
This commit is contained in:
parent
a5a28d5778
commit
188791fc0b
tensorflow/tools/compatibility
@ -91,6 +91,12 @@ py_library(
|
|||||||
deps = [":renames_v2"],
|
deps = [":renames_v2"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "module_deprecations_v2",
|
||||||
|
srcs = ["module_deprecations_v2.py"],
|
||||||
|
deps = [":ast_edits"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tf_upgrade_v2_lib",
|
name = "tf_upgrade_v2_lib",
|
||||||
srcs = ["tf_upgrade_v2.py"],
|
srcs = ["tf_upgrade_v2.py"],
|
||||||
@ -98,11 +104,22 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":all_renames_v2",
|
":all_renames_v2",
|
||||||
":ast_edits",
|
":ast_edits",
|
||||||
|
":module_deprecations_v2",
|
||||||
":reorders_v2",
|
":reorders_v2",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tf_upgrade_v2_safety_lib",
|
||||||
|
srcs = ["tf_upgrade_v2_safety.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":ast_edits",
|
||||||
|
":module_deprecations_v2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "tf_upgrade_v2",
|
name = "tf_upgrade_v2",
|
||||||
srcs = ["tf_upgrade_v2_main.py"],
|
srcs = ["tf_upgrade_v2_main.py"],
|
||||||
@ -113,6 +130,7 @@ py_binary(
|
|||||||
":ast_edits",
|
":ast_edits",
|
||||||
":ipynb",
|
":ipynb",
|
||||||
":tf_upgrade_v2_lib",
|
":tf_upgrade_v2_lib",
|
||||||
|
":tf_upgrade_v2_safety_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -133,6 +151,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "tf_upgrade_v2_safety_test",
|
||||||
|
srcs = ["tf_upgrade_v2_safety_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":tf_upgrade_v2_safety_lib",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Keep for reference, this test will succeed in 0.11 but fail in 1.0
|
# Keep for reference, this test will succeed in 0.11 but fail in 1.0
|
||||||
# py_test(
|
# py_test(
|
||||||
# name = "test_file_v0_11",
|
# name = "test_file_v0_11",
|
||||||
|
64
tensorflow/tools/compatibility/module_deprecations_v2.py
Normal file
64
tensorflow/tools/compatibility/module_deprecations_v2.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Module deprecation warnings for TensorFlow 2.0."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
|
|
||||||
|
|
||||||
|
_CONTRIB_WARNING = (
|
||||||
|
ast_edits.ERROR,
|
||||||
|
"<function name> cannot be converted automatically. tf.contrib will not"
|
||||||
|
" be distributed with TensorFlow 2.0, please consider an alternative in"
|
||||||
|
" non-contrib TensorFlow, a community-maintained repository, or fork "
|
||||||
|
"the required code."
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLAGS_WARNING = (
|
||||||
|
ast_edits.ERROR,
|
||||||
|
"tf.flags has been removed, please use the argparse or absl"
|
||||||
|
" modules if you need command line parsing."
|
||||||
|
)
|
||||||
|
|
||||||
|
_CONTRIB_CUDNN_RNN_WARNING = (
|
||||||
|
ast_edits.WARNING,
|
||||||
|
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
|
||||||
|
"and the CuDNN kernel has been integrated with "
|
||||||
|
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
|
||||||
|
"and use that instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
_CONTRIB_RNN_WARNING = (
|
||||||
|
ast_edits.WARNING,
|
||||||
|
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
|
||||||
|
"widely used cells/functions will be moved to tensorflow/addons "
|
||||||
|
"repository. Please check it there and file Github issues if necessary."
|
||||||
|
)
|
||||||
|
|
||||||
|
_CONTRIB_DIST_STRAT_WARNING = (
|
||||||
|
ast_edits.WARNING,
|
||||||
|
"(Manual edit required) tf.contrib.distribute.* have been migrated to"
|
||||||
|
"tf.distribute.*. Please check out the new module for updates APIs.")
|
||||||
|
|
||||||
|
MODULE_DEPRECATIONS = {
|
||||||
|
"tf.contrib": _CONTRIB_WARNING,
|
||||||
|
"tf.contrib.cudnn_rnn": _CONTRIB_CUDNN_RNN_WARNING,
|
||||||
|
"tf.contrib.rnn": _CONTRIB_RNN_WARNING,
|
||||||
|
"tf.flags": _FLAGS_WARNING,
|
||||||
|
"tf.contrib.distribute": _CONTRIB_DIST_STRAT_WARNING
|
||||||
|
}
|
@ -27,6 +27,7 @@ import pasta
|
|||||||
|
|
||||||
from tensorflow.tools.compatibility import all_renames_v2
|
from tensorflow.tools.compatibility import all_renames_v2
|
||||||
from tensorflow.tools.compatibility import ast_edits
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
|
from tensorflow.tools.compatibility import module_deprecations_v2
|
||||||
from tensorflow.tools.compatibility import reorders_v2
|
from tensorflow.tools.compatibility import reorders_v2
|
||||||
|
|
||||||
# These pylint warnings are a mistake.
|
# These pylint warnings are a mistake.
|
||||||
@ -622,34 +623,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
self.function_reorders = dict(reorders_v2.reorders)
|
self.function_reorders = dict(reorders_v2.reorders)
|
||||||
self.function_reorders.update(self.manual_function_reorders)
|
self.function_reorders.update(self.manual_function_reorders)
|
||||||
|
|
||||||
contrib_warning = (
|
|
||||||
ast_edits.ERROR,
|
|
||||||
"<function name> cannot be converted automatically. tf.contrib will not"
|
|
||||||
" be distributed with TensorFlow 2.0, please consider an alternative in"
|
|
||||||
" non-contrib TensorFlow, a community-maintained repository, or fork "
|
|
||||||
"the required code."
|
|
||||||
)
|
|
||||||
|
|
||||||
flags_warning = (
|
|
||||||
ast_edits.ERROR,
|
|
||||||
"tf.flags has been removed, please use the argparse or absl"
|
|
||||||
" modules if you need command line parsing.")
|
|
||||||
|
|
||||||
contrib_cudnn_rnn_warning = (
|
|
||||||
ast_edits.WARNING,
|
|
||||||
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
|
|
||||||
"and the CuDNN kernel has been integrated with "
|
|
||||||
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
|
|
||||||
"and use that instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
contrib_rnn_warning = (
|
|
||||||
ast_edits.WARNING,
|
|
||||||
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
|
|
||||||
"widely used cells/functions will be moved to tensorflow/addons "
|
|
||||||
"repository. Please check it there and file Github issues if necessary."
|
|
||||||
)
|
|
||||||
|
|
||||||
decay_function_comment = (
|
decay_function_comment = (
|
||||||
ast_edits.INFO,
|
ast_edits.INFO,
|
||||||
"To use learning rate decay schedules with TensorFlow 2.0, switch to "
|
"To use learning rate decay schedules with TensorFlow 2.0, switch to "
|
||||||
@ -790,11 +763,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
"default, instead of HDF5. To continue saving to HDF5, add the "
|
"default, instead of HDF5. To continue saving to HDF5, add the "
|
||||||
"argument save_format='h5' to the save() function.")
|
"argument save_format='h5' to the save() function.")
|
||||||
|
|
||||||
contrib_dist_strat_warning = (
|
|
||||||
ast_edits.WARNING,
|
|
||||||
"(Manual edit required) tf.contrib.distribute.* have been migrated to"
|
|
||||||
"tf.distribute.*. Please check out the new module for updates APIs.")
|
|
||||||
|
|
||||||
distribute_strategy_api_changes = (
|
distribute_strategy_api_changes = (
|
||||||
"If you're using the strategy with a "
|
"If you're using the strategy with a "
|
||||||
"custom training loop, note the following changes in methods: "
|
"custom training loop, note the following changes in methods: "
|
||||||
@ -1504,13 +1472,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
arg_value_ast=ast.Str("h5")),
|
arg_value_ast=ast.Str("h5")),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.module_deprecations = {
|
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||||
"tf.contrib": contrib_warning,
|
|
||||||
"tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning,
|
|
||||||
"tf.contrib.rnn": contrib_rnn_warning,
|
|
||||||
"tf.flags": flags_warning,
|
|
||||||
"tf.contrib.distribute": contrib_dist_strat_warning
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _is_ast_str(node):
|
def _is_ast_str(node):
|
||||||
|
@ -19,11 +19,22 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
from tensorflow.tools.compatibility import ast_edits
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
from tensorflow.tools.compatibility import tf_upgrade_v2
|
from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||||
|
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
|
||||||
from tensorflow.tools.compatibility import ipynb
|
from tensorflow.tools.compatibility import ipynb
|
||||||
|
|
||||||
|
# Make straightforward changes to convert to 2.0. In harder cases,
|
||||||
|
# use compat.v1.
|
||||||
|
_DEFAULT_MODE = "DEFAULT"
|
||||||
|
|
||||||
|
# Convert to use compat.v1.
|
||||||
|
# TODO(kaftan): remove EXPERIMENTAL_ prefix once safety mode is
|
||||||
|
# implemented.
|
||||||
|
_SAFETY_MODE = "EXPERIMENTAL_SAFETY"
|
||||||
|
|
||||||
|
|
||||||
def process_file(in_filename, out_filename, upgrader):
|
def process_file(in_filename, out_filename, upgrader):
|
||||||
"""Process a file of type `.py` or `.ipynb`."""
|
"""Process a file of type `.py` or `.ipynb`."""
|
||||||
@ -91,9 +102,27 @@ Simple usage:
|
|||||||
"stored."
|
"stored."
|
||||||
"(default: %(default)s)"),
|
"(default: %(default)s)"),
|
||||||
default="report.txt")
|
default="report.txt")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
dest="mode",
|
||||||
|
choices=[_DEFAULT_MODE, _SAFETY_MODE],
|
||||||
|
help=("Upgrade script mode. Supported modes:\n"
|
||||||
|
"%s: Perform only straightforward conversions to upgrade to "
|
||||||
|
"2.0. In more difficult cases, switch to use compat.v1.\n"
|
||||||
|
"%s: Keep 1.* code intact and import compat.v1 "
|
||||||
|
"module. Note: safety mode is under development and not available "
|
||||||
|
"yet." % (_DEFAULT_MODE, _SAFETY_MODE)),
|
||||||
|
default=_DEFAULT_MODE)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
upgrade = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
|
if args.mode == _SAFETY_MODE:
|
||||||
|
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
|
||||||
|
sys.stderr.write(
|
||||||
|
"%s mode is not fully implemented yet." % _SAFETY_MODE)
|
||||||
|
else:
|
||||||
|
change_spec = tf_upgrade_v2.TFAPIChangeSpec()
|
||||||
|
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
|
||||||
|
|
||||||
report_text = None
|
report_text = None
|
||||||
report_filename = args.report_filename
|
report_filename = args.report_filename
|
||||||
files_processed = 0
|
files_processed = 0
|
||||||
|
38
tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
Normal file
38
tensorflow/tools/compatibility/tf_upgrade_v2_safety.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Upgrader for Python scripts from 1.* to 2.0 TensorFlow using SAFETY mode."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
|
from tensorflow.tools.compatibility import module_deprecations_v2
|
||||||
|
|
||||||
|
|
||||||
|
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||||
|
"""List of maps that describe what changed in the API."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.function_keyword_renames = {}
|
||||||
|
self.symbol_renames = {}
|
||||||
|
self.change_to_function = {}
|
||||||
|
self.function_reorders = {}
|
||||||
|
self.function_warnings = {}
|
||||||
|
self.function_transformers = {}
|
||||||
|
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
|
||||||
|
|
||||||
|
# TODO(kaftan,annarev): specify replacement from TensorFlow import to
|
||||||
|
# compat.v1 import.
|
47
tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
Normal file
47
tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf 2.0 upgrader in safety mode."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import test as test_lib
|
||||||
|
from tensorflow.tools.compatibility import ast_edits
|
||||||
|
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
|
||||||
|
|
||||||
|
|
||||||
|
class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def _upgrade(self, old_file_text):
|
||||||
|
in_file = six.StringIO(old_file_text)
|
||||||
|
out_file = six.StringIO()
|
||||||
|
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2_safety.TFAPIChangeSpec())
|
||||||
|
count, report, errors = (
|
||||||
|
upgrader.process_opened_file("test.py", in_file,
|
||||||
|
"test_out.py", out_file))
|
||||||
|
return count, report, errors, out_file.getvalue()
|
||||||
|
|
||||||
|
def testContribWarning(self):
|
||||||
|
text = "tf.contrib.foo()"
|
||||||
|
_, report, _, _ = self._upgrade(text)
|
||||||
|
expected_info = "tf.contrib will not be distributed"
|
||||||
|
self.assertIn(expected_info, report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_lib.main()
|
Loading…
Reference in New Issue
Block a user