Default use_resource to False if user does not specify it.
PiperOrigin-RevId: 236387319
This commit is contained in:
parent
3c3bf0873d
commit
269caf17e6
@ -1464,6 +1464,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.nn.fractional_avg_pool": _pool_seed_transformer,
|
||||
"tf.nn.fractional_max_pool": _pool_seed_transformer,
|
||||
"tf.name_scope": _name_scope_transformer,
|
||||
"tf.get_variable": _add_use_resource_transformer,
|
||||
"tf.device": functools.partial(
|
||||
_rename_if_arg_found_transformer, arg_name="device_name",
|
||||
arg_ok_predicate=_is_ast_str, remove_if_ok=False,
|
||||
@ -1966,6 +1967,40 @@ def _add_loss_reduction_transformer(parent, node, full_name, name, logs):
|
||||
return node
|
||||
|
||||
|
||||
def _add_use_resource_transformer(parent, node, full_name, name, logs):
|
||||
"""Adds a use_resource argument if not specified.
|
||||
|
||||
Default value for tf.get_variable use_resource argument is removed. So, we
|
||||
update existing calls to use_resource=False.
|
||||
|
||||
Args:
|
||||
parent: Parent of node.
|
||||
node: ast.Call node to maybe modify.
|
||||
full_name: full name of function to modify
|
||||
name: name of function to modify
|
||||
logs: list of logs to append to
|
||||
|
||||
Returns:
|
||||
node, if it was modified, else None.
|
||||
"""
|
||||
|
||||
for keyword_arg in node.keywords:
|
||||
if keyword_arg.arg == "use_resource":
|
||||
return node
|
||||
|
||||
default_value = "False"
|
||||
# Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
|
||||
ast_value = pasta.parse(default_value)
|
||||
node.keywords.append(ast.keyword(arg="use_resource", value=ast_value))
|
||||
logs.append((
|
||||
ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"%s: Default use_resource to False. This will use (deprecated) reference"
|
||||
" variables. Removing this argument will work in most cases.\n"
|
||||
% (full_name or name)))
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def _add_uniform_scaling_initializer_transformer(
|
||||
parent, node, full_name, name, logs):
|
||||
"""Updates references to uniform_unit_scaling_initializer.
|
||||
@ -1995,7 +2030,7 @@ def _add_uniform_scaling_initializer_transformer(
|
||||
node.func.attr = "VarianceScaling"
|
||||
return node
|
||||
|
||||
|
||||
|
||||
def _name_scope_transformer(parent, node, full_name, name, logs):
|
||||
"""Fix name scope invocation to use 'default_name' and omit 'values' args."""
|
||||
|
||||
|
@ -517,6 +517,16 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
_, report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testGetVariableWithUseResource(self):
|
||||
text = "tf.get_variable(name=\"a\")"
|
||||
expected_text = "tf.compat.v1.get_variable(name=\"a\", use_resource=False)"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "tf.get_variable(name=\"a\", use_resource=None)"
|
||||
expected_text = "tf.compat.v1.get_variable(name=\"a\", use_resource=None)"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testExtractGlimpse(self):
|
||||
text = ("tf.image.extract_glimpse(x, size, off, False, "
|
||||
|
Loading…
Reference in New Issue
Block a user