Update deprecation message for contrib/{cudnn_}rnn in tf upgrade script.
PiperOrigin-RevId: 242178768
This commit is contained in:
parent
f8c0a3ff45
commit
112e128f39
@ -1051,6 +1051,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
"tf.flags has been removed, please use the argparse or absl"
|
"tf.flags has been removed, please use the argparse or absl"
|
||||||
" modules if you need command line parsing.")
|
" modules if you need command line parsing.")
|
||||||
|
|
||||||
|
contrib_cudnn_rnn_warning = (
|
||||||
|
ast_edits.WARNING,
|
||||||
|
"(Manual edit required) tf.contrib.cudnn_rnn.* has been deprecated, "
|
||||||
|
"and the CuDNN kernel has been integrated with "
|
||||||
|
"tf.keras.layers.LSTM/GRU in TensorFlow 2.0. Please check the new API "
|
||||||
|
"and use that instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
contrib_rnn_warning = (
|
||||||
|
ast_edits.WARNING,
|
||||||
|
"(Manual edit required) tf.contrib.rnn.* has been deprecated, and "
|
||||||
|
"widely used cells/functions will be moved to tensorflow/addons "
|
||||||
|
"repository. Please check it there and file Github issues if necessary."
|
||||||
|
)
|
||||||
|
|
||||||
decay_function_comment = (
|
decay_function_comment = (
|
||||||
ast_edits.INFO,
|
ast_edits.INFO,
|
||||||
"To use learning rate decay schedules with TensorFlow 2.0, switch to "
|
"To use learning rate decay schedules with TensorFlow 2.0, switch to "
|
||||||
@ -1706,6 +1721,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
|||||||
|
|
||||||
self.module_deprecations = {
|
self.module_deprecations = {
|
||||||
"tf.contrib": contrib_warning,
|
"tf.contrib": contrib_warning,
|
||||||
|
"tf.contrib.cudnn_rnn": contrib_cudnn_rnn_warning,
|
||||||
|
"tf.contrib.rnn": contrib_rnn_warning,
|
||||||
"tf.flags": flags_warning,
|
"tf.flags": flags_warning,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1351,6 +1351,14 @@ def _log_prob(self, x):
|
|||||||
_, _, errors, _ = self._upgrade("tf.flags.FLAGS")
|
_, _, errors, _ = self._upgrade("tf.flags.FLAGS")
|
||||||
self.assertIn("tf.flags has been removed", errors[0])
|
self.assertIn("tf.flags has been removed", errors[0])
|
||||||
|
|
||||||
|
def test_contrib_rnn_deprecation(self):
|
||||||
|
_, report, _, _ = self._upgrade("tf.contrib.rnn")
|
||||||
|
self.assertIn("tf.contrib.rnn.* has been deprecated", report)
|
||||||
|
|
||||||
|
def test_contrib_cudnn_rnn_deprecation(self):
|
||||||
|
_, report, _, _ = self._upgrade("tf.contrib.cudnn_rnn")
|
||||||
|
self.assertIn("tf.contrib.cudnn_rnn.* has been deprecated", report)
|
||||||
|
|
||||||
def test_max_pool_2d(self):
|
def test_max_pool_2d(self):
|
||||||
text = "tf.nn.max_pool(value=4)"
|
text = "tf.nn.max_pool(value=4)"
|
||||||
expected_text = "tf.nn.max_pool2d(input=4)"
|
expected_text = "tf.nn.max_pool2d(input=4)"
|
||||||
|
Loading…
Reference in New Issue
Block a user