diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 609f7d15ccc..9bb1f56ce1f 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -1464,6 +1464,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.nn.fractional_avg_pool": _pool_seed_transformer, "tf.nn.fractional_max_pool": _pool_seed_transformer, "tf.name_scope": _name_scope_transformer, + "tf.get_variable": _add_use_resource_transformer, "tf.device": functools.partial( _rename_if_arg_found_transformer, arg_name="device_name", arg_ok_predicate=_is_ast_str, remove_if_ok=False, @@ -1966,6 +1967,40 @@ def _add_loss_reduction_transformer(parent, node, full_name, name, logs): return node +def _add_use_resource_transformer(parent, node, full_name, name, logs): + """Adds a use_resource argument if not specified. + + Default value for tf.get_variable use_resource argument is removed. So, we + update existing calls to use_resource=False. + + Args: + parent: Parent of node. + node: ast.Call node to maybe modify. + full_name: full name of function to modify + name: name of function to modify + logs: list of logs to append to + + Returns: + node, if it was modified, else None. + """ + + for keyword_arg in node.keywords: + if keyword_arg.arg == "use_resource": + return node + + default_value = "False" + # Parse with pasta instead of ast to avoid emitting a spurious trailing \n. + ast_value = pasta.parse(default_value) + node.keywords.append(ast.keyword(arg="use_resource", value=ast_value)) + logs.append(( + ast_edits.INFO, node.lineno, node.col_offset, + "%s: Default use_resource to False. This will use (deprecated) reference" + " variables. Removing this argument will work in most cases.\n" + % (full_name or name))) + + return node + + def _add_uniform_scaling_initializer_transformer( parent, node, full_name, name, logs): """Updates references to uniform_unit_scaling_initializer. @@ -1995,7 +2030,7 @@ def _add_uniform_scaling_initializer_transformer( node.func.attr = "VarianceScaling" return node - + def _name_scope_transformer(parent, node, full_name, name, logs): """Fix name scope invocation to use 'default_name' and omit 'values' args.""" diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 52d5fbf0603..ec82323331c 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -517,6 +517,16 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map _, report, errors, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) + def testGetVariableWithUseResource(self): + text = "tf.get_variable(name=\"a\")" + expected_text = "tf.compat.v1.get_variable(name=\"a\", use_resource=False)" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(expected_text, new_text) + + text = "tf.get_variable(name=\"a\", use_resource=None)" + expected_text = "tf.compat.v1.get_variable(name=\"a\", use_resource=None)" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(expected_text, new_text) def testExtractGlimpse(self): text = ("tf.image.extract_glimpse(x, size, off, False, "