Add get_local_variable, which returns a local variable that respects variable scopes.

Change: 132338671
This commit is contained in:
A. Unique TensorFlower 2016-09-06 10:02:05 -08:00 committed by TensorFlower Gardener
parent a4f6a51a84
commit f0b904afef
2 changed files with 37 additions and 2 deletions

View File

@ -605,6 +605,27 @@ class VariableScopeTest(tf.test.TestCase):
with tf.name_scope("scope2") as sc2:
self.assertEqual(sc2, "outer_1/default/scope2/")
def testGetLocalVar(self):
with self.test_session():
# Check that local variable respects naming.
with tf.variable_scope("outer") as outer:
with tf.variable_scope(outer, "default", []):
local_var = variable_scope.get_local_variable(
"w", [], collections=["foo"])
self.assertEqual(local_var.name, "outer/w:0")
# Since variable is local, it should be in the local variable collection
# but not the the trainable collection.
self.assertIn(local_var, tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES))
self.assertIn(local_var, tf.get_collection("foo"))
self.assertNotIn(
local_var, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
# Check that local variable respects `reuse`.
with tf.variable_scope(outer, "default", reuse=True):
self.assertEqual(variable_scope.get_local_variable("w", []).name,
"outer/w:0")
def axis0_into1_partitioner(shape=None, **unused_kwargs):
part = [1] * len(shape)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import collections as collections_lib
import contextlib
import functools
import traceback
import six
@ -35,8 +36,8 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
__all__ = ["VariableScope", "get_variable_scope",
"get_variable", "variable_scope", "variable_op_scope",
"no_regularizer"]
"get_variable", "get_local_variable", "variable_scope",
"variable_op_scope", "no_regularizer"]
class _PartitionInfo(object):
@ -1012,6 +1013,19 @@ def get_variable(name,
custom_getter=custom_getter)
@functools.wraps(get_variable)
def get_local_variable(*args, **kwargs):
kwargs["trainable"] = False
if "collections" in kwargs:
kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES]
else:
kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES]
get_local_variable.__doc__ = (
"Gets an existing local variable or creates a new one.\n\n" +
get_local_variable.__doc__)
return get_variable(*args, **kwargs)
def _get_partitioned_variable(name,
shape=None,
dtype=None,