From 188791fc0b4a50e7591528f40023ac7c891f17c6 Mon Sep 17 00:00:00 2001 From: Anna R Date: Fri, 10 May 2019 16:04:20 -0700 Subject: [PATCH] Setup skeleton for adding SAFETY mode to tf_upgrade_v2.py script. PiperOrigin-RevId: 247692270 --- tensorflow/tools/compatibility/BUILD | 30 +++++++++ .../compatibility/module_deprecations_v2.py | 64 +++++++++++++++++++ .../tools/compatibility/tf_upgrade_v2.py | 42 +----------- .../tools/compatibility/tf_upgrade_v2_main.py | 31 ++++++++- .../compatibility/tf_upgrade_v2_safety.py | 38 +++++++++++ .../tf_upgrade_v2_safety_test.py | 47 ++++++++++++++ 6 files changed, 211 insertions(+), 41 deletions(-) create mode 100644 tensorflow/tools/compatibility/module_deprecations_v2.py create mode 100644 tensorflow/tools/compatibility/tf_upgrade_v2_safety.py create mode 100644 tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index 470b23072c0..4640132f1aa 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -91,6 +91,12 @@ py_library( deps = [":renames_v2"], ) +py_library( + name = "module_deprecations_v2", + srcs = ["module_deprecations_v2.py"], + deps = [":ast_edits"], +) + py_library( name = "tf_upgrade_v2_lib", srcs = ["tf_upgrade_v2.py"], @@ -98,11 +104,22 @@ py_library( deps = [ ":all_renames_v2", ":ast_edits", + ":module_deprecations_v2", ":reorders_v2", "@six_archive//:six", ], ) +py_library( + name = "tf_upgrade_v2_safety_lib", + srcs = ["tf_upgrade_v2_safety.py"], + srcs_version = "PY2AND3", + deps = [ + ":ast_edits", + ":module_deprecations_v2", + ], +) + py_binary( name = "tf_upgrade_v2", srcs = ["tf_upgrade_v2_main.py"], @@ -113,6 +130,7 @@ py_binary( ":ast_edits", ":ipynb", ":tf_upgrade_v2_lib", + ":tf_upgrade_v2_safety_lib", ], ) @@ -133,6 +151,18 @@ py_test( ], ) +py_test( + name = "tf_upgrade_v2_safety_test", + srcs = ["tf_upgrade_v2_safety_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":tf_upgrade_v2_safety_lib", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", + ], +) + # Keep for reference, this test will succeed in 0.11 but fail in 1.0 # py_test( # name = "test_file_v0_11", diff --git a/tensorflow/tools/compatibility/module_deprecations_v2.py b/tensorflow/tools/compatibility/module_deprecations_v2.py new file mode 100644 index 00000000000..ba542954a1d --- /dev/null +++ b/tensorflow/tools/compatibility/module_deprecations_v2.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module deprecation warnings for TensorFlow 2.0.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.tools.compatibility import ast_edits + + +_CONTRIB_WARNING = ( + ast_edits.ERROR, + " cannot be converted automatically. tf.contrib will not" + " be distributed with TensorFlow 2.0, please consider an alternative in" + " non-contrib TensorFlow, a community-maintained repository, or fork " + "the required code." +) + +_FLAGS_WARNING = ( + ast_edits.ERROR, + "tf.flags has been removed, please use the argparse or absl" + " modules if you need command line parsing." +) + +_CONTRIB_CUDNN_RNN_WARNING = ( + ast_edits.WARNING, + "(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, " + "and the CuDNN kernel has been integrated with " + "tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API " + "and use that instead." +) + +_CONTRIB_RNN_WARNING = ( + ast_edits.WARNING, + "(Manual edit required) tf.contrib.rnn.* has been deprecated, and " + "widely used cells/functions will be moved to tensorflow/addons " + "repository. Please check it there and file Github issues if necessary." +) + +_CONTRIB_DIST_STRAT_WARNING = ( + ast_edits.WARNING, + "(Manual edit required) tf.contrib.distribute.* have been migrated to" + "tf.distribute.*. Please check out the new module for updates APIs.") + +MODULE_DEPRECATIONS = { + "tf.contrib": _CONTRIB_WARNING, + "tf.contrib.cudnn_rnn": _CONTRIB_CUDNN_RNN_WARNING, + "tf.contrib.rnn": _CONTRIB_RNN_WARNING, + "tf.flags": _FLAGS_WARNING, + "tf.contrib.distribute": _CONTRIB_DIST_STRAT_WARNING +} diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index d4044a2a0c4..e55ad592bff 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -27,6 +27,7 @@ import pasta from tensorflow.tools.compatibility import all_renames_v2 from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import module_deprecations_v2 from tensorflow.tools.compatibility import reorders_v2 # These pylint warnings are a mistake. @@ -622,34 +623,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): self.function_reorders = dict(reorders_v2.reorders) self.function_reorders.update(self.manual_function_reorders) - contrib_warning = ( - ast_edits.ERROR, - " cannot be converted automatically. tf.contrib will not" - " be distributed with TensorFlow 2.0, please consider an alternative in" - " non-contrib TensorFlow, a community-maintained repository, or fork " - "the required code." - ) - - flags_warning = ( - ast_edits.ERROR, - "tf.flags has been removed, please use the argparse or absl" - " modules if you need command line parsing.") - - contrib_cudnn_rnn_warning = ( - ast_edits.WARNING, - "(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, " - "and the CuDNN kernel has been integrated with " - "tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API " - "and use that instead." - ) - - contrib_rnn_warning = ( - ast_edits.WARNING, - "(Manual edit required) tf.contrib.rnn.* has been deprecated, and " - "widely used cells/functions will be moved to tensorflow/addons " - "repository. Please check it there and file Github issues if necessary." - ) - decay_function_comment = ( ast_edits.INFO, "To use learning rate decay schedules with TensorFlow 2.0, switch to " @@ -790,11 +763,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "default, instead of HDF5. To continue saving to HDF5, add the " "argument save_format='h5' to the save() function.") - contrib_dist_strat_warning = ( - ast_edits.WARNING, - "(Manual edit required) tf.contrib.distribute.* have been migrated to" - "tf.distribute.*. Please check out the new module for updates APIs.") - distribute_strategy_api_changes = ( "If you're using the strategy with a " "custom training loop, note the following changes in methods: " @@ -1504,13 +1472,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): arg_value_ast=ast.Str("h5")), } - self.module_deprecations = { - "tf.contrib": contrib_warning, - "tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning, - "tf.contrib.rnn": contrib_rnn_warning, - "tf.flags": flags_warning, - "tf.contrib.distribute": contrib_dist_strat_warning - } + self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS def _is_ast_str(node): diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py index 36e30f559e3..3c4263ed809 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py @@ -19,11 +19,22 @@ from __future__ import division from __future__ import print_function import argparse +import sys from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import tf_upgrade_v2 +from tensorflow.tools.compatibility import tf_upgrade_v2_safety from tensorflow.tools.compatibility import ipynb +# Make straightforward changes to convert to 2.0. In harder cases, +# use compat.v1. +_DEFAULT_MODE = "DEFAULT" + +# Convert to use compat.v1. +# TODO(kaftan): remove EXPERIMENTAL_ prefix once safety mode is +# implemented. +_SAFETY_MODE = "EXPERIMENTAL_SAFETY" + def process_file(in_filename, out_filename, upgrader): """Process a file of type `.py` or `.ipynb`.""" @@ -91,9 +102,27 @@ Simple usage: "stored." "(default: %(default)s)"), default="report.txt") + parser.add_argument( + "--mode", + dest="mode", + choices=[_DEFAULT_MODE, _SAFETY_MODE], + help=("Upgrade script mode. Supported modes:\n" + "%s: Perform only straightforward conversions to upgrade to " + "2.0. In more difficult cases, switch to use compat.v1.\n" + "%s: Keep 1.* code intact and import compat.v1 " + "module. Note: safety mode is under development and not available " + "yet." % (_DEFAULT_MODE, _SAFETY_MODE)), + default=_DEFAULT_MODE) args = parser.parse_args() - upgrade = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec()) + if args.mode == _SAFETY_MODE: + change_spec = tf_upgrade_v2_safety.TFAPIChangeSpec() + sys.stderr.write( + "%s mode is not fully implemented yet." % _SAFETY_MODE) + else: + change_spec = tf_upgrade_v2.TFAPIChangeSpec() + upgrade = ast_edits.ASTCodeUpgrader(change_spec) + report_text = None report_filename = args.report_filename files_processed = 0 diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py b/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py new file mode 100644 index 00000000000..02ade7bb812 --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_safety.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts from 1.* to 2.0 TensorFlow using SAFETY mode.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import module_deprecations_v2 + + +class TFAPIChangeSpec(ast_edits.APIChangeSpec): + """List of maps that describe what changed in the API.""" + + def __init__(self): + self.function_keyword_renames = {} + self.symbol_renames = {} + self.change_to_function = {} + self.function_reorders = {} + self.function_warnings = {} + self.function_transformers = {} + self.module_deprecations = module_deprecations_v2.MODULE_DEPRECATIONS + + # TODO(kaftan,annarev): specify replacement from TensorFlow import to + # compat.v1 import. diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py new file mode 100644 index 00000000000..8890d631c34 --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_safety_test.py @@ -0,0 +1,47 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf 2.0 upgrader in safety mode.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import tf_upgrade_v2_safety + + +class TfUpgradeV2SafetyTest(test_util.TensorFlowTestCase): + + def _upgrade(self, old_file_text): + in_file = six.StringIO(old_file_text) + out_file = six.StringIO() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2_safety.TFAPIChangeSpec()) + count, report, errors = ( + upgrader.process_opened_file("test.py", in_file, + "test_out.py", out_file)) + return count, report, errors, out_file.getvalue() + + def testContribWarning(self): + text = "tf.contrib.foo()" + _, report, _, _ = self._upgrade(text) + expected_info = "tf.contrib will not be distributed" + self.assertIn(expected_info, report) + + +if __name__ == "__main__": + test_lib.main()