Merge pull request #26785 from wangsiyu:fix_bug_vs_scope_merge_call
PiperOrigin-RevId: 239264902
This commit is contained in:
commit
e52328ebc5
@ -840,8 +840,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
||||
|
||||
# variable_scope.variable() respects name scopes when creating
|
||||
# variables. On the other hand variable_scope.get_variable() ignores name
|
||||
# scopes when creating variables. We test both methods of creating variables
|
||||
# to make sure that we have the same variable names in both cases.
|
||||
# scopes but respects variable scope when creating variables. We test both
|
||||
# methods of creating variables to make sure that we have the same
|
||||
# variable names in both cases.
|
||||
def testNameScopeWithVariable(self, distribution):
|
||||
def in_cross_replica(_):
|
||||
c = variable_scope.variable(1.0, name="c")
|
||||
@ -900,6 +901,36 @@ class MirroredStrategyNameScopeTest(test.TestCase):
|
||||
self.assertEqual("c:0", c0.name)
|
||||
self.assertEqual("c/replica_1:0", c1.name)
|
||||
|
||||
def testVariableScopeWithGetVariable(self, distribution):
|
||||
|
||||
def in_cross_replica(_):
|
||||
c = variable_scope.get_variable("c", [1])
|
||||
return c
|
||||
|
||||
def model_fn():
|
||||
b = variable_scope.get_variable("b", [1])
|
||||
with variable_scope.variable_scope("foo"):
|
||||
c = ds_context.get_replica_context().merge_call(in_cross_replica)
|
||||
return b, c
|
||||
|
||||
with context.graph_mode(), distribution.scope():
|
||||
with variable_scope.variable_scope("main"):
|
||||
a = variable_scope.get_variable("a", [1])
|
||||
result = distribution.extended.call_for_each_replica(model_fn)
|
||||
result_b = result[0]
|
||||
result_c = result[1]
|
||||
self.assertIsInstance(result_b, values.DistributedValues)
|
||||
self.assertIsInstance(result_c, values.DistributedValues)
|
||||
a0, a1 = distribution.experimental_local_results(a)
|
||||
b0, b1 = distribution.experimental_local_results(result_b)
|
||||
c0, c1 = distribution.experimental_local_results(result_c)
|
||||
self.assertEqual("main/a:0", a0.name)
|
||||
self.assertEqual("main/a/replica_1:0", a1.name)
|
||||
self.assertEqual("main/b:0", b0.name)
|
||||
self.assertEqual("main/b/replica_1:0", b1.name)
|
||||
self.assertEqual("main/foo/c:0", c0.name)
|
||||
self.assertEqual("main/foo/c/replica_1:0", c1.name)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -177,12 +177,14 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
|
||||
# capture the name_scope from the first MRT and assume it is
|
||||
# the same for all other MRTs.
|
||||
mtt_captured_name_scope = threads[0].captured_name_scope
|
||||
mtt_captured_var_scope = threads[0].captured_var_scope
|
||||
# Capture and merge the control dependencies from all the threads.
|
||||
mtt_captured_control_deps = set()
|
||||
for t in threads:
|
||||
mtt_captured_control_deps.update(t.captured_control_deps)
|
||||
with ops.name_scope(mtt_captured_name_scope),\
|
||||
ops.control_dependencies(mtt_captured_control_deps):
|
||||
ops.control_dependencies(mtt_captured_control_deps), \
|
||||
variable_scope.variable_scope(mtt_captured_var_scope):
|
||||
merge_result = threads[0].merge_fn(distribution, *merge_args,
|
||||
**merge_kwargs)
|
||||
for r, t in enumerate(threads):
|
||||
@ -823,6 +825,7 @@ class _MirroredReplicaThread(threading.Thread):
|
||||
self.merge_kwargs = None
|
||||
self.merge_result = None
|
||||
self.captured_name_scope = None
|
||||
self.captured_var_scope = None
|
||||
# We use a thread.Event for the main thread to signal when this
|
||||
# thread should start running (`should_run`), and another for
|
||||
# this thread to transfer control back to the main thread
|
||||
@ -850,7 +853,7 @@ class _MirroredReplicaThread(threading.Thread):
|
||||
self._init_graph = ops.get_default_graph()
|
||||
|
||||
self._variable_creator_stack = self.graph._variable_creator_stack[:]
|
||||
self._captured_var_scope = variable_scope.get_variable_scope()
|
||||
self._var_scope = variable_scope.get_variable_scope()
|
||||
# Adding a "/" at end lets us re-enter this scope later.
|
||||
self._name_scope = self.graph.get_name_scope()
|
||||
if self._name_scope:
|
||||
@ -879,7 +882,7 @@ class _MirroredReplicaThread(threading.Thread):
|
||||
self.replica_id]), \
|
||||
ops.name_scope(self._name_scope), \
|
||||
variable_scope.variable_scope(
|
||||
self._captured_var_scope, reuse=self.replica_id > 0), \
|
||||
self._var_scope, reuse=self.replica_id > 0), \
|
||||
variable_scope.variable_creator_scope(self.variable_creator_fn):
|
||||
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
|
||||
self.done = True
|
||||
@ -926,6 +929,7 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext):
|
||||
if t.captured_name_scope:
|
||||
t.captured_name_scope += "/"
|
||||
|
||||
t.captured_var_scope = variable_scope.get_variable_scope()
|
||||
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
|
||||
|
||||
# NOTE(priyag): Throw an error if there is a merge call in the middle of a
|
||||
|
Loading…
Reference in New Issue
Block a user