Update deprecation message for contrib/{cudnn_}rnn in tf upgrade script.

PiperOrigin-RevId: 242178768
This commit is contained in:
Scott Zhu 2019-04-05 13:03:33 -07:00 committed by TensorFlower Gardener
parent f8c0a3ff45
commit 112e128f39
2 changed files with 25 additions and 0 deletions

View File

@ -1051,6 +1051,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.flags has been removed, please use the argparse or absl" "tf.flags has been removed, please use the argparse or absl"
" modules if you need command line parsing.") " modules if you need command line parsing.")
contrib_cudnn_rnn_warning = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
"and the CuDNN kernel has been integrated with "
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
"and use that instead."
)
contrib_rnn_warning = (
ast_edits.WARNING,
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
"widely used cells/functions will be moved to tensorflow/addons "
"repository. Please check it there and file Github issues if necessary."
)
decay_function_comment = ( decay_function_comment = (
ast_edits.INFO, ast_edits.INFO,
"To use learning rate decay schedules with TensorFlow 2.0, switch to " "To use learning rate decay schedules with TensorFlow 2.0, switch to "
@ -1706,6 +1721,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
self.module_deprecations = { self.module_deprecations = {
"tf.contrib": contrib_warning, "tf.contrib": contrib_warning,
"tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning,
"tf.contrib.rnn": contrib_rnn_warning,
"tf.flags": flags_warning, "tf.flags": flags_warning,
} }

View File

@ -1351,6 +1351,14 @@ def _log_prob(self, x):
_, _, errors, _ = self._upgrade("tf.flags.FLAGS") _, _, errors, _ = self._upgrade("tf.flags.FLAGS")
self.assertIn("tf.flags has been removed", errors[0]) self.assertIn("tf.flags has been removed", errors[0])
def test_contrib_rnn_deprecation(self):
_, report, _, _ = self._upgrade("tf.contrib.rnn")
self.assertIn("tf.contrib.rnn.* has been deprecated", report)
def test_contrib_cudnn_rnn_deprecation(self):
_, report, _, _ = self._upgrade("tf.contrib.cudnn_rnn")
self.assertIn("tf.contrib.cudnn_rnn.* has been deprecated", report)
def test_max_pool_2d(self): def test_max_pool_2d(self):
text = "tf.nn.max_pool(value=4)" text = "tf.nn.max_pool(value=4)"
expected_text = "tf.nn.max_pool2d(input=4)" expected_text = "tf.nn.max_pool2d(input=4)"