Change tf_upgrade_v2 to be compatible with tf1 and tf2 for tf.nn.dropout
This commit is contained in:
parent
2894fd21c0
commit
ac5ace351a
@ -1845,7 +1845,10 @@ def _dropout_transformer(parent, node, full_name, name, logs):
|
|||||||
"automatic fix was disabled. tf.nn.dropout has changed "
|
"automatic fix was disabled. tf.nn.dropout has changed "
|
||||||
"the semantics of the second argument."))
|
"the semantics of the second argument."))
|
||||||
else:
|
else:
|
||||||
_replace_keep_prob_node(node, node.args[1])
|
rate_arg = ast.keyword(arg="rate", value=node.args[1])
|
||||||
|
_replace_keep_prob_node(rate_arg, rate_arg.value)
|
||||||
|
node.keywords.append(rate_arg)
|
||||||
|
del node.args[1]
|
||||||
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
|
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
|
||||||
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
"Changing keep_prob arg of tf.nn.dropout to rate, and "
|
||||||
"recomputing value.\n"))
|
"recomputing value.\n"))
|
||||||
|
@ -901,7 +901,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
new_text,
|
new_text,
|
||||||
"tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n",
|
"tf.nn.dropout(x, rate=1 - (keep_prob), name=\"foo\")\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
|
text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n"
|
||||||
@ -934,7 +934,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
|||||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
new_text,
|
new_text,
|
||||||
"tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n",
|
"tf.nn.dropout(x, rate=1 - (1 - func(3 + 4.)), name=\"foo\")\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
def testContribL1(self):
|
def testContribL1(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user