From f0b904afefc0cc2e30fdd4c03a88a5f13654f22f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Sep 2016 10:02:05 -0800 Subject: [PATCH] Add `get_local_variable`, which returns a local variable that respects variable scopes. Change: 132338671 --- .../kernel_tests/variable_scope_test.py | 21 +++++++++++++++++++ tensorflow/python/ops/variable_scope.py | 18 ++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index bce249d8c5e..e48ed569118 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -605,6 +605,27 @@ class VariableScopeTest(tf.test.TestCase): with tf.name_scope("scope2") as sc2: self.assertEqual(sc2, "outer_1/default/scope2/") + def testGetLocalVar(self): + with self.test_session(): + # Check that local variable respects naming. + with tf.variable_scope("outer") as outer: + with tf.variable_scope(outer, "default", []): + local_var = variable_scope.get_local_variable( + "w", [], collections=["foo"]) + self.assertEqual(local_var.name, "outer/w:0") + + # Since variable is local, it should be in the local variable collection + # but not the the trainable collection. + self.assertIn(local_var, tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES)) + self.assertIn(local_var, tf.get_collection("foo")) + self.assertNotIn( + local_var, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)) + + # Check that local variable respects `reuse`. + with tf.variable_scope(outer, "default", reuse=True): + self.assertEqual(variable_scope.get_local_variable("w", []).name, + "outer/w:0") + def axis0_into1_partitioner(shape=None, **unused_kwargs): part = [1] * len(shape) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 217556df3c7..e94ff063b7b 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections as collections_lib import contextlib +import functools import traceback import six @@ -35,8 +36,8 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging __all__ = ["VariableScope", "get_variable_scope", - "get_variable", "variable_scope", "variable_op_scope", - "no_regularizer"] + "get_variable", "get_local_variable", "variable_scope", + "variable_op_scope", "no_regularizer"] class _PartitionInfo(object): @@ -1012,6 +1013,19 @@ def get_variable(name, custom_getter=custom_getter) +@functools.wraps(get_variable) +def get_local_variable(*args, **kwargs): + kwargs["trainable"] = False + if "collections" in kwargs: + kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES] + else: + kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES] + get_local_variable.__doc__ = ( + "Gets an existing local variable or creates a new one.\n\n" + + get_local_variable.__doc__) + return get_variable(*args, **kwargs) + + def _get_partitioned_variable(name, shape=None, dtype=None,