Default naming of scopes should continue, not reset, after jumps to reused scopes.

Change: 144556462
This commit is contained in:
Lukasz Kaiser 2017-01-14 22:24:39 -08:00 committed by TensorFlower Gardener
parent 5ab7874f70
commit e72067896c
2 changed files with 31 additions and 7 deletions
tensorflow/python

View File

@ -453,6 +453,25 @@ class VariableScopeTest(test.TestCase):
variable_scope.get_variable("w", []).name,
"defaultScope1_2/layer/w:0")
def testVarOpScopeUniqueNamesWithJump(self):
with self.test_session():
with variable_scope.variable_scope("default") as default:
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
variable_scope.get_variable("w", []).name,
"default/layer/w:0")
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
variable_scope.get_variable("w", []).name,
"default/layer_1/w:0")
with variable_scope.variable_scope(default):
pass
# No matter the jump in the middle, unique numbering continues.
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
variable_scope.get_variable("w", []).name,
"default/layer_2/w:0")
def testVarOpScopeReuse(self):
with self.test_session():
with variable_scope.variable_scope("outer") as outer:

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import collections as collections_lib
import contextlib
import copy
import functools
import traceback
@ -182,21 +183,21 @@ class _VariableStore(object):
"""Create a variable store."""
self._vars = {} # A dictionary of the stored TensorFlow variables.
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self._variable_scopes_count = {} # Count re-used variable scopes.
self.variable_scopes_count = {} # Count re-used variable scopes.
def open_variable_scope(self, scope_name):
if scope_name in self._variable_scopes_count:
self._variable_scopes_count[scope_name] += 1
if scope_name in self.variable_scopes_count:
self.variable_scopes_count[scope_name] += 1
else:
self._variable_scopes_count[scope_name] = 1
self.variable_scopes_count[scope_name] = 1
def close_variable_subscopes(self, scope_name):
for k in self._variable_scopes_count:
for k in self.variable_scopes_count:
if not scope_name or k.startswith(scope_name + "/"):
self._variable_scopes_count[k] = 0
self.variable_scopes_count[k] = 0
def variable_scope_count(self, scope_name):
return self._variable_scopes_count.get(scope_name, 0)
return self.variable_scopes_count.get(scope_name, 0)
def get_variable(self, name, shape=None, dtype=dtypes.float32,
initializer=None, regularizer=None, reuse=None,
@ -1222,6 +1223,7 @@ def _pure_variable_scope(name_or_scope,
try:
var_store.open_variable_scope(new_name)
if isinstance(name_or_scope, VariableScope):
old_subscopes = copy.copy(var_store.variable_scopes_count)
name_scope = name_or_scope._name_scope # pylint: disable=protected-access
# Handler for the case when we jump to a shared scope.
# We create a new VariableScope (default_varscope[0]) that contains
@ -1280,6 +1282,9 @@ def _pure_variable_scope(name_or_scope,
yield default_varscope[0]
finally:
var_store.close_variable_subscopes(new_name)
# If jumping out from a non-prolonged scope, restore counts.
if isinstance(name_or_scope, VariableScope):
var_store.variable_scopes_count = old_subscopes
default_varscope[0] = old