From ac5ace351a188d74a089610eb8e288fa686d6bf8 Mon Sep 17 00:00:00 2001 From: Yuta Fukasawa Date: Sun, 15 Nov 2020 12:13:37 +0900 Subject: [PATCH] Change tf_upgrade_v2 to be compatible with tf1 and tf2 for tf.nn.dropout --- tensorflow/tools/compatibility/tf_upgrade_v2.py | 5 ++++- tensorflow/tools/compatibility/tf_upgrade_v2_test.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 5d795de68c5..4c099e2c66d 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -1845,7 +1845,10 @@ def _dropout_transformer(parent, node, full_name, name, logs): "automatic fix was disabled. tf.nn.dropout has changed " "the semantics of the second argument.")) else: - _replace_keep_prob_node(node, node.args[1]) + rate_arg = ast.keyword(arg="rate", value=node.args[1]) + _replace_keep_prob_node(rate_arg, rate_arg.value) + node.keywords.append(rate_arg) + del node.args[1] logs.append((ast_edits.INFO, node.lineno, node.col_offset, "Changing keep_prob arg of tf.nn.dropout to rate, and " "recomputing value.\n")) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 6b65785fe32..71d7b6eb471 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -901,7 +901,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map _, unused_report, unused_errors, new_text = self._upgrade(text) self.assertEqual( new_text, - "tf.nn.dropout(x, 1 - (keep_prob), name=\"foo\")\n", + "tf.nn.dropout(x, rate=1 - (keep_prob), name=\"foo\")\n", ) text = "tf.nn.dropout(x, keep_prob=.4, name=\"foo\")\n" @@ -934,7 +934,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map _, unused_report, unused_errors, new_text = self._upgrade(text) self.assertEqual( new_text, - "tf.nn.dropout(x, 1 - (1 - func(3 + 4.)), name=\"foo\")\n", + "tf.nn.dropout(x, rate=1 - (1 - func(3 + 4.)), name=\"foo\")\n", ) def testContribL1(self):