Default naming of scopes should continue, not reset, after jumps to reused scopes.
Change: 144556462
This commit is contained in:
parent
5ab7874f70
commit
e72067896c
tensorflow/python
@ -453,6 +453,25 @@ class VariableScopeTest(test.TestCase):
|
||||
variable_scope.get_variable("w", []).name,
|
||||
"defaultScope1_2/layer/w:0")
|
||||
|
||||
def testVarOpScopeUniqueNamesWithJump(self):
|
||||
with self.test_session():
|
||||
with variable_scope.variable_scope("default") as default:
|
||||
with variable_scope.variable_scope(None, "layer"):
|
||||
self.assertEqual(
|
||||
variable_scope.get_variable("w", []).name,
|
||||
"default/layer/w:0")
|
||||
with variable_scope.variable_scope(None, "layer"):
|
||||
self.assertEqual(
|
||||
variable_scope.get_variable("w", []).name,
|
||||
"default/layer_1/w:0")
|
||||
with variable_scope.variable_scope(default):
|
||||
pass
|
||||
# No matter the jump in the middle, unique numbering continues.
|
||||
with variable_scope.variable_scope(None, "layer"):
|
||||
self.assertEqual(
|
||||
variable_scope.get_variable("w", []).name,
|
||||
"default/layer_2/w:0")
|
||||
|
||||
def testVarOpScopeReuse(self):
|
||||
with self.test_session():
|
||||
with variable_scope.variable_scope("outer") as outer:
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import collections as collections_lib
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import traceback
|
||||
|
||||
@ -182,21 +183,21 @@ class _VariableStore(object):
|
||||
"""Create a variable store."""
|
||||
self._vars = {} # A dictionary of the stored TensorFlow variables.
|
||||
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
|
||||
self._variable_scopes_count = {} # Count re-used variable scopes.
|
||||
self.variable_scopes_count = {} # Count re-used variable scopes.
|
||||
|
||||
def open_variable_scope(self, scope_name):
|
||||
if scope_name in self._variable_scopes_count:
|
||||
self._variable_scopes_count[scope_name] += 1
|
||||
if scope_name in self.variable_scopes_count:
|
||||
self.variable_scopes_count[scope_name] += 1
|
||||
else:
|
||||
self._variable_scopes_count[scope_name] = 1
|
||||
self.variable_scopes_count[scope_name] = 1
|
||||
|
||||
def close_variable_subscopes(self, scope_name):
|
||||
for k in self._variable_scopes_count:
|
||||
for k in self.variable_scopes_count:
|
||||
if not scope_name or k.startswith(scope_name + "/"):
|
||||
self._variable_scopes_count[k] = 0
|
||||
self.variable_scopes_count[k] = 0
|
||||
|
||||
def variable_scope_count(self, scope_name):
|
||||
return self._variable_scopes_count.get(scope_name, 0)
|
||||
return self.variable_scopes_count.get(scope_name, 0)
|
||||
|
||||
def get_variable(self, name, shape=None, dtype=dtypes.float32,
|
||||
initializer=None, regularizer=None, reuse=None,
|
||||
@ -1222,6 +1223,7 @@ def _pure_variable_scope(name_or_scope,
|
||||
try:
|
||||
var_store.open_variable_scope(new_name)
|
||||
if isinstance(name_or_scope, VariableScope):
|
||||
old_subscopes = copy.copy(var_store.variable_scopes_count)
|
||||
name_scope = name_or_scope._name_scope # pylint: disable=protected-access
|
||||
# Handler for the case when we jump to a shared scope.
|
||||
# We create a new VariableScope (default_varscope[0]) that contains
|
||||
@ -1280,6 +1282,9 @@ def _pure_variable_scope(name_or_scope,
|
||||
yield default_varscope[0]
|
||||
finally:
|
||||
var_store.close_variable_subscopes(new_name)
|
||||
# If jumping out from a non-prolonged scope, restore counts.
|
||||
if isinstance(name_or_scope, VariableScope):
|
||||
var_store.variable_scopes_count = old_subscopes
|
||||
default_varscope[0] = old
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user