Change tf_upgrade_v2 to be compatible with tf1 and tf2 for tf.nn.dropout
This commit is contained in:
parent
2894fd21c0
commit
ac5ace351a
@ -1845,7 +1845,10 @@ def _dropout_transformer(parent, node, full_name, name, logs):
|
||||
"automatic fix was disabled. tf.nn.dropout has changed "
|
||||
"the semantics of the second argument."))
|
||||
else:
|
||||
_replace_keep_prob_node(node, node.args[1])
|
||||
rate_arg = ast.keyword(arg="rate", value=node.args[1])
|
||||
_replace_keep_prob_node(rate_arg, rate_arg.value)
|
||||
node.keywords.append(rate_arg)
|
||||
del node.args[1]
|
||||
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
||||
"recomputing value.\n"))
|
||||
|
@ -901,7 +901,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text,
|
||||
"tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n",
|
||||
"tf.nn.dropout(x, rate=1 - (keep_prob), name=\"foo\")\n",
|
||||
)
|
||||
|
||||
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
|
||||
@ -934,7 +934,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
new_text,
|
||||
"tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n",
|
||||
"tf.nn.dropout(x, rate=1 - (1 - func(3 + 4.)), name=\"foo\")\n",
|
||||
)
|
||||
|
||||
def testContribL1(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user