Add get_local_variable
, which returns a local variable that respects variable scopes.
Change: 132338671
This commit is contained in:
parent
a4f6a51a84
commit
f0b904afef
@ -605,6 +605,27 @@ class VariableScopeTest(tf.test.TestCase):
|
||||
with tf.name_scope("scope2") as sc2:
|
||||
self.assertEqual(sc2, "outer_1/default/scope2/")
|
||||
|
||||
def testGetLocalVar(self):
|
||||
with self.test_session():
|
||||
# Check that local variable respects naming.
|
||||
with tf.variable_scope("outer") as outer:
|
||||
with tf.variable_scope(outer, "default", []):
|
||||
local_var = variable_scope.get_local_variable(
|
||||
"w", [], collections=["foo"])
|
||||
self.assertEqual(local_var.name, "outer/w:0")
|
||||
|
||||
# Since variable is local, it should be in the local variable collection
|
||||
# but not the the trainable collection.
|
||||
self.assertIn(local_var, tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES))
|
||||
self.assertIn(local_var, tf.get_collection("foo"))
|
||||
self.assertNotIn(
|
||||
local_var, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
|
||||
|
||||
# Check that local variable respects `reuse`.
|
||||
with tf.variable_scope(outer, "default", reuse=True):
|
||||
self.assertEqual(variable_scope.get_local_variable("w", []).name,
|
||||
"outer/w:0")
|
||||
|
||||
|
||||
def axis0_into1_partitioner(shape=None, **unused_kwargs):
|
||||
part = [1] * len(shape)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import collections as collections_lib
|
||||
import contextlib
|
||||
import functools
|
||||
import traceback
|
||||
|
||||
import six
|
||||
@ -35,8 +36,8 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
__all__ = ["VariableScope", "get_variable_scope",
|
||||
"get_variable", "variable_scope", "variable_op_scope",
|
||||
"no_regularizer"]
|
||||
"get_variable", "get_local_variable", "variable_scope",
|
||||
"variable_op_scope", "no_regularizer"]
|
||||
|
||||
|
||||
class _PartitionInfo(object):
|
||||
@ -1012,6 +1013,19 @@ def get_variable(name,
|
||||
custom_getter=custom_getter)
|
||||
|
||||
|
||||
@functools.wraps(get_variable)
|
||||
def get_local_variable(*args, **kwargs):
|
||||
kwargs["trainable"] = False
|
||||
if "collections" in kwargs:
|
||||
kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES]
|
||||
else:
|
||||
kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES]
|
||||
get_local_variable.__doc__ = (
|
||||
"Gets an existing local variable or creates a new one.\n\n" +
|
||||
get_local_variable.__doc__)
|
||||
return get_variable(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_partitioned_variable(name,
|
||||
shape=None,
|
||||
dtype=None,
|
||||
|
Loading…
Reference in New Issue
Block a user