Default use_resource to False if user does not specify it.

PiperOrigin-RevId: 236387319
This commit is contained in:
Zhenyu Tan 2019-03-01 15:16:08 -08:00 committed by TensorFlower Gardener
parent 3c3bf0873d
commit 269caf17e6
2 changed files with 46 additions and 1 deletions

View File

@ -1464,6 +1464,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nn.fractional_avg_pool": _pool_seed_transformer,
"tf.nn.fractional_max_pool": _pool_seed_transformer,
"tf.name_scope": _name_scope_transformer,
"tf.get_variable": _add_use_resource_transformer,
"tf.device": functools.partial(
_rename_if_arg_found_transformer, arg_name="device_name",
arg_ok_predicate=_is_ast_str, remove_if_ok=False,
@ -1966,6 +1967,40 @@ def _add_loss_reduction_transformer(parent, node, full_name, name, logs):
return node
def _add_use_resource_transformer(parent, node, full_name, name, logs):
"""Adds a use_resource argument if not specified.
Default value for tf.get_variable use_resource argument is removed. So, we
update existing calls to use_resource=False.
Args:
parent: Parent of node.
node: ast.Call node to maybe modify.
full_name: full name of function to modify
name: name of function to modify
logs: list of logs to append to
Returns:
node, if it was modified, else None.
"""
for keyword_arg in node.keywords:
if keyword_arg.arg == "use_resource":
return node
default_value = "False"
# Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
ast_value = pasta.parse(default_value)
node.keywords.append(ast.keyword(arg="use_resource", value=ast_value))
logs.append((
ast_edits.INFO, node.lineno, node.col_offset,
"%s: Default use_resource to False. This will use (deprecated) reference"
" variables. Removing this argument will work in most cases.\n"
% (full_name or name)))
return node
def _add_uniform_scaling_initializer_transformer(
parent, node, full_name, name, logs):
"""Updates references to uniform_unit_scaling_initializer.
@ -1995,7 +2030,7 @@ def _add_uniform_scaling_initializer_transformer(
node.func.attr = "VarianceScaling"
return node
def _name_scope_transformer(parent, node, full_name, name, logs):
"""Fix name scope invocation to use 'default_name' and omit 'values' args."""

View File

@ -517,6 +517,16 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
_, report, errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def testGetVariableWithUseResource(self):
text = "tf.get_variable(name=\"a\")"
expected_text = "tf.compat.v1.get_variable(name=\"a\", use_resource=False)"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "tf.get_variable(name=\"a\", use_resource=None)"
expected_text = "tf.compat.v1.get_variable(name=\"a\", use_resource=None)"
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
def testExtractGlimpse(self):
text = ("tf.image.extract_glimpse(x, size, off, False, "