From 46598b8a38c1cecb9d99b182dd70019fe45e62c8 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Sat, 16 Mar 2019 18:41:43 +0800 Subject: [PATCH] [Bug fix] Fix bug of not respecting outer variable scope when using to create variable in merge_call because of thread local storage --- .../python/mirrored_strategy_multigpu_test.py | 34 +++++++++++++++++-- .../python/distribute/mirrored_strategy.py | 10 ++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 8a1772b7f22..924d819ce28 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -840,8 +840,9 @@ class MirroredStrategyNameScopeTest(test.TestCase): # variable_scope.variable() respects name scopes when creating # variables. On the other hand variable_scope.get_variable() ignores name - # scopes when creating variables. We test both methods of creating variables - # to make sure that we have the same variable names in both cases. + # scopes but respects variable scope when creating variables. We test both + # methods of creating variables to make sure that we have the same + # variable names in both cases. def testNameScopeWithVariable(self, distribution): def in_cross_replica(_): c = variable_scope.variable(1.0, name="c") @@ -900,6 +901,35 @@ class MirroredStrategyNameScopeTest(test.TestCase): self.assertEqual("c:0", c0.name) self.assertEqual("c/replica_1:0", c1.name) + def testVariableScopeWithGetVariable(self, distribution): + def in_cross_replica(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with variable_scope.variable_scope("foo"): + c = ds_context.get_replica_context().merge_call(in_cross_replica) + return b, c + + with context.graph_mode(), distribution.scope(): + with variable_scope.variable_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = distribution.extended.call_for_each_replica(model_fn) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = distribution.experimental_local_results(a) + b0, b1 = distribution.experimental_local_results(result_b) + c0, c1 = distribution.experimental_local_results(result_c) + self.assertEqual("main/a:0", a0.name) + self.assertEqual("main/a/replica_1:0", a1.name) + self.assertEqual("main/b:0", b0.name) + self.assertEqual("main/b/replica_1:0", b1.name) + self.assertEqual("main/foo/c:0", c0.name) + self.assertEqual("main/foo/c/replica_1:0", c1.name) + @combinations.generate( combinations.combine( diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 3b34732cee3..0f733e8ab27 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -177,12 +177,14 @@ def _call_for_each_replica(distribution, device_map, fn, args, kwargs): # capture the name_scope from the first MRT and assume it is # the same for all other MRTs. mtt_captured_name_scope = threads[0].captured_name_scope + mtt_captured_var_scope = threads[0].captured_var_scope # Capture and merge the control dependencies from all the threads. mtt_captured_control_deps = set() for t in threads: mtt_captured_control_deps.update(t.captured_control_deps) with ops.name_scope(mtt_captured_name_scope),\ - ops.control_dependencies(mtt_captured_control_deps): + ops.control_dependencies(mtt_captured_control_deps), \ + variable_scope.variable_scope(mtt_captured_var_scope): merge_result = threads[0].merge_fn(distribution, *merge_args, **merge_kwargs) for r, t in enumerate(threads): @@ -823,6 +825,7 @@ class _MirroredReplicaThread(threading.Thread): self.merge_kwargs = None self.merge_result = None self.captured_name_scope = None + self.captured_var_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -850,7 +853,7 @@ class _MirroredReplicaThread(threading.Thread): self._init_graph = ops.get_default_graph() self._variable_creator_stack = self.graph._variable_creator_stack[:] - self._captured_var_scope = variable_scope.get_variable_scope() + self._var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. self._name_scope = self.graph.get_name_scope() if self._name_scope: @@ -879,7 +882,7 @@ class _MirroredReplicaThread(threading.Thread): self.replica_id]), \ ops.name_scope(self._name_scope), \ variable_scope.variable_scope( - self._captured_var_scope, reuse=self.replica_id > 0), \ + self._var_scope, reuse=self.replica_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) self.done = True @@ -926,6 +929,7 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext): if t.captured_name_scope: t.captured_name_scope += "/" + t.captured_var_scope = variable_scope.get_variable_scope() t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access # NOTE(priyag): Throw an error if there is a merge call in the middle of a