From b16d24a342c5de1384dcb9ee408a74f206d332b2 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Thu, 2 Apr 2020 11:08:10 -0700 Subject: [PATCH] Separate mirrored call_for_each_replica to its own file Both ParameterServer and Mirrored uses this, and itself is complicated enough. This also fix a issue that you can't strategy.run(tf.function) under CentralStorageStrategy, by applying the same workaround we have in MirroredStrategy. PiperOrigin-RevId: 304437119 Change-Id: I6a7a67b88e7a5b7217aa9ffe05882d0ef4097896 --- tensorflow/python/distribute/BUILD | 37 +- tensorflow/python/distribute/mirrored_run.py | 454 ++++++++++++++++++ .../python/distribute/mirrored_strategy.py | 392 +-------------- .../distribute/parameter_server_strategy.py | 7 +- tensorflow/python/distribute/values_test.py | 1 - 5 files changed, 488 insertions(+), 403 deletions(-) create mode 100644 tensorflow/python/distribute/mirrored_run.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index b0f982e3d5c..d0a70f18294 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -255,22 +255,17 @@ py_library( ) py_library( - name = "mirrored_strategy", - srcs = ["mirrored_strategy.py"], + name = "mirrored_run", + srcs = ["mirrored_run.py"], deps = [ - ":cross_device_ops", ":device_util", ":distribute_lib", - ":input_lib", - ":multi_worker_util", - ":numpy_dataset", ":reduce_util", ":shared_variable_creator", ":values", "//tensorflow/python:array_ops", "//tensorflow/python:config", "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -284,9 +279,33 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/autograph/core", "//tensorflow/python/autograph/impl", - "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", + ], +) + +py_library( + name = "mirrored_strategy", + srcs = ["mirrored_strategy.py"], + deps = [ + ":cross_device_ops", + ":device_util", + ":distribute_lib", + ":input_lib", + ":mirrored_run", + ":multi_worker_util", + ":numpy_dataset", + ":reduce_util", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", ], ) @@ -297,7 +316,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":input_lib", - ":mirrored_strategy", + ":mirrored_run", ":numpy_dataset", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/distribute/mirrored_run.py b/tensorflow/python/distribute/mirrored_run.py new file mode 100644 index 00000000000..2cd139c387f --- /dev/null +++ b/tensorflow/python/distribute/mirrored_run.py @@ -0,0 +1,454 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class MirroredStrategy implementing tf.distribute.Strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import functools +import threading +import weakref + +from tensorflow.python import pywrap_tfe +from tensorflow.python.autograph.core import ag_ctx as autograph_ctx +from tensorflow.python.autograph.impl import api as autograph +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import shared_variable_creator +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator + + +def call_for_each_replica(strategy, fn, args=None, kwargs=None): + """Call `fn` on each worker devices(replica). + + It's highly recommended to wrap the call to this function inside a + `tf.function`, otherwise the performance is poor. + + Args: + strategy: `tf.distribute.Strategy`. + fn: function to call on each worker devices. + args: positional arguments to `fn`. + kwargs: keyword arguments to `fn`. + + Returns: + Wrapped returned value of `fn` from all replicas. + """ + if args is None: + args = () + if kwargs is None: + kwargs = {} + + if isinstance(fn, def_function.Function): + if strategy not in _cfer_fn_cache: + _cfer_fn_cache[strategy] = weakref.WeakKeyDictionary() + wrapped = _cfer_fn_cache[strategy].get(fn) + if wrapped is None: + # We need to wrap fn such that it triggers _call_for_each_replica inside + # the tf.function. We use _clone() instead of @tf.function wrapped + # call_for_each_replica() because we would like to retain the arguments to + # the @tf.function decorator of fn. + wrapped = fn._clone( # pylint: disable=protected-access + python_function=functools.partial(call_for_each_replica, strategy, + fn.python_function)) + _cfer_fn_cache[strategy][fn] = wrapped + return wrapped(args, kwargs) + + if context.executing_eagerly(): + logging.log_first_n( + logging.WARN, "Using %s eagerly has significant " + "overhead currently. We will be working on improving " + "this in the future, but for now please wrap " + "`call_for_each_replica` or `experimental_run` or " + "`experimental_run_v2` inside a tf.function to get " + "the best performance." % strategy.__class__.__name__, 5) + else: + # When a tf.function is wrapped to trigger _call_for_each_replica (see + # the other branch above), AutoGraph stops conversion at + # _call_for_each_replica itself (TF library functions are whitelisted). + # This makes sure that the Python function that originally passed to + # the tf.function is still converted. + fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) + + return _call_for_each_replica(strategy, fn, args, kwargs) + + +# Per strategy cache for call_for_each_replica def_function.Function objects. +_cfer_fn_cache = weakref.WeakKeyDictionary() + + +@contextlib.contextmanager +def _enter_graph(g, eager, creator_stack=None): + """Context manager for selecting a graph and maybe eager mode.""" + if eager: + with g.as_default(), context.eager_mode(): + if creator_stack is not None: + g._variable_creator_stack = creator_stack # pylint: disable=protected-access + yield + else: + with g.as_default(): + if creator_stack is not None: + g._variable_creator_stack = creator_stack # pylint: disable=protected-access + yield + + +def _cpu_device(device): + cpu_device = tf_device.DeviceSpec.from_string(device) + cpu_device = cpu_device.replace(device_type="CPU", device_index=0) + return cpu_device.to_string() + + +class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name + pass + + +def _call_for_each_replica(distribution, fn, args, kwargs): + """Run `fn` in separate threads, once per replica/worker device. + + Args: + distribution: the DistributionStrategy object. + fn: function to run (will be run once per replica, each in its own thread). + args: positional arguments for `fn` + kwargs: keyword arguments for `fn`. + + Returns: + Merged return value of `fn` across all replicas. + + Raises: + RuntimeError: If fn() calls get_replica_context().merge_call() a different + number of times from the available devices. + """ + # TODO(josh11b): Add this option once we add synchronization to variable + # creation. Until then, this is pretty unsafe to use. + run_concurrently = False + if not context.executing_eagerly(): + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + + coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + devices = distribution.extended.worker_devices + + # TODO(isaprykin): Create these threads once instead of during every call. + threads = [] + for index in range(len(devices)): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = _MirroredReplicaThread( + distribution, coord, index, devices, variable_creator_fn, fn, + values.select_replica(index, args), + values.select_replica(index, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredReplicaThread + # (`MRT`) threads. The execution waits until + # `MRT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_replica_context().merge_call()` is called. If `fn` is + # complete, then `MRT.done` is set to True. Otherwise, arguments + # of `get_replica_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_replica_context().merge_call` are then set to `MRT.merge_result`. + # Each such `get_replica_context().merge_call` call returns the + # `MRT.merge_result` for that thread when `MRT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some replicas made a different number of " + "replica_context().merge_call() calls.") + # get_replica_context().merge_call() case + merge_args = values.regroup(tuple(t.merge_args for t in threads)) + merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads)) + # We capture the name_scope of the MRT when we call merge_fn + # to ensure that if we have opened a name scope in the MRT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MRT and assume it is + # the same for all other MRTs. + mtt_captured_name_scope = threads[0].captured_name_scope + mtt_captured_var_scope = threads[0].captured_var_scope + # Capture and merge the control dependencies from all the threads. + mtt_captured_control_deps = set() + for t in threads: + mtt_captured_control_deps.update(t.captured_control_deps) + with ops.name_scope(mtt_captured_name_scope),\ + ops.control_dependencies(mtt_captured_control_deps), \ + variable_scope.variable_scope(mtt_captured_var_scope): + merge_result = threads[0].merge_fn(distribution, *merge_args, + **merge_kwargs) + for r, t in enumerate(threads): + t.merge_result = values.select_replica(r, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup(tuple(t.main_result for t in threads)) + + +class _MirroredReplicaThread(threading.Thread): + """A thread that runs() a function on a device.""" + + def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, + fn, args, kwargs): + super(_MirroredReplicaThread, self).__init__() + self.coord = coord + self.distribution = dist + self.devices = devices + self.replica_id = replica_id + self.variable_creator_fn = variable_creator_fn + # State needed to run and return the results of `fn`. + self.main_fn = fn + self.main_args = args + self.main_kwargs = kwargs + self.main_result = None + self.done = False + # State needed to run the next merge_call() (if any) requested via + # ReplicaContext. + self.merge_fn = None + self.merge_args = None + self.merge_kwargs = None + self.merge_result = None + self.captured_name_scope = None + self.captured_var_scope = None + # We use a thread.Event for the main thread to signal when this + # thread should start running (`should_run`), and another for + # this thread to transfer control back to the main thread + # (`has_paused`, either when it gets to a + # `get_replica_context().merge_call` or when `fn` returns). In + # either case the event starts cleared, is signaled by calling + # set(). The receiving thread waits for the signal by calling + # wait() and then immediately clearing the event using clear(). + self.should_run = threading.Event() + self.has_paused = threading.Event() + # These fields have to do with inheriting various contexts from the + # parent thread: + context.ensure_initialized() + ctx = context.context() + self.in_eager = ctx.executing_eagerly() + self.record_thread_local_summary_state() + self.record_thread_local_eager_context_state() + self.context_device_policy = ( + pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( + ctx._context_handle)) # pylint: disable=protected-access + self.graph = ops.get_default_graph() + with ops.init_scope(): + self._init_in_eager = context.executing_eagerly() + self._init_graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access + self._var_scope = variable_scope.get_variable_scope() + # Adding a "/" at end lets us re-enter this scope later. + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" + if self.replica_id > 0: + if not self._name_scope: + self._name_scope = "" + self._name_scope += "replica_%d/" % self.replica_id + + def run(self): + self.should_run.wait() + self.should_run.clear() + try: + if self.coord.should_stop(): + return + self.restore_thread_local_summary_state() + self.restore_thread_local_eager_context_state() + # TODO(josh11b): Use current logical device instead of 0 here. + with self.coord.stop_on_exception(), \ + _enter_graph(self._init_graph, self._init_in_eager), \ + _enter_graph(self.graph, self.in_eager, + self._variable_creator_stack), \ + context.device_policy(self.context_device_policy), \ + _MirroredReplicaContext(self.distribution, constant_op.constant( + self.replica_id, dtypes.int32)), \ + ops.device(self.devices[self.replica_id]), \ + ops.name_scope(self._name_scope), \ + variable_scope.variable_scope( + self._var_scope, reuse=self.replica_id > 0), \ + variable_scope.variable_creator_scope(self.variable_creator_fn): + self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) + self.done = True + finally: + self.has_paused.set() + + def record_thread_local_summary_state(self): + """Record the thread local summary state in self.""" + # TODO(slebedev): is this still relevant? the referenced bug is closed. + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + self._summary_step = summary_state.step + self._summary_writer = summary_state.writer + self._summary_recording = summary_state.is_recording + self._summary_recording_distribution_strategy = ( + summary_state.is_recording_distribution_strategy) + + def restore_thread_local_summary_state(self): + """Restore thread local summary state from self.""" + # TODO(slebedev): is this still relevant? the referenced bug is closed. + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + summary_state.step = self._summary_step + summary_state.writer = self._summary_writer + summary_state.is_recording = self._summary_recording + summary_state.is_recording_distribution_strategy = ( + self._summary_recording_distribution_strategy) + + def record_thread_local_eager_context_state(self): + ctx = context.context() + eager_context_state = ctx._thread_local_data # pylint: disable=protected-access + self._eager_context_op_callbacks = eager_context_state.op_callbacks + # TODO(b/125892694): record other fields in EagerContext. + + def restore_thread_local_eager_context_state(self): + ctx = context.context() + eager_context_state = ctx._thread_local_data # pylint: disable=protected-access + eager_context_state.op_callbacks = self._eager_context_op_callbacks + # TODO(b/125892694): record other fields in EagerContext. + + +class _MirroredReplicaContext(distribute_lib.ReplicaContext): + """ReplicaContext for synchronized replica.""" + + def _merge_call(self, fn, args, kwargs): + """`merge_call()` implementation for synchronized replica. + + This pauses the current replica thread and passes `fn` and its arguments to + the main thread. The main thread will wait until all replicas pause, then + invoke `fn` with grouped arugments. The current replica thread will continue + after `fn` completes. + + See `_call_for_each_replica` for the logic in the main thread. + + Args: + fn: a function that is called in cross replica context with grouped + arguments from each replica. `fn` should returns grouped values. + args: positional arguments to `fn`. + kwargs: keyward arguments to `fn`. + + Returns: + Return value of `fn` for the current replica. + + Raises: + RuntimeError: when merge_call happens in a different graph, e.g. in a + different tf.function, which is not supported now. + _RequestedStop: when stop is requested. + + """ + t = threading.current_thread() + assert isinstance(t, _MirroredReplicaThread) + t.merge_fn = fn + t.merge_args = args + t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" + + t.captured_var_scope = variable_scope.get_variable_scope() + t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access + + # It is problematic if `merge_call` is called under a different graph other + # than the one that `_call_for_each_replica` is called under, there are + # 3 cases this can happen: + # + # 1. The `fn` passed to `_call_for_each_replica` is decorated with + # `tf.function` and there is a `merge_call` in `fn`. Since + # MirroredStrategy traces a separate function per thread (per device), + # and each trace takes a shared lock, the lock is never released by the + # first thread and subsequent replica threads cannot proceed to trace + # their own functions. This issue is addressed by always converting + # `_call_for_each_replica(tf.function(f))` to + # ``tf.function(_call_for_each_replica(f))`.` in + # `MirroredStrategy._call_for_each_replica`. + # + # 2. The `fn` passed to `_call_for_each_replica` contains a nested + # `tf.function`, and there is a `merge_call` in the nested `tf.function`. + # In this case each thread can successfully trace its own function, but + # since the `merge_fn` passed to `merge_call` is executed in the main + # thread (where `_call_for_each_replica` is executed), it can't access + # the tensors that come from different graphs. + # + # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow + # statement, and there is a `merge_call` inside the control-flow body, + # `fn` or `_call_for_each_replica` is decorated with `tf.function`. + # Control flow statement creates a separate graph for its body, similar + # to #2, `merge_fn` executed in the main thread can't access the + # tensors that come from different graphs. + # + # We raise an error for #2 and #3. + if ops.get_default_graph() != t.graph: + raise RuntimeError( + "`merge_call` called while defining a new graph or a tf.function." + " This can often happen if the function `fn` passed to" + " `strategy.run()` contains a nested `@tf.function`, and the nested " + "`@tf.function` contains a synchronization point, such as aggregating" + " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" + " uses a control flow statement which contains a synchronization" + " point in the body. Such behaviors are not yet supported. Instead," + " please avoid nested `tf.function`s or control flow statements that" + " may potentially cross a synchronization boundary, for example," + " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" + " inside a `tf.function` or move the control flow out of `fn`") + + t.has_paused.set() + t.should_run.wait() + t.should_run.clear() + if t.coord.should_stop(): + raise _RequestedStop() + return t.merge_result + + @property + def devices(self): + distribute_lib.require_replica_context(self) + replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) + return [self._strategy.extended.worker_devices_by_replica[replica_id]] diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 698bf2c2ce6..de66128ce37 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -18,191 +18,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import copy -import functools -import threading -import weakref -from tensorflow.python import pywrap_tfe -from tensorflow.python.autograph.core import ag_ctx as autograph_ctx -from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util -from tensorflow.python.distribute import shared_variable_creator from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.eager import context -from tensorflow.python.eager import def_function from tensorflow.python.eager import tape from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import summary_ops_v2 -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import coordinator from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export - # TODO(josh11b): Replace asserts in this file with if ...: raise ... -@contextlib.contextmanager -def _enter_graph(g, eager, creator_stack=None): - """Context manager for selecting a graph and maybe eager mode.""" - if eager: - with g.as_default(), context.eager_mode(): - if creator_stack is not None: - g._variable_creator_stack = creator_stack # pylint: disable=protected-access - yield - else: - with g.as_default(): - if creator_stack is not None: - g._variable_creator_stack = creator_stack # pylint: disable=protected-access - yield - - -def _cpu_device(device): - cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device = cpu_device.replace(device_type="CPU", device_index=0) - return cpu_device.to_string() - - -class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name - pass - - -# _call_for_each_replica is not a member of MirroredStrategy so that it is -# not allowed to use anything specific to MirroredStrategy and thus -# can be shared with other distribution strategies. - - -# TODO(yuefengz): maybe create a common class for those who need to call this -# _call_for_each_replica. -def _call_for_each_replica(distribution, devices, fn, args, kwargs): - """Run `fn` in separate threads, once per replica/worker device. - - Args: - distribution: the DistributionStrategy object. - devices: the devices to run `fn` on (logical device 0 for each replica). - fn: function to run (will be run once per replica, each in its own thread). - args: positional arguments for `fn` - kwargs: keyword arguments for `fn`. - - Returns: - Merged return value of `fn` across all replicas. - - Raises: - RuntimeError: If fn() calls get_replica_context().merge_call() a different - number of times from the available devices. - """ - # TODO(josh11b): Add this option once we add synchronization to variable - # creation. Until then, this is pretty unsafe to use. - run_concurrently = False - if not context.executing_eagerly(): - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - - coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every call. - threads = [] - for index in range(len(devices)): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = _MirroredReplicaThread( - distribution, coord, index, devices, variable_creator_fn, fn, - values.select_replica(index, args), - values.select_replica(index, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredReplicaThread - # (`MRT`) threads. The execution waits until - # `MRT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_replica_context().merge_call()` is called. If `fn` is - # complete, then `MRT.done` is set to True. Otherwise, arguments - # of `get_replica_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_replica_context().merge_call` are then set to `MRT.merge_result`. - # Each such `get_replica_context().merge_call` call returns the - # `MRT.merge_result` for that thread when `MRT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some replicas made a different number of " - "replica_context().merge_call() calls.") - # get_replica_context().merge_call() case - merge_args = values.regroup(tuple(t.merge_args for t in threads)) - merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads)) - # We capture the name_scope of the MRT when we call merge_fn - # to ensure that if we have opened a name scope in the MRT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MRT and assume it is - # the same for all other MRTs. - mtt_captured_name_scope = threads[0].captured_name_scope - mtt_captured_var_scope = threads[0].captured_var_scope - # Capture and merge the control dependencies from all the threads. - mtt_captured_control_deps = set() - for t in threads: - mtt_captured_control_deps.update(t.captured_control_deps) - with ops.name_scope(mtt_captured_name_scope),\ - ops.control_dependencies(mtt_captured_control_deps), \ - variable_scope.variable_scope(mtt_captured_var_scope): - merge_result = threads[0].merge_fn(distribution, *merge_args, - **merge_kwargs) - for r, t in enumerate(threads): - t.merge_result = values.select_replica(r, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup(tuple(t.main_result for t in threads)) - - def _is_device_list_single_worker(devices): """Checks whether the devices list is for single or multi-worker. @@ -469,7 +311,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): "any local devices.") self._cross_device_ops = cross_device_ops self._initialize_strategy(devices) - self._cfer_fn_cache = weakref.WeakKeyDictionary() # TODO(b/128995245): Enable last partial batch support in graph mode. if ops.executing_eagerly_outside_functions(): @@ -739,35 +580,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return self._get_cross_device_ops().broadcast(tensor, destinations) def _call_for_each_replica(self, fn, args, kwargs): - if isinstance(fn, def_function.Function): - wrapped = self._cfer_fn_cache.get(fn) - if wrapped is None: - # We need to wrap fn such that it triggers _call_for_each_replica inside - # the tf.function. - wrapped = fn._clone( # pylint: disable=protected-access - python_function=functools.partial(self._call_for_each_replica, - fn.python_function)) - self._cfer_fn_cache[fn] = wrapped - return wrapped(args, kwargs) - - if context.executing_eagerly(): - logging.log_first_n( - logging.WARN, "Using %s eagerly has significant " - "overhead currently. We will be working on improving " - "this in the future, but for now please wrap " - "`call_for_each_replica` or `experimental_run` or " - "`run` inside a tf.function to get the best performance." % - self._container_strategy().__class__.__name__, 5) - else: - # When a tf.function is wrapped to trigger _call_for_each_replica (see - # the other branch above), AutoGraph stops conversion at - # _call_for_each_replica itself (TF library functions are whitelisted). - # This makes sure that the Python function that originally passed to - # the tf.function is still converted. - fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) - - return _call_for_each_replica(self._container_strategy(), self._devices, - fn, args, kwargs) + return mirrored_run.call_for_each_replica(self._container_strategy(), fn, + args, kwargs) def _configure(self, session_config=None, @@ -912,203 +726,3 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" return False - - -class _MirroredReplicaThread(threading.Thread): - """A thread that runs() a function on a device.""" - - def __init__(self, dist, coord, replica_id, devices, variable_creator_fn, - fn, args, kwargs): - super(_MirroredReplicaThread, self).__init__() - self.coord = coord - self.distribution = dist - self.devices = devices - self.replica_id = replica_id - self.variable_creator_fn = variable_creator_fn - # State needed to run and return the results of `fn`. - self.main_fn = fn - self.main_args = args - self.main_kwargs = kwargs - self.main_result = None - self.done = False - # State needed to run the next merge_call() (if any) requested via - # ReplicaContext. - self.merge_fn = None - self.merge_args = None - self.merge_kwargs = None - self.merge_result = None - self.captured_name_scope = None - self.captured_var_scope = None - # We use a thread.Event for the main thread to signal when this - # thread should start running (`should_run`), and another for - # this thread to transfer control back to the main thread - # (`has_paused`, either when it gets to a - # `get_replica_context().merge_call` or when `fn` returns). In - # either case the event starts cleared, is signaled by calling - # set(). The receiving thread waits for the signal by calling - # wait() and then immediately clearing the event using clear(). - self.should_run = threading.Event() - self.has_paused = threading.Event() - # These fields have to do with inheriting various contexts from the - # parent thread: - context.ensure_initialized() - ctx = context.context() - self.in_eager = ctx.executing_eagerly() - self.record_thread_local_summary_state() - self.record_thread_local_eager_context_state() - self.context_device_policy = ( - pywrap_tfe.TFE_ContextGetDevicePlacementPolicy( - ctx._context_handle)) # pylint: disable=protected-access - self.graph = ops.get_default_graph() - with ops.init_scope(): - self._init_in_eager = context.executing_eagerly() - self._init_graph = ops.get_default_graph() - self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access - self._var_scope = variable_scope.get_variable_scope() - # Adding a "/" at end lets us re-enter this scope later. - self._name_scope = self.graph.get_name_scope() - if self._name_scope: - self._name_scope += "/" - if self.replica_id > 0: - if not self._name_scope: - self._name_scope = "" - self._name_scope += "replica_%d/" % self.replica_id - - def run(self): - self.should_run.wait() - self.should_run.clear() - try: - if self.coord.should_stop(): - return - self.restore_thread_local_summary_state() - self.restore_thread_local_eager_context_state() - # TODO(josh11b): Use current logical device instead of 0 here. - with self.coord.stop_on_exception(), \ - _enter_graph(self._init_graph, self._init_in_eager), \ - _enter_graph(self.graph, self.in_eager, - self._variable_creator_stack), \ - context.device_policy(self.context_device_policy), \ - MirroredReplicaContext(self.distribution, constant_op.constant( - self.replica_id, dtypes.int32)), \ - ops.device(self.devices[self.replica_id]), \ - ops.name_scope(self._name_scope), \ - variable_scope.variable_scope( - self._var_scope, reuse=self.replica_id > 0), \ - variable_scope.variable_creator_scope(self.variable_creator_fn): - self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) - self.done = True - finally: - self.has_paused.set() - - def record_thread_local_summary_state(self): - """Record the thread local summary state in self.""" - # TODO(slebedev): is this still relevant? the referenced bug is closed. - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - self._summary_step = summary_state.step - self._summary_writer = summary_state.writer - self._summary_recording = summary_state.is_recording - self._summary_recording_distribution_strategy = ( - summary_state.is_recording_distribution_strategy) - - def restore_thread_local_summary_state(self): - """Restore thread local summary state from self.""" - # TODO(slebedev): is this still relevant? the referenced bug is closed. - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.step = self._summary_step - summary_state.writer = self._summary_writer - summary_state.is_recording = self._summary_recording - summary_state.is_recording_distribution_strategy = ( - self._summary_recording_distribution_strategy) - - def record_thread_local_eager_context_state(self): - ctx = context.context() - eager_context_state = ctx._thread_local_data # pylint: disable=protected-access - self._eager_context_op_callbacks = eager_context_state.op_callbacks - # TODO(b/125892694): record other fields in EagerContext. - - def restore_thread_local_eager_context_state(self): - ctx = context.context() - eager_context_state = ctx._thread_local_data # pylint: disable=protected-access - eager_context_state.op_callbacks = self._eager_context_op_callbacks - # TODO(b/125892694): record other fields in EagerContext. - - -class MirroredReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext used in MirroredStrategy.extended.call_for_each_replica(). - - Opened in `_MirroredReplicaThread`, to allow the user to invoke - `MirroredStrategy`'s specific implementation of `merge_call()`, - which works by delegating the function and its arguments to - the main thread (the one that invoked - `MirroredStrategy.extended.call_for_each_replica()`). - """ - - def _merge_call(self, fn, args, kwargs): - """Delegate to the main thread to actually perform merge_call().""" - t = threading.current_thread() # a _MirroredReplicaThread - t.merge_fn = fn - t.merge_args = args - t.merge_kwargs = kwargs - t.captured_name_scope = t.graph.get_name_scope() - # Adding a "/" at end lets us re-enter this scope later. - if t.captured_name_scope: - t.captured_name_scope += "/" - - t.captured_var_scope = variable_scope.get_variable_scope() - t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access - - # It is problematic if `merge_call` is called under a different graph other - # than the one that `_call_for_each_replica` is called under, there are - # 3 cases this can happen: - # - # 1. The `fn` passed to `_call_for_each_replica` is decorated with - # `tf.function` and there is a `merge_call` in `fn`. Since - # MirroredStrategy traces a separate function per thread (per device), - # and each trace takes a shared lock, the lock is never released by the - # first thread and subsequent replica threads cannot proceed to trace - # their own functions. This issue is addressed by always converting - # `_call_for_each_replica(tf.function(f))` to - # ``tf.function(_call_for_each_replica(f))`.` in - # `MirroredStrategy._call_for_each_replica`. - # - # 2. The `fn` passed to `_call_for_each_replica` contains a nested - # `tf.function`, and there is a `merge_call` in the nested `tf.function`. - # In this case each thread can successfully trace its own function, but - # since the `merge_fn` passed to `merge_call` is executed in the main - # thread (where `_call_for_each_replica` is executed), it can't access - # the tensors that come from different graphs. - # - # 3. The `fn` passed to `_call_for_each_replica` contains a control-flow - # statement, and there is a `merge_call` inside the control-flow body, - # `fn` or `_call_for_each_replica` is decorated with `tf.function`. - # Control flow statement creates a separate graph for its body, similar - # to #2, `merge_fn` executed in the main thread can't access the - # tensors that come from different graphs. - # - # We raise an error for #2 and #3. - if ops.get_default_graph() != t.graph: - raise RuntimeError( - "`merge_call` called while defining a new graph or a tf.function." - " This can often happen if the function `fn` passed to" - " `strategy.run()` contains a nested `@tf.function`, and the nested " - "`@tf.function` contains a synchronization point, such as aggregating" - " gradients (e.g, optimizer.apply_gradients), or if the function `fn`" - " uses a control flow statement which contains a synchronization" - " point in the body. Such behaviors are not yet supported. Instead," - " please avoid nested `tf.function`s or control flow statements that" - " may potentially cross a synchronization boundary, for example," - " wrap the `fn` passed to `strategy.run` or the entire `strategy.run`" - " inside a `tf.function` or move the control flow out of `fn`") - - t.has_paused.set() - t.should_run.wait() - t.should_run.clear() - if t.coord.should_stop(): - raise _RequestedStop() - return t.merge_result - - @property - def devices(self): - distribute_lib.require_replica_context(self) - replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return [self._strategy.extended.worker_devices_by_replica[replica_id]] diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 7099e4a6390..e1f0f41b393 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -25,7 +25,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib -from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import mirrored_run from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import values @@ -456,9 +456,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): return var_creator(**kwargs) def _call_for_each_replica(self, fn, args, kwargs): - # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica( - self._container_strategy(), self._compute_devices, fn, args, kwargs) + return mirrored_run.call_for_each_replica(self._container_strategy(), fn, + args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index c3c3e0d5286..96089a67b00 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -1841,7 +1841,6 @@ class AggregatingVariableTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.) def testAssignAdd(self, distribution): - self.skipTest("b/151250566") with distribution.scope(): v = variable_scope.variable( 1, aggregation=variables_lib.VariableAggregation.MEAN)