Setup skeleton for adding SAFETY mode to tf_upgrade_v2.py script.

PiperOrigin-RevId: 247692270
This commit is contained in:
Anna R 2019-05-10 16:04:20 -07:00 committed by TensorFlower Gardener
parent a5a28d5778
commit 188791fc0b
6 changed files with 211 additions and 41 deletions

View File

@ -91,6 +91,12 @@ py_library(
deps = [":renames_v2"],
)
py_library(
name = "module_deprecations_v2",
srcs = ["module_deprecations_v2.py"],
deps = [":ast_edits"],
)
py_library(
name = "tf_upgrade_v2_lib",
srcs = ["tf_upgrade_v2.py"],
@ -98,11 +104,22 @@ py_library(
deps = [
":all_renames_v2",
":ast_edits",
":module_deprecations_v2",
":reorders_v2",
"@six_archive//:six",
],
)
py_library(
name = "tf_upgrade_v2_safety_lib",
srcs = ["tf_upgrade_v2_safety.py"],
srcs_version = "PY2AND3",
deps = [
":ast_edits",
":module_deprecations_v2",
],
)
py_binary(
name = "tf_upgrade_v2",
srcs = ["tf_upgrade_v2_main.py"],
@ -113,6 +130,7 @@ py_binary(
":ast_edits",
":ipynb",
":tf_upgrade_v2_lib",
":tf_upgrade_v2_safety_lib",
],
)
@ -133,6 +151,18 @@ py_test(
],
)
py_test(
name = "tf_upgrade_v2_safety_test",
srcs = ["tf_upgrade_v2_safety_test.py"],
srcs_version = "PY2AND3",
deps = [
":tf_upgrade_v2_safety_lib",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",
],
)
# Keep for reference, this test will succeed in 0.11 but fail in 1.0
# py_test(
# name = "test_file_v0_11",

View File

@ -0,0 +1,64 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module deprecation warnings for TensorFlow 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.tools.compatibility import ast_edits
_CONTRIB_WARNING = (
ast_edits.ERROR,
"<function name> cannot be converted automatically. tf.contrib will not"
" be distributed with TensorFlow 2.0, please consider an alternative in"
" non-contrib TensorFlow, a community-maintained repository, or fork "
"the required code."
)
_FLAGS_WARNING = (
ast_edits.ERROR,
"tf.flags has been removed, please use the argparse or absl"
" modules if you need command line parsing."
)
_CONTRIB_CUDNN_RNN_WARNING = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
"and the CuDNN kernel has been integrated with "
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
"and use that instead."
)
_CONTRIB_RNN_WARNING = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
"widely used cells/functions will be moved to tensorflow/addons "
"repository. Please check it there and file Github issues if necessary."
)
_CONTRIB_DIST_STRAT_WARNING = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.distribute.* have been migrated to"
"tf.distribute.*. Please check out the new module for updates APIs.")
MODULE_DEPRECATIONS = {
"tf.contrib": _CONTRIB_WARNING,
"tf.contrib.cudnn_rnn": _CONTRIB_CUDNN_RNN_WARNING,
"tf.contrib.rnn": _CONTRIB_RNN_WARNING,
"tf.flags": _FLAGS_WARNING,
"tf.contrib.distribute": _CONTRIB_DIST_STRAT_WARNING
}

View File

@ -27,6 +27,7 @@ import pasta
from tensorflow.tools.compatibility import all_renames_v2
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import module_deprecations_v2
from tensorflow.tools.compatibility import reorders_v2
# These pylint warnings are a mistake.
@ -622,34 +623,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
self.function_reorders = dict(reorders_v2.reorders)
self.function_reorders.update(self.manual_function_reorders)
contrib_warning = (
ast_edits.ERROR,
"<function name> cannot be converted automatically. tf.contrib will not"
" be distributed with TensorFlow 2.0, please consider an alternative in"
" non-contrib TensorFlow, a community-maintained repository, or fork "
"the required code."
)
flags_warning = (
ast_edits.ERROR,
"tf.flags has been removed, please use the argparse or absl"
" modules if you need command line parsing.")
contrib_cudnn_rnn_warning = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
"and the CuDNN kernel has been integrated with "
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
"and use that instead."
)
contrib_rnn_warning = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
"widely used cells/functions will be moved to tensorflow/addons "
"repository. Please check it there and file Github issues if necessary."
)
decay_function_comment = (
ast_edits.INFO,
"To use learning rate decay schedules with TensorFlow 2.0, switch to "
@ -790,11 +763,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"default, instead of HDF5. To continue saving to HDF5, add the "
"argument save_format='h5' to the save() function.")
contrib_dist_strat_warning = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.distribute.* have been migrated to"
"tf.distribute.*. Please check out the new module for updates APIs.")
distribute_strategy_api_changes = (
"If you're using the strategy with a "
"custom training loop, note the following changes in methods: "
@ -1504,13 +1472,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
arg_value_ast=ast.Str("h5")),
}
self.module_deprecations = {
"tf.contrib": contrib_warning,
"tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning,
"tf.contrib.rnn": contrib_rnn_warning,
"tf.flags": flags_warning,
"tf.contrib.distribute": contrib_dist_strat_warning
}
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
def _is_ast_str(node):

View File

@ -19,11 +19,22 @@ from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import tf_upgrade_v2
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
from tensorflow.tools.compatibility import ipynb
# Make straightforward changes to convert to 2.0. In harder cases,
# use compat.v1.
_DEFAULT_MODE = "DEFAULT"
# Convert to use compat.v1.
# TODO(kaftan): remove EXPERIMENTAL_ prefix once safety mode is
# implemented.
_SAFETY_MODE = "EXPERIMENTAL_SAFETY"
def process_file(in_filename, out_filename, upgrader):
"""Process a file of type `.py` or `.ipynb`."""
@ -91,9 +102,27 @@ Simple usage:
"stored."
"(default: %(default)s)"),
default="report.txt")
parser.add_argument(
"--mode",
dest="mode",
choices=[_DEFAULT_MODE, _SAFETY_MODE],
help=("Upgrade script mode. Supported modes:\n"
"%s: Perform only straightforward conversions to upgrade to "
"2.0. In more difficult cases, switch to use compat.v1.\n"
"%s: Keep 1.* code intact and import compat.v1 "
"module. Note: safety mode is under development and not available "
"yet." % (_DEFAULT_MODE, _SAFETY_MODE)),
default=_DEFAULT_MODE)
args = parser.parse_args()
upgrade = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
if args.mode == _SAFETY_MODE:
change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec()
sys.stderr.write(
"%s mode is not fully implemented yet." % _SAFETY_MODE)
else:
change_spec = tf_upgrade_v2.TFAPIChangeSpec()
upgrade = ast_edits.ASTCodeUpgrader(change_spec)
report_text = None
report_filename = args.report_filename
files_processed = 0

View File

@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Upgrader for Python scripts from 1.* to 2.0 TensorFlow using SAFETY mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import module_deprecations_v2
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"""List of maps that describe what changed in the API."""
def __init__(self):
self.function_keyword_renames = {}
self.symbol_renames = {}
self.change_to_function = {}
self.function_reorders = {}
self.function_warnings = {}
self.function_transformers = {}
self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS
# TODO(kaftan,annarev): specify replacement from TensorFlow import to
# compat.v1 import.

View File

@ -0,0 +1,47 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf 2.0 upgrader in safety mode."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test as test_lib
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase):
def _upgrade(self, old_file_text):
in_file = six.StringIO(old_file_text)
out_file = six.StringIO()
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2_safety.TFAPIChangeSpec())
count, report, errors = (
upgrader.process_opened_file("test.py", in_file,
"test_out.py", out_file))
return count, report, errors, out_file.getvalue()
def testContribWarning(self):
text = "tf.contrib.foo()"
_, report, _, _ = self._upgrade(text)
expected_info = "tf.contrib will not be distributed"
self.assertIn(expected_info, report)
if __name__ == "__main__":
test_lib.main()