Change tf_upgrade_v2 to be compatible with tf1 and tf2 for tf.nn.dropout

This commit is contained in:
Yuta Fukasawa 2020-11-15 12:13:37 +09:00
parent 2894fd21c0
commit ac5ace351a
2 changed files with 6 additions and 3 deletions

View File

@ -1845,7 +1845,10 @@ def _dropout_transformer(parent, node, full_name, name, logs):
"automatic fix was disabled. tf.nn.dropout has changed "
"the semantics of the second argument."))
else:
_replace_keep_prob_node(node, node.args[1])
rate_arg = ast.keyword(arg="rate", value=node.args[1])
_replace_keep_prob_node(rate_arg, rate_arg.value)
node.keywords.append(rate_arg)
del node.args[1]
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changing keep_prob arg of tf.nn.dropout to rate, and "
"recomputing value.\n"))

View File

@ -901,7 +901,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(
new_text,
"tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n",
"tf.nn.dropout(x, rate=1 - (keep_prob), name=\"foo\")\n",
)
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
@ -934,7 +934,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(
new_text,
"tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n",
"tf.nn.dropout(x, rate=1 - (1 - func(3 + 4.)), name=\"foo\")\n",
)
def testContribL1(self):