diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 1b674e27e54..2c3d5651ab8 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -1051,6 +1051,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.flags has been removed, please use the argparse or absl" " modules if you need command line parsing.") + contrib_cudnn_rnn_warning = ( + ast_edits.WARNING, + "(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, " + "and the CuDNN kernel has been integrated with " + "tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API " + "and use that instead." + ) + + contrib_rnn_warning = ( + ast_edits.WARNING, + "(Manual edit required) tf.contrib.rnn.* has been deprecated, and " + "widely used cells/functions will be moved to tensorflow/addons " + "repository. Please check it there and file Github issues if necessary." + ) + decay_function_comment = ( ast_edits.INFO, "To use learning rate decay schedules with TensorFlow 2.0, switch to " @@ -1706,6 +1721,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): self.module_deprecations = { "tf.contrib": contrib_warning, + "tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning, + "tf.contrib.rnn": contrib_rnn_warning, "tf.flags": flags_warning, } diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 8cc0c546a65..a39baea1c04 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -1351,6 +1351,14 @@ def _log_prob(self, x): _, _, errors, _ = self._upgrade("tf.flags.FLAGS") self.assertIn("tf.flags has been removed", errors[0]) + def test_contrib_rnn_deprecation(self): + _, report, _, _ = self._upgrade("tf.contrib.rnn") + self.assertIn("tf.contrib.rnn.* has been deprecated", report) + + def test_contrib_cudnn_rnn_deprecation(self): + _, report, _, _ = self._upgrade("tf.contrib.cudnn_rnn") + self.assertIn("tf.contrib.cudnn_rnn.* has been deprecated", report) + def test_max_pool_2d(self): text = "tf.nn.max_pool(value=4)" expected_text = "tf.nn.max_pool2d(input=4)"