Merge pull request #26785 from wangsiyu:fix_bug_vs_scope_merge_call
PiperOrigin-RevId: 239264902
This commit is contained in:
commit
e52328ebc5
@ -840,8 +840,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
|||||||
|
|
||||||
# variable_scope.variable() respects name scopes when creating
|
# variable_scope.variable() respects name scopes when creating
|
||||||
# variables. On the other hand variable_scope.get_variable() ignores name
|
# variables. On the other hand variable_scope.get_variable() ignores name
|
||||||
# scopes when creating variables. We test both methods of creating variables
|
# scopes but respects variable scope when creating variables. We test both
|
||||||
# to make sure that we have the same variable names in both cases.
|
# methods of creating variables to make sure that we have the same
|
||||||
|
# variable names in both cases.
|
||||||
def testNameScopeWithVariable(self, distribution):
|
def testNameScopeWithVariable(self, distribution):
|
||||||
def in_cross_replica(_):
|
def in_cross_replica(_):
|
||||||
c = variable_scope.variable(1.0, name="c")
|
c = variable_scope.variable(1.0, name="c")
|
||||||
@ -900,6 +901,36 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
|||||||
self.assertEqual("c:0", c0.name)
|
self.assertEqual("c:0", c0.name)
|
||||||
self.assertEqual("c/replica_1:0", c1.name)
|
self.assertEqual("c/replica_1:0", c1.name)
|
||||||
|
|
||||||
|
def testVariableScopeWithGetVariable(self, distribution):
|
||||||
|
|
||||||
|
def in_cross_replica(_):
|
||||||
|
c = variable_scope.get_variable("c", [1])
|
||||||
|
return c
|
||||||
|
|
||||||
|
def model_fn():
|
||||||
|
b = variable_scope.get_variable("b", [1])
|
||||||
|
with variable_scope.variable_scope("foo"):
|
||||||
|
c = ds_context.get_replica_context().merge_call(in_cross_replica)
|
||||||
|
return b, c
|
||||||
|
|
||||||
|
with context.graph_mode(), distribution.scope():
|
||||||
|
with variable_scope.variable_scope("main"):
|
||||||
|
a = variable_scope.get_variable("a", [1])
|
||||||
|
result = distribution.extended.call_for_each_replica(model_fn)
|
||||||
|
result_b = result[0]
|
||||||
|
result_c = result[1]
|
||||||
|
self.assertIsInstance(result_b, values.DistributedValues)
|
||||||
|
self.assertIsInstance(result_c, values.DistributedValues)
|
||||||
|
a0, a1 = distribution.experimental_local_results(a)
|
||||||
|
b0, b1 = distribution.experimental_local_results(result_b)
|
||||||
|
c0, c1 = distribution.experimental_local_results(result_c)
|
||||||
|
self.assertEqual("main/a:0", a0.name)
|
||||||
|
self.assertEqual("main/a/replica_1:0", a1.name)
|
||||||
|
self.assertEqual("main/b:0", b0.name)
|
||||||
|
self.assertEqual("main/b/replica_1:0", b1.name)
|
||||||
|
self.assertEqual("main/foo/c:0", c0.name)
|
||||||
|
self.assertEqual("main/foo/c/replica_1:0", c1.name)
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
|
@ -177,12 +177,14 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
|
|||||||
# capture the name_scope from the first MRT and assume it is
|
# capture the name_scope from the first MRT and assume it is
|
||||||
# the same for all other MRTs.
|
# the same for all other MRTs.
|
||||||
mtt_captured_name_scope = threads[0].captured_name_scope
|
mtt_captured_name_scope = threads[0].captured_name_scope
|
||||||
|
mtt_captured_var_scope = threads[0].captured_var_scope
|
||||||
# Capture and merge the control dependencies from all the threads.
|
# Capture and merge the control dependencies from all the threads.
|
||||||
mtt_captured_control_deps = set()
|
mtt_captured_control_deps = set()
|
||||||
for t in threads:
|
for t in threads:
|
||||||
mtt_captured_control_deps.update(t.captured_control_deps)
|
mtt_captured_control_deps.update(t.captured_control_deps)
|
||||||
with ops.name_scope(mtt_captured_name_scope),\
|
with ops.name_scope(mtt_captured_name_scope),\
|
||||||
ops.control_dependencies(mtt_captured_control_deps):
|
ops.control_dependencies(mtt_captured_control_deps), \
|
||||||
|
variable_scope.variable_scope(mtt_captured_var_scope):
|
||||||
merge_result = threads[0].merge_fn(distribution, *merge_args,
|
merge_result = threads[0].merge_fn(distribution, *merge_args,
|
||||||
**merge_kwargs)
|
**merge_kwargs)
|
||||||
for r, t in enumerate(threads):
|
for r, t in enumerate(threads):
|
||||||
@ -823,6 +825,7 @@ class _MirroredReplicaThread(threading.Thread):
|
|||||||
self.merge_kwargs = None
|
self.merge_kwargs = None
|
||||||
self.merge_result = None
|
self.merge_result = None
|
||||||
self.captured_name_scope = None
|
self.captured_name_scope = None
|
||||||
|
self.captured_var_scope = None
|
||||||
# We use a thread.Event for the main thread to signal when this
|
# We use a thread.Event for the main thread to signal when this
|
||||||
# thread should start running (`should_run`), and another for
|
# thread should start running (`should_run`), and another for
|
||||||
# this thread to transfer control back to the main thread
|
# this thread to transfer control back to the main thread
|
||||||
@ -850,7 +853,7 @@ class _MirroredReplicaThread(threading.Thread):
|
|||||||
self._init_graph = ops.get_default_graph()
|
self._init_graph = ops.get_default_graph()
|
||||||
|
|
||||||
self._variable_creator_stack = self.graph._variable_creator_stack[:]
|
self._variable_creator_stack = self.graph._variable_creator_stack[:]
|
||||||
self._captured_var_scope = variable_scope.get_variable_scope()
|
self._var_scope = variable_scope.get_variable_scope()
|
||||||
# Adding a "/" at end lets us re-enter this scope later.
|
# Adding a "/" at end lets us re-enter this scope later.
|
||||||
self._name_scope = self.graph.get_name_scope()
|
self._name_scope = self.graph.get_name_scope()
|
||||||
if self._name_scope:
|
if self._name_scope:
|
||||||
@ -879,7 +882,7 @@ class _MirroredReplicaThread(threading.Thread):
|
|||||||
self.replica_id]), \
|
self.replica_id]), \
|
||||||
ops.name_scope(self._name_scope), \
|
ops.name_scope(self._name_scope), \
|
||||||
variable_scope.variable_scope(
|
variable_scope.variable_scope(
|
||||||
self._captured_var_scope, reuse=self.replica_id > 0), \
|
self._var_scope, reuse=self.replica_id > 0), \
|
||||||
variable_scope.variable_creator_scope(self.variable_creator_fn):
|
variable_scope.variable_creator_scope(self.variable_creator_fn):
|
||||||
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
|
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
|
||||||
self.done = True
|
self.done = True
|
||||||
@ -926,6 +929,7 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext):
|
|||||||
if t.captured_name_scope:
|
if t.captured_name_scope:
|
||||||
t.captured_name_scope += "/"
|
t.captured_name_scope += "/"
|
||||||
|
|
||||||
|
t.captured_var_scope = variable_scope.get_variable_scope()
|
||||||
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
|
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
|
||||||
|
|
||||||
# NOTE(priyag): Throw an error if there is a merge call in the middle of a
|
# NOTE(priyag): Throw an error if there is a merge call in the middle of a
|
||||||
|
Loading…
Reference in New Issue
Block a user