Merge pull request #26785 from wangsiyu:fix_bug_vs_scope_merge_call

PiperOrigin-RevId: 239264902
This commit is contained in:
TensorFlower Gardener 2019-03-19 14:15:20 -07:00
commit e52328ebc5
2 changed files with 40 additions and 5 deletions

View File

@ -840,8 +840,9 @@ class MirroredStrategyNameScopeTest(test.TestCase):
# variable_scope.variable() respects name scopes when creating
# variables. On the other hand variable_scope.get_variable() ignores name
# scopes when creating variables. We test both methods of creating variables
# to make sure that we have the same variable names in both cases.
# scopes but respects variable scope when creating variables. We test both
# methods of creating variables to make sure that we have the same
# variable names in both cases.
def testNameScopeWithVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.variable(1.0, name="c")
@ -900,6 +901,36 @@ class MirroredStrategyNameScopeTest(test.TestCase):
self.assertEqual("c:0", c0.name)
self.assertEqual("c/replica_1:0", c1.name)
def testVariableScopeWithGetVariable(self, distribution):
def in_cross_replica(_):
c = variable_scope.get_variable("c", [1])
return c
def model_fn():
b = variable_scope.get_variable("b", [1])
with variable_scope.variable_scope("foo"):
c = ds_context.get_replica_context().merge_call(in_cross_replica)
return b, c
with context.graph_mode(), distribution.scope():
with variable_scope.variable_scope("main"):
a = variable_scope.get_variable("a", [1])
result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
self.assertIsInstance(result_c, values.DistributedValues)
a0, a1 = distribution.experimental_local_results(a)
b0, b1 = distribution.experimental_local_results(result_b)
c0, c1 = distribution.experimental_local_results(result_c)
self.assertEqual("main/a:0", a0.name)
self.assertEqual("main/a/replica_1:0", a1.name)
self.assertEqual("main/b:0", b0.name)
self.assertEqual("main/b/replica_1:0", b1.name)
self.assertEqual("main/foo/c:0", c0.name)
self.assertEqual("main/foo/c/replica_1:0", c1.name)
@combinations.generate(
combinations.combine(

View File

@ -177,12 +177,14 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
# capture the name_scope from the first MRT and assume it is
# the same for all other MRTs.
mtt_captured_name_scope = threads[0].captured_name_scope
mtt_captured_var_scope = threads[0].captured_var_scope
# Capture and merge the control dependencies from all the threads.
mtt_captured_control_deps = set()
for t in threads:
mtt_captured_control_deps.update(t.captured_control_deps)
with ops.name_scope(mtt_captured_name_scope),\
ops.control_dependencies(mtt_captured_control_deps):
ops.control_dependencies(mtt_captured_control_deps), \
variable_scope.variable_scope(mtt_captured_var_scope):
merge_result = threads[0].merge_fn(distribution, *merge_args,
**merge_kwargs)
for r, t in enumerate(threads):
@ -823,6 +825,7 @@ class _MirroredReplicaThread(threading.Thread):
self.merge_kwargs = None
self.merge_result = None
self.captured_name_scope = None
self.captured_var_scope = None
# We use a thread.Event for the main thread to signal when this
# thread should start running (`should_run`), and another for
# this thread to transfer control back to the main thread
@ -850,7 +853,7 @@ class _MirroredReplicaThread(threading.Thread):
self._init_graph = ops.get_default_graph()
self._variable_creator_stack = self.graph._variable_creator_stack[:]
self._captured_var_scope = variable_scope.get_variable_scope()
self._var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
self._name_scope = self.graph.get_name_scope()
if self._name_scope:
@ -879,7 +882,7 @@ class _MirroredReplicaThread(threading.Thread):
self.replica_id]), \
ops.name_scope(self._name_scope), \
variable_scope.variable_scope(
self._captured_var_scope, reuse=self.replica_id > 0), \
self._var_scope, reuse=self.replica_id > 0), \
variable_scope.variable_creator_scope(self.variable_creator_fn):
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
self.done = True
@ -926,6 +929,7 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext):
if t.captured_name_scope:
t.captured_name_scope += "/"
t.captured_var_scope = variable_scope.get_variable_scope()
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
# NOTE(priyag): Throw an error if there is a merge call in the middle of a