Separate mirrored call_for_each_replica to its own file
Both ParameterServer and Mirrored uses this, and itself is complicated enough. This also fix a issue that you can't strategy.run(tf.function) under CentralStorageStrategy, by applying the same workaround we have in MirroredStrategy. PiperOrigin-RevId: 304437119 Change-Id: I6a7a67b88e7a5b7217aa9ffe05882d0ef4097896
This commit is contained in:
parent
f2bde78c4c
commit
b16d24a342
|
@ -255,22 +255,17 @@ py_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "mirrored_strategy",
|
name = "mirrored_run",
|
||||||
srcs = ["mirrored_strategy.py"],
|
srcs = ["mirrored_run.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":cross_device_ops",
|
|
||||||
":device_util",
|
":device_util",
|
||||||
":distribute_lib",
|
":distribute_lib",
|
||||||
":input_lib",
|
|
||||||
":multi_worker_util",
|
|
||||||
":numpy_dataset",
|
|
||||||
":reduce_util",
|
":reduce_util",
|
||||||
":shared_variable_creator",
|
":shared_variable_creator",
|
||||||
":values",
|
":values",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:config",
|
"//tensorflow/python:config",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:control_flow_ops",
|
|
||||||
"//tensorflow/python:device",
|
"//tensorflow/python:device",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
@ -284,9 +279,33 @@ py_library(
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python/autograph/core",
|
"//tensorflow/python/autograph/core",
|
||||||
"//tensorflow/python/autograph/impl",
|
"//tensorflow/python/autograph/impl",
|
||||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "mirrored_strategy",
|
||||||
|
srcs = ["mirrored_strategy.py"],
|
||||||
|
deps = [
|
||||||
|
":cross_device_ops",
|
||||||
|
":device_util",
|
||||||
|
":distribute_lib",
|
||||||
|
":input_lib",
|
||||||
|
":mirrored_run",
|
||||||
|
":multi_worker_util",
|
||||||
|
":numpy_dataset",
|
||||||
|
":reduce_util",
|
||||||
|
":values",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:device",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:tape",
|
"//tensorflow/python/eager:tape",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -297,7 +316,7 @@ py_library(
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
":input_lib",
|
":input_lib",
|
||||||
":mirrored_strategy",
|
":mirrored_run",
|
||||||
":numpy_dataset",
|
":numpy_dataset",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
|
|
@ -0,0 +1,454 @@
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Class MirroredStrategy implementing tf.distribute.Strategy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import threading
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tfe
|
||||||
|
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
||||||
|
from tensorflow.python.autograph.impl import api as autograph
|
||||||
|
from tensorflow.python.distribute import distribute_lib
|
||||||
|
from tensorflow.python.distribute import shared_variable_creator
|
||||||
|
from tensorflow.python.distribute import values
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import device as tf_device
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import summary_ops_v2
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import coordinator
|
||||||
|
|
||||||
|
|
||||||
|
def call_for_each_replica(strategy, fn, args=None, kwargs=None):
|
||||||
|
"""Call `fn` on each worker devices(replica).
|
||||||
|
|
||||||
|
It's highly recommended to wrap the call to this function inside a
|
||||||
|
`tf.function`, otherwise the performance is poor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: `tf.distribute.Strategy`.
|
||||||
|
fn: function to call on each worker devices.
|
||||||
|
args: positional arguments to `fn`.
|
||||||
|
kwargs: keyword arguments to `fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped returned value of `fn` from all replicas.
|
||||||
|
"""
|
||||||
|
if args is None:
|
||||||
|
args = ()
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
if isinstance(fn, def_function.Function):
|
||||||
|
if strategy not in _cfer_fn_cache:
|
||||||
|
_cfer_fn_cache[strategy] = weakref.WeakKeyDictionary()
|
||||||
|
wrapped = _cfer_fn_cache[strategy].get(fn)
|
||||||
|
if wrapped is None:
|
||||||
|
# We need to wrap fn such that it triggers _call_for_each_replica inside
|
||||||
|
# the tf.function. We use _clone() instead of @tf.function wrapped
|
||||||
|
# call_for_each_replica() because we would like to retain the arguments to
|
||||||
|
# the @tf.function decorator of fn.
|
||||||
|
wrapped = fn._clone( # pylint: disable=protected-access
|
||||||
|
python_function=functools.partial(call_for_each_replica, strategy,
|
||||||
|
fn.python_function))
|
||||||
|
_cfer_fn_cache[strategy][fn] = wrapped
|
||||||
|
return wrapped(args, kwargs)
|
||||||
|
|
||||||
|
if context.executing_eagerly():
|
||||||
|
logging.log_first_n(
|
||||||
|
logging.WARN, "Using %s eagerly has significant "
|
||||||
|
"overhead currently. We will be working on improving "
|
||||||
|
"this in the future, but for now please wrap "
|
||||||
|
"`call_for_each_replica` or `experimental_run` or "
|
||||||
|
"`experimental_run_v2` inside a tf.function to get "
|
||||||
|
"the best performance." % strategy.__class__.__name__, 5)
|
||||||
|
else:
|
||||||
|
# When a tf.function is wrapped to trigger _call_for_each_replica (see
|
||||||
|
# the other branch above), AutoGraph stops conversion at
|
||||||
|
# _call_for_each_replica itself (TF library functions are whitelisted).
|
||||||
|
# This makes sure that the Python function that originally passed to
|
||||||
|
# the tf.function is still converted.
|
||||||
|
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
||||||
|
|
||||||
|
return _call_for_each_replica(strategy, fn, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Per strategy cache for call_for_each_replica def_function.Function objects.
|
||||||
|
_cfer_fn_cache = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _enter_graph(g, eager, creator_stack=None):
|
||||||
|
"""Context manager for selecting a graph and maybe eager mode."""
|
||||||
|
if eager:
|
||||||
|
with g.as_default(), context.eager_mode():
|
||||||
|
if creator_stack is not None:
|
||||||
|
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
with g.as_default():
|
||||||
|
if creator_stack is not None:
|
||||||
|
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _cpu_device(device):
|
||||||
|
cpu_device = tf_device.DeviceSpec.from_string(device)
|
||||||
|
cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
|
||||||
|
return cpu_device.to_string()
|
||||||
|
|
||||||
|
|
||||||
|
class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _call_for_each_replica(distribution, fn, args, kwargs):
|
||||||
|
"""Run `fn` in separate threads, once per replica/worker device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution: the DistributionStrategy object.
|
||||||
|
fn: function to run (will be run once per replica, each in its own thread).
|
||||||
|
args: positional arguments for `fn`
|
||||||
|
kwargs: keyword arguments for `fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged return value of `fn` across all replicas.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If fn() calls get_replica_context().merge_call() a different
|
||||||
|
number of times from the available devices.
|
||||||
|
"""
|
||||||
|
# TODO(josh11b): Add this option once we add synchronization to variable
|
||||||
|
# creation. Until then, this is pretty unsafe to use.
|
||||||
|
run_concurrently = False
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
# Needed for per-thread device, etc. contexts in graph mode.
|
||||||
|
ops.get_default_graph().switch_to_thread_local()
|
||||||
|
|
||||||
|
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
|
||||||
|
|
||||||
|
shared_variable_store = {}
|
||||||
|
devices = distribution.extended.worker_devices
|
||||||
|
|
||||||
|
# TODO(isaprykin): Create these threads once instead of during every call.
|
||||||
|
threads = []
|
||||||
|
for index in range(len(devices)):
|
||||||
|
variable_creator_fn = shared_variable_creator.make_fn(
|
||||||
|
shared_variable_store, index)
|
||||||
|
t = _MirroredReplicaThread(
|
||||||
|
distribution, coord, index, devices, variable_creator_fn, fn,
|
||||||
|
values.select_replica(index, args),
|
||||||
|
values.select_replica(index, kwargs))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
# When `fn` starts `should_run` event is set on _MirroredReplicaThread
|
||||||
|
# (`MRT`) threads. The execution waits until
|
||||||
|
# `MRT.has_paused` is set, which indicates that either `fn` is
|
||||||
|
# complete or a `get_replica_context().merge_call()` is called. If `fn` is
|
||||||
|
# complete, then `MRT.done` is set to True. Otherwise, arguments
|
||||||
|
# of `get_replica_context().merge_call` from all paused threads are grouped
|
||||||
|
# and the `merge_fn` is performed. Results of the
|
||||||
|
# `get_replica_context().merge_call` are then set to `MRT.merge_result`.
|
||||||
|
# Each such `get_replica_context().merge_call` call returns the
|
||||||
|
# `MRT.merge_result` for that thread when `MRT.should_run` event
|
||||||
|
# is reset again. Execution of `fn` resumes.
|
||||||
|
|
||||||
|
try:
|
||||||
|
with coord.stop_on_exception():
|
||||||
|
all_done = False
|
||||||
|
while not all_done and not coord.should_stop():
|
||||||
|
done = []
|
||||||
|
if run_concurrently:
|
||||||
|
for t in threads:
|
||||||
|
t.should_run.set()
|
||||||
|
for t in threads:
|
||||||
|
t.has_paused.wait()
|
||||||
|
t.has_paused.clear()
|
||||||
|
if coord.should_stop():
|
||||||
|
return None
|
||||||
|
done.append(t.done)
|
||||||
|
else:
|
||||||
|
for t in threads:
|
||||||
|
t.should_run.set()
|
||||||
|
t.has_paused.wait()
|
||||||
|
t.has_paused.clear()
|
||||||
|
if coord.should_stop():
|
||||||
|
return None
|
||||||
|
done.append(t.done)
|
||||||
|
if coord.should_stop():
|
||||||
|
return None
|
||||||
|
all_done = all(done)
|
||||||
|
if not all_done:
|
||||||
|
if any(done):
|
||||||
|
raise RuntimeError("Some replicas made a different number of "
|
||||||
|
"replica_context().merge_call() calls.")
|
||||||
|
# get_replica_context().merge_call() case
|
||||||
|
merge_args = values.regroup(tuple(t.merge_args for t in threads))
|
||||||
|
merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads))
|
||||||
|
# We capture the name_scope of the MRT when we call merge_fn
|
||||||
|
# to ensure that if we have opened a name scope in the MRT,
|
||||||
|
# it will be respected when executing the merge function. We only
|
||||||
|
# capture the name_scope from the first MRT and assume it is
|
||||||
|
# the same for all other MRTs.
|
||||||
|
mtt_captured_name_scope = threads[0].captured_name_scope
|
||||||
|
mtt_captured_var_scope = threads[0].captured_var_scope
|
||||||
|
# Capture and merge the control dependencies from all the threads.
|
||||||
|
mtt_captured_control_deps = set()
|
||||||
|
for t in threads:
|
||||||
|
mtt_captured_control_deps.update(t.captured_control_deps)
|
||||||
|
with ops.name_scope(mtt_captured_name_scope),\
|
||||||
|
ops.control_dependencies(mtt_captured_control_deps), \
|
||||||
|
variable_scope.variable_scope(mtt_captured_var_scope):
|
||||||
|
merge_result = threads[0].merge_fn(distribution, *merge_args,
|
||||||
|
**merge_kwargs)
|
||||||
|
for r, t in enumerate(threads):
|
||||||
|
t.merge_result = values.select_replica(r, merge_result)
|
||||||
|
finally:
|
||||||
|
for t in threads:
|
||||||
|
t.should_run.set()
|
||||||
|
coord.join(threads)
|
||||||
|
|
||||||
|
return values.regroup(tuple(t.main_result for t in threads))
|
||||||
|
|
||||||
|
|
||||||
|
class _MirroredReplicaThread(threading.Thread):
|
||||||
|
"""A thread that runs() a function on a device."""
|
||||||
|
|
||||||
|
def __init__(self, dist, coord, replica_id, devices, variable_creator_fn,
|
||||||
|
fn, args, kwargs):
|
||||||
|
super(_MirroredReplicaThread, self).__init__()
|
||||||
|
self.coord = coord
|
||||||
|
self.distribution = dist
|
||||||
|
self.devices = devices
|
||||||
|
self.replica_id = replica_id
|
||||||
|
self.variable_creator_fn = variable_creator_fn
|
||||||
|
# State needed to run and return the results of `fn`.
|
||||||
|
self.main_fn = fn
|
||||||
|
self.main_args = args
|
||||||
|
self.main_kwargs = kwargs
|
||||||
|
self.main_result = None
|
||||||
|
self.done = False
|
||||||
|
# State needed to run the next merge_call() (if any) requested via
|
||||||
|
# ReplicaContext.
|
||||||
|
self.merge_fn = None
|
||||||
|
self.merge_args = None
|
||||||
|
self.merge_kwargs = None
|
||||||
|
self.merge_result = None
|
||||||
|
self.captured_name_scope = None
|
||||||
|
self.captured_var_scope = None
|
||||||
|
# We use a thread.Event for the main thread to signal when this
|
||||||
|
# thread should start running (`should_run`), and another for
|
||||||
|
# this thread to transfer control back to the main thread
|
||||||
|
# (`has_paused`, either when it gets to a
|
||||||
|
# `get_replica_context().merge_call` or when `fn` returns). In
|
||||||
|
# either case the event starts cleared, is signaled by calling
|
||||||
|
# set(). The receiving thread waits for the signal by calling
|
||||||
|
# wait() and then immediately clearing the event using clear().
|
||||||
|
self.should_run = threading.Event()
|
||||||
|
self.has_paused = threading.Event()
|
||||||
|
# These fields have to do with inheriting various contexts from the
|
||||||
|
# parent thread:
|
||||||
|
context.ensure_initialized()
|
||||||
|
ctx = context.context()
|
||||||
|
self.in_eager = ctx.executing_eagerly()
|
||||||
|
self.record_thread_local_summary_state()
|
||||||
|
self.record_thread_local_eager_context_state()
|
||||||
|
self.context_device_policy = (
|
||||||
|
pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
|
||||||
|
ctx._context_handle)) # pylint: disable=protected-access
|
||||||
|
self.graph = ops.get_default_graph()
|
||||||
|
with ops.init_scope():
|
||||||
|
self._init_in_eager = context.executing_eagerly()
|
||||||
|
self._init_graph = ops.get_default_graph()
|
||||||
|
self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access
|
||||||
|
self._var_scope = variable_scope.get_variable_scope()
|
||||||
|
# Adding a "/" at end lets us re-enter this scope later.
|
||||||
|
self._name_scope = self.graph.get_name_scope()
|
||||||
|
if self._name_scope:
|
||||||
|
self._name_scope += "/"
|
||||||
|
if self.replica_id > 0:
|
||||||
|
if not self._name_scope:
|
||||||
|
self._name_scope = ""
|
||||||
|
self._name_scope += "replica_%d/" % self.replica_id
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.should_run.wait()
|
||||||
|
self.should_run.clear()
|
||||||
|
try:
|
||||||
|
if self.coord.should_stop():
|
||||||
|
return
|
||||||
|
self.restore_thread_local_summary_state()
|
||||||
|
self.restore_thread_local_eager_context_state()
|
||||||
|
# TODO(josh11b): Use current logical device instead of 0 here.
|
||||||
|
with self.coord.stop_on_exception(), \
|
||||||
|
_enter_graph(self._init_graph, self._init_in_eager), \
|
||||||
|
_enter_graph(self.graph, self.in_eager,
|
||||||
|
self._variable_creator_stack), \
|
||||||
|
context.device_policy(self.context_device_policy), \
|
||||||
|
_MirroredReplicaContext(self.distribution, constant_op.constant(
|
||||||
|
self.replica_id, dtypes.int32)), \
|
||||||
|
ops.device(self.devices[self.replica_id]), \
|
||||||
|
ops.name_scope(self._name_scope), \
|
||||||
|
variable_scope.variable_scope(
|
||||||
|
self._var_scope, reuse=self.replica_id > 0), \
|
||||||
|
variable_scope.variable_creator_scope(self.variable_creator_fn):
|
||||||
|
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
|
||||||
|
self.done = True
|
||||||
|
finally:
|
||||||
|
self.has_paused.set()
|
||||||
|
|
||||||
|
def record_thread_local_summary_state(self):
|
||||||
|
"""Record the thread local summary state in self."""
|
||||||
|
# TODO(slebedev): is this still relevant? the referenced bug is closed.
|
||||||
|
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||||
|
self._summary_step = summary_state.step
|
||||||
|
self._summary_writer = summary_state.writer
|
||||||
|
self._summary_recording = summary_state.is_recording
|
||||||
|
self._summary_recording_distribution_strategy = (
|
||||||
|
summary_state.is_recording_distribution_strategy)
|
||||||
|
|
||||||
|
def restore_thread_local_summary_state(self):
|
||||||
|
"""Restore thread local summary state from self."""
|
||||||
|
# TODO(slebedev): is this still relevant? the referenced bug is closed.
|
||||||
|
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||||
|
summary_state.step = self._summary_step
|
||||||
|
summary_state.writer = self._summary_writer
|
||||||
|
summary_state.is_recording = self._summary_recording
|
||||||
|
summary_state.is_recording_distribution_strategy = (
|
||||||
|
self._summary_recording_distribution_strategy)
|
||||||
|
|
||||||
|
def record_thread_local_eager_context_state(self):
|
||||||
|
ctx = context.context()
|
||||||
|
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
|
||||||
|
self._eager_context_op_callbacks = eager_context_state.op_callbacks
|
||||||
|
# TODO(b/125892694): record other fields in EagerContext.
|
||||||
|
|
||||||
|
def restore_thread_local_eager_context_state(self):
|
||||||
|
ctx = context.context()
|
||||||
|
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
|
||||||
|
eager_context_state.op_callbacks = self._eager_context_op_callbacks
|
||||||
|
# TODO(b/125892694): record other fields in EagerContext.
|
||||||
|
|
||||||
|
|
||||||
|
class _MirroredReplicaContext(distribute_lib.ReplicaContext):
|
||||||
|
"""ReplicaContext for synchronized replica."""
|
||||||
|
|
||||||
|
def _merge_call(self, fn, args, kwargs):
|
||||||
|
"""`merge_call()` implementation for synchronized replica.
|
||||||
|
|
||||||
|
This pauses the current replica thread and passes `fn` and its arguments to
|
||||||
|
the main thread. The main thread will wait until all replicas pause, then
|
||||||
|
invoke `fn` with grouped arugments. The current replica thread will continue
|
||||||
|
after `fn` completes.
|
||||||
|
|
||||||
|
See `_call_for_each_replica` for the logic in the main thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: a function that is called in cross replica context with grouped
|
||||||
|
arguments from each replica. `fn` should returns grouped values.
|
||||||
|
args: positional arguments to `fn`.
|
||||||
|
kwargs: keyward arguments to `fn`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return value of `fn` for the current replica.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: when merge_call happens in a different graph, e.g. in a
|
||||||
|
different tf.function, which is not supported now.
|
||||||
|
_RequestedStop: when stop is requested.
|
||||||
|
|
||||||
|
"""
|
||||||
|
t = threading.current_thread()
|
||||||
|
assert isinstance(t, _MirroredReplicaThread)
|
||||||
|
t.merge_fn = fn
|
||||||
|
t.merge_args = args
|
||||||
|
t.merge_kwargs = kwargs
|
||||||
|
t.captured_name_scope = t.graph.get_name_scope()
|
||||||
|
# Adding a "/" at end lets us re-enter this scope later.
|
||||||
|
if t.captured_name_scope:
|
||||||
|
t.captured_name_scope += "/"
|
||||||
|
|
||||||
|
t.captured_var_scope = variable_scope.get_variable_scope()
|
||||||
|
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
# It is problematic if `merge_call` is called under a different graph other
|
||||||
|
# than the one that `_call_for_each_replica` is called under, there are
|
||||||
|
# 3 cases this can happen:
|
||||||
|
#
|
||||||
|
# 1. The `fn` passed to `_call_for_each_replica` is decorated with
|
||||||
|
# `tf.function` and there is a `merge_call` in `fn`. Since
|
||||||
|
# MirroredStrategy traces a separate function per thread (per device),
|
||||||
|
# and each trace takes a shared lock, the lock is never released by the
|
||||||
|
# first thread and subsequent replica threads cannot proceed to trace
|
||||||
|
# their own functions. This issue is addressed by always converting
|
||||||
|
# `_call_for_each_replica(tf.function(f))` to
|
||||||
|
# ``tf.function(_call_for_each_replica(f))`.` in
|
||||||
|
# `MirroredStrategy._call_for_each_replica`.
|
||||||
|
#
|
||||||
|
# 2. The `fn` passed to `_call_for_each_replica` contains a nested
|
||||||
|
# `tf.function`, and there is a `merge_call` in the nested `tf.function`.
|
||||||
|
# In this case each thread can successfully trace its own function, but
|
||||||
|
# since the `merge_fn` passed to `merge_call` is executed in the main
|
||||||
|
# thread (where `_call_for_each_replica` is executed), it can't access
|
||||||
|
# the tensors that come from different graphs.
|
||||||
|
#
|
||||||
|
# 3. The `fn` passed to `_call_for_each_replica` contains a control-flow
|
||||||
|
# statement, and there is a `merge_call` inside the control-flow body,
|
||||||
|
# `fn` or `_call_for_each_replica` is decorated with `tf.function`.
|
||||||
|
# Control flow statement creates a separate graph for its body, similar
|
||||||
|
# to #2, `merge_fn` executed in the main thread can't access the
|
||||||
|
# tensors that come from different graphs.
|
||||||
|
#
|
||||||
|
# We raise an error for #2 and #3.
|
||||||
|
if ops.get_default_graph() != t.graph:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`merge_call` called while defining a new graph or a tf.function."
|
||||||
|
" This can often happen if the function `fn` passed to"
|
||||||
|
" `strategy.run()` contains a nested `@tf.function`, and the nested "
|
||||||
|
"`@tf.function` contains a synchronization point, such as aggregating"
|
||||||
|
" gradients (e.g, optimizer.apply_gradients), or if the function `fn`"
|
||||||
|
" uses a control flow statement which contains a synchronization"
|
||||||
|
" point in the body. Such behaviors are not yet supported. Instead,"
|
||||||
|
" please avoid nested `tf.function`s or control flow statements that"
|
||||||
|
" may potentially cross a synchronization boundary, for example,"
|
||||||
|
" wrap the `fn` passed to `strategy.run` or the entire `strategy.run`"
|
||||||
|
" inside a `tf.function` or move the control flow out of `fn`")
|
||||||
|
|
||||||
|
t.has_paused.set()
|
||||||
|
t.should_run.wait()
|
||||||
|
t.should_run.clear()
|
||||||
|
if t.coord.should_stop():
|
||||||
|
raise _RequestedStop()
|
||||||
|
return t.merge_result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def devices(self):
|
||||||
|
distribute_lib.require_replica_context(self)
|
||||||
|
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
||||||
|
return [self._strategy.extended.worker_devices_by_replica[replica_id]]
|
|
@ -18,191 +18,33 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import copy
|
import copy
|
||||||
import functools
|
|
||||||
import threading
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tfe
|
|
||||||
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
|
||||||
from tensorflow.python.autograph.impl import api as autograph
|
|
||||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
|
from tensorflow.python.distribute import mirrored_run
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import numpy_dataset
|
from tensorflow.python.distribute import numpy_dataset
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import shared_variable_creator
|
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import summary_ops_v2
|
|
||||||
from tensorflow.python.ops import variable_scope
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import coordinator
|
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
|
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _enter_graph(g, eager, creator_stack=None):
|
|
||||||
"""Context manager for selecting a graph and maybe eager mode."""
|
|
||||||
if eager:
|
|
||||||
with g.as_default(), context.eager_mode():
|
|
||||||
if creator_stack is not None:
|
|
||||||
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
with g.as_default():
|
|
||||||
if creator_stack is not None:
|
|
||||||
g._variable_creator_stack = creator_stack # pylint: disable=protected-access
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _cpu_device(device):
|
|
||||||
cpu_device = tf_device.DeviceSpec.from_string(device)
|
|
||||||
cpu_device = cpu_device.replace(device_type="CPU", device_index=0)
|
|
||||||
return cpu_device.to_string()
|
|
||||||
|
|
||||||
|
|
||||||
class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# _call_for_each_replica is not a member of MirroredStrategy so that it is
|
|
||||||
# not allowed to use anything specific to MirroredStrategy and thus
|
|
||||||
# can be shared with other distribution strategies.
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): maybe create a common class for those who need to call this
|
|
||||||
# _call_for_each_replica.
|
|
||||||
def _call_for_each_replica(distribution, devices, fn, args, kwargs):
|
|
||||||
"""Run `fn` in separate threads, once per replica/worker device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
distribution: the DistributionStrategy object.
|
|
||||||
devices: the devices to run `fn` on (logical device 0 for each replica).
|
|
||||||
fn: function to run (will be run once per replica, each in its own thread).
|
|
||||||
args: positional arguments for `fn`
|
|
||||||
kwargs: keyword arguments for `fn`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Merged return value of `fn` across all replicas.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If fn() calls get_replica_context().merge_call() a different
|
|
||||||
number of times from the available devices.
|
|
||||||
"""
|
|
||||||
# TODO(josh11b): Add this option once we add synchronization to variable
|
|
||||||
# creation. Until then, this is pretty unsafe to use.
|
|
||||||
run_concurrently = False
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
# Needed for per-thread device, etc. contexts in graph mode.
|
|
||||||
ops.get_default_graph().switch_to_thread_local()
|
|
||||||
|
|
||||||
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
|
|
||||||
|
|
||||||
shared_variable_store = {}
|
|
||||||
|
|
||||||
# TODO(isaprykin): Create these threads once instead of during every call.
|
|
||||||
threads = []
|
|
||||||
for index in range(len(devices)):
|
|
||||||
variable_creator_fn = shared_variable_creator.make_fn(
|
|
||||||
shared_variable_store, index)
|
|
||||||
t = _MirroredReplicaThread(
|
|
||||||
distribution, coord, index, devices, variable_creator_fn, fn,
|
|
||||||
values.select_replica(index, args),
|
|
||||||
values.select_replica(index, kwargs))
|
|
||||||
threads.append(t)
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
# When `fn` starts `should_run` event is set on _MirroredReplicaThread
|
|
||||||
# (`MRT`) threads. The execution waits until
|
|
||||||
# `MRT.has_paused` is set, which indicates that either `fn` is
|
|
||||||
# complete or a `get_replica_context().merge_call()` is called. If `fn` is
|
|
||||||
# complete, then `MRT.done` is set to True. Otherwise, arguments
|
|
||||||
# of `get_replica_context().merge_call` from all paused threads are grouped
|
|
||||||
# and the `merge_fn` is performed. Results of the
|
|
||||||
# `get_replica_context().merge_call` are then set to `MRT.merge_result`.
|
|
||||||
# Each such `get_replica_context().merge_call` call returns the
|
|
||||||
# `MRT.merge_result` for that thread when `MRT.should_run` event
|
|
||||||
# is reset again. Execution of `fn` resumes.
|
|
||||||
|
|
||||||
try:
|
|
||||||
with coord.stop_on_exception():
|
|
||||||
all_done = False
|
|
||||||
while not all_done and not coord.should_stop():
|
|
||||||
done = []
|
|
||||||
if run_concurrently:
|
|
||||||
for t in threads:
|
|
||||||
t.should_run.set()
|
|
||||||
for t in threads:
|
|
||||||
t.has_paused.wait()
|
|
||||||
t.has_paused.clear()
|
|
||||||
if coord.should_stop():
|
|
||||||
return None
|
|
||||||
done.append(t.done)
|
|
||||||
else:
|
|
||||||
for t in threads:
|
|
||||||
t.should_run.set()
|
|
||||||
t.has_paused.wait()
|
|
||||||
t.has_paused.clear()
|
|
||||||
if coord.should_stop():
|
|
||||||
return None
|
|
||||||
done.append(t.done)
|
|
||||||
if coord.should_stop():
|
|
||||||
return None
|
|
||||||
all_done = all(done)
|
|
||||||
if not all_done:
|
|
||||||
if any(done):
|
|
||||||
raise RuntimeError("Some replicas made a different number of "
|
|
||||||
"replica_context().merge_call() calls.")
|
|
||||||
# get_replica_context().merge_call() case
|
|
||||||
merge_args = values.regroup(tuple(t.merge_args for t in threads))
|
|
||||||
merge_kwargs = values.regroup(tuple(t.merge_kwargs for t in threads))
|
|
||||||
# We capture the name_scope of the MRT when we call merge_fn
|
|
||||||
# to ensure that if we have opened a name scope in the MRT,
|
|
||||||
# it will be respected when executing the merge function. We only
|
|
||||||
# capture the name_scope from the first MRT and assume it is
|
|
||||||
# the same for all other MRTs.
|
|
||||||
mtt_captured_name_scope = threads[0].captured_name_scope
|
|
||||||
mtt_captured_var_scope = threads[0].captured_var_scope
|
|
||||||
# Capture and merge the control dependencies from all the threads.
|
|
||||||
mtt_captured_control_deps = set()
|
|
||||||
for t in threads:
|
|
||||||
mtt_captured_control_deps.update(t.captured_control_deps)
|
|
||||||
with ops.name_scope(mtt_captured_name_scope),\
|
|
||||||
ops.control_dependencies(mtt_captured_control_deps), \
|
|
||||||
variable_scope.variable_scope(mtt_captured_var_scope):
|
|
||||||
merge_result = threads[0].merge_fn(distribution, *merge_args,
|
|
||||||
**merge_kwargs)
|
|
||||||
for r, t in enumerate(threads):
|
|
||||||
t.merge_result = values.select_replica(r, merge_result)
|
|
||||||
finally:
|
|
||||||
for t in threads:
|
|
||||||
t.should_run.set()
|
|
||||||
coord.join(threads)
|
|
||||||
|
|
||||||
return values.regroup(tuple(t.main_result for t in threads))
|
|
||||||
|
|
||||||
|
|
||||||
def _is_device_list_single_worker(devices):
|
def _is_device_list_single_worker(devices):
|
||||||
"""Checks whether the devices list is for single or multi-worker.
|
"""Checks whether the devices list is for single or multi-worker.
|
||||||
|
|
||||||
|
@ -469,7 +311,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||||
"any local devices.")
|
"any local devices.")
|
||||||
self._cross_device_ops = cross_device_ops
|
self._cross_device_ops = cross_device_ops
|
||||||
self._initialize_strategy(devices)
|
self._initialize_strategy(devices)
|
||||||
self._cfer_fn_cache = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
# TODO(b/128995245): Enable last partial batch support in graph mode.
|
# TODO(b/128995245): Enable last partial batch support in graph mode.
|
||||||
if ops.executing_eagerly_outside_functions():
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
@ -739,35 +580,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||||
return self._get_cross_device_ops().broadcast(tensor, destinations)
|
return self._get_cross_device_ops().broadcast(tensor, destinations)
|
||||||
|
|
||||||
def _call_for_each_replica(self, fn, args, kwargs):
|
def _call_for_each_replica(self, fn, args, kwargs):
|
||||||
if isinstance(fn, def_function.Function):
|
return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
|
||||||
wrapped = self._cfer_fn_cache.get(fn)
|
args, kwargs)
|
||||||
if wrapped is None:
|
|
||||||
# We need to wrap fn such that it triggers _call_for_each_replica inside
|
|
||||||
# the tf.function.
|
|
||||||
wrapped = fn._clone( # pylint: disable=protected-access
|
|
||||||
python_function=functools.partial(self._call_for_each_replica,
|
|
||||||
fn.python_function))
|
|
||||||
self._cfer_fn_cache[fn] = wrapped
|
|
||||||
return wrapped(args, kwargs)
|
|
||||||
|
|
||||||
if context.executing_eagerly():
|
|
||||||
logging.log_first_n(
|
|
||||||
logging.WARN, "Using %s eagerly has significant "
|
|
||||||
"overhead currently. We will be working on improving "
|
|
||||||
"this in the future, but for now please wrap "
|
|
||||||
"`call_for_each_replica` or `experimental_run` or "
|
|
||||||
"`run` inside a tf.function to get the best performance." %
|
|
||||||
self._container_strategy().__class__.__name__, 5)
|
|
||||||
else:
|
|
||||||
# When a tf.function is wrapped to trigger _call_for_each_replica (see
|
|
||||||
# the other branch above), AutoGraph stops conversion at
|
|
||||||
# _call_for_each_replica itself (TF library functions are whitelisted).
|
|
||||||
# This makes sure that the Python function that originally passed to
|
|
||||||
# the tf.function is still converted.
|
|
||||||
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
|
||||||
|
|
||||||
return _call_for_each_replica(self._container_strategy(), self._devices,
|
|
||||||
fn, args, kwargs)
|
|
||||||
|
|
||||||
def _configure(self,
|
def _configure(self,
|
||||||
session_config=None,
|
session_config=None,
|
||||||
|
@ -912,203 +726,3 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||||
def _in_multi_worker_mode(self):
|
def _in_multi_worker_mode(self):
|
||||||
"""Whether this strategy indicates working in multi-worker settings."""
|
"""Whether this strategy indicates working in multi-worker settings."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class _MirroredReplicaThread(threading.Thread):
|
|
||||||
"""A thread that runs() a function on a device."""
|
|
||||||
|
|
||||||
def __init__(self, dist, coord, replica_id, devices, variable_creator_fn,
|
|
||||||
fn, args, kwargs):
|
|
||||||
super(_MirroredReplicaThread, self).__init__()
|
|
||||||
self.coord = coord
|
|
||||||
self.distribution = dist
|
|
||||||
self.devices = devices
|
|
||||||
self.replica_id = replica_id
|
|
||||||
self.variable_creator_fn = variable_creator_fn
|
|
||||||
# State needed to run and return the results of `fn`.
|
|
||||||
self.main_fn = fn
|
|
||||||
self.main_args = args
|
|
||||||
self.main_kwargs = kwargs
|
|
||||||
self.main_result = None
|
|
||||||
self.done = False
|
|
||||||
# State needed to run the next merge_call() (if any) requested via
|
|
||||||
# ReplicaContext.
|
|
||||||
self.merge_fn = None
|
|
||||||
self.merge_args = None
|
|
||||||
self.merge_kwargs = None
|
|
||||||
self.merge_result = None
|
|
||||||
self.captured_name_scope = None
|
|
||||||
self.captured_var_scope = None
|
|
||||||
# We use a thread.Event for the main thread to signal when this
|
|
||||||
# thread should start running (`should_run`), and another for
|
|
||||||
# this thread to transfer control back to the main thread
|
|
||||||
# (`has_paused`, either when it gets to a
|
|
||||||
# `get_replica_context().merge_call` or when `fn` returns). In
|
|
||||||
# either case the event starts cleared, is signaled by calling
|
|
||||||
# set(). The receiving thread waits for the signal by calling
|
|
||||||
# wait() and then immediately clearing the event using clear().
|
|
||||||
self.should_run = threading.Event()
|
|
||||||
self.has_paused = threading.Event()
|
|
||||||
# These fields have to do with inheriting various contexts from the
|
|
||||||
# parent thread:
|
|
||||||
context.ensure_initialized()
|
|
||||||
ctx = context.context()
|
|
||||||
self.in_eager = ctx.executing_eagerly()
|
|
||||||
self.record_thread_local_summary_state()
|
|
||||||
self.record_thread_local_eager_context_state()
|
|
||||||
self.context_device_policy = (
|
|
||||||
pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(
|
|
||||||
ctx._context_handle)) # pylint: disable=protected-access
|
|
||||||
self.graph = ops.get_default_graph()
|
|
||||||
with ops.init_scope():
|
|
||||||
self._init_in_eager = context.executing_eagerly()
|
|
||||||
self._init_graph = ops.get_default_graph()
|
|
||||||
self._variable_creator_stack = self.graph._variable_creator_stack[:] # pylint: disable=protected-access
|
|
||||||
self._var_scope = variable_scope.get_variable_scope()
|
|
||||||
# Adding a "/" at end lets us re-enter this scope later.
|
|
||||||
self._name_scope = self.graph.get_name_scope()
|
|
||||||
if self._name_scope:
|
|
||||||
self._name_scope += "/"
|
|
||||||
if self.replica_id > 0:
|
|
||||||
if not self._name_scope:
|
|
||||||
self._name_scope = ""
|
|
||||||
self._name_scope += "replica_%d/" % self.replica_id
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.should_run.wait()
|
|
||||||
self.should_run.clear()
|
|
||||||
try:
|
|
||||||
if self.coord.should_stop():
|
|
||||||
return
|
|
||||||
self.restore_thread_local_summary_state()
|
|
||||||
self.restore_thread_local_eager_context_state()
|
|
||||||
# TODO(josh11b): Use current logical device instead of 0 here.
|
|
||||||
with self.coord.stop_on_exception(), \
|
|
||||||
_enter_graph(self._init_graph, self._init_in_eager), \
|
|
||||||
_enter_graph(self.graph, self.in_eager,
|
|
||||||
self._variable_creator_stack), \
|
|
||||||
context.device_policy(self.context_device_policy), \
|
|
||||||
MirroredReplicaContext(self.distribution, constant_op.constant(
|
|
||||||
self.replica_id, dtypes.int32)), \
|
|
||||||
ops.device(self.devices[self.replica_id]), \
|
|
||||||
ops.name_scope(self._name_scope), \
|
|
||||||
variable_scope.variable_scope(
|
|
||||||
self._var_scope, reuse=self.replica_id > 0), \
|
|
||||||
variable_scope.variable_creator_scope(self.variable_creator_fn):
|
|
||||||
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
|
|
||||||
self.done = True
|
|
||||||
finally:
|
|
||||||
self.has_paused.set()
|
|
||||||
|
|
||||||
def record_thread_local_summary_state(self):
|
|
||||||
"""Record the thread local summary state in self."""
|
|
||||||
# TODO(slebedev): is this still relevant? the referenced bug is closed.
|
|
||||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
|
||||||
self._summary_step = summary_state.step
|
|
||||||
self._summary_writer = summary_state.writer
|
|
||||||
self._summary_recording = summary_state.is_recording
|
|
||||||
self._summary_recording_distribution_strategy = (
|
|
||||||
summary_state.is_recording_distribution_strategy)
|
|
||||||
|
|
||||||
def restore_thread_local_summary_state(self):
|
|
||||||
"""Restore thread local summary state from self."""
|
|
||||||
# TODO(slebedev): is this still relevant? the referenced bug is closed.
|
|
||||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
|
||||||
summary_state.step = self._summary_step
|
|
||||||
summary_state.writer = self._summary_writer
|
|
||||||
summary_state.is_recording = self._summary_recording
|
|
||||||
summary_state.is_recording_distribution_strategy = (
|
|
||||||
self._summary_recording_distribution_strategy)
|
|
||||||
|
|
||||||
def record_thread_local_eager_context_state(self):
|
|
||||||
ctx = context.context()
|
|
||||||
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
|
|
||||||
self._eager_context_op_callbacks = eager_context_state.op_callbacks
|
|
||||||
# TODO(b/125892694): record other fields in EagerContext.
|
|
||||||
|
|
||||||
def restore_thread_local_eager_context_state(self):
|
|
||||||
ctx = context.context()
|
|
||||||
eager_context_state = ctx._thread_local_data # pylint: disable=protected-access
|
|
||||||
eager_context_state.op_callbacks = self._eager_context_op_callbacks
|
|
||||||
# TODO(b/125892694): record other fields in EagerContext.
|
|
||||||
|
|
||||||
|
|
||||||
class MirroredReplicaContext(distribute_lib.ReplicaContext):
|
|
||||||
"""ReplicaContext used in MirroredStrategy.extended.call_for_each_replica().
|
|
||||||
|
|
||||||
Opened in `_MirroredReplicaThread`, to allow the user to invoke
|
|
||||||
`MirroredStrategy`'s specific implementation of `merge_call()`,
|
|
||||||
which works by delegating the function and its arguments to
|
|
||||||
the main thread (the one that invoked
|
|
||||||
`MirroredStrategy.extended.call_for_each_replica()`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _merge_call(self, fn, args, kwargs):
|
|
||||||
"""Delegate to the main thread to actually perform merge_call()."""
|
|
||||||
t = threading.current_thread() # a _MirroredReplicaThread
|
|
||||||
t.merge_fn = fn
|
|
||||||
t.merge_args = args
|
|
||||||
t.merge_kwargs = kwargs
|
|
||||||
t.captured_name_scope = t.graph.get_name_scope()
|
|
||||||
# Adding a "/" at end lets us re-enter this scope later.
|
|
||||||
if t.captured_name_scope:
|
|
||||||
t.captured_name_scope += "/"
|
|
||||||
|
|
||||||
t.captured_var_scope = variable_scope.get_variable_scope()
|
|
||||||
t.captured_control_deps = t.graph._current_control_dependencies() # pylint: disable=protected-access
|
|
||||||
|
|
||||||
# It is problematic if `merge_call` is called under a different graph other
|
|
||||||
# than the one that `_call_for_each_replica` is called under, there are
|
|
||||||
# 3 cases this can happen:
|
|
||||||
#
|
|
||||||
# 1. The `fn` passed to `_call_for_each_replica` is decorated with
|
|
||||||
# `tf.function` and there is a `merge_call` in `fn`. Since
|
|
||||||
# MirroredStrategy traces a separate function per thread (per device),
|
|
||||||
# and each trace takes a shared lock, the lock is never released by the
|
|
||||||
# first thread and subsequent replica threads cannot proceed to trace
|
|
||||||
# their own functions. This issue is addressed by always converting
|
|
||||||
# `_call_for_each_replica(tf.function(f))` to
|
|
||||||
# ``tf.function(_call_for_each_replica(f))`.` in
|
|
||||||
# `MirroredStrategy._call_for_each_replica`.
|
|
||||||
#
|
|
||||||
# 2. The `fn` passed to `_call_for_each_replica` contains a nested
|
|
||||||
# `tf.function`, and there is a `merge_call` in the nested `tf.function`.
|
|
||||||
# In this case each thread can successfully trace its own function, but
|
|
||||||
# since the `merge_fn` passed to `merge_call` is executed in the main
|
|
||||||
# thread (where `_call_for_each_replica` is executed), it can't access
|
|
||||||
# the tensors that come from different graphs.
|
|
||||||
#
|
|
||||||
# 3. The `fn` passed to `_call_for_each_replica` contains a control-flow
|
|
||||||
# statement, and there is a `merge_call` inside the control-flow body,
|
|
||||||
# `fn` or `_call_for_each_replica` is decorated with `tf.function`.
|
|
||||||
# Control flow statement creates a separate graph for its body, similar
|
|
||||||
# to #2, `merge_fn` executed in the main thread can't access the
|
|
||||||
# tensors that come from different graphs.
|
|
||||||
#
|
|
||||||
# We raise an error for #2 and #3.
|
|
||||||
if ops.get_default_graph() != t.graph:
|
|
||||||
raise RuntimeError(
|
|
||||||
"`merge_call` called while defining a new graph or a tf.function."
|
|
||||||
" This can often happen if the function `fn` passed to"
|
|
||||||
" `strategy.run()` contains a nested `@tf.function`, and the nested "
|
|
||||||
"`@tf.function` contains a synchronization point, such as aggregating"
|
|
||||||
" gradients (e.g, optimizer.apply_gradients), or if the function `fn`"
|
|
||||||
" uses a control flow statement which contains a synchronization"
|
|
||||||
" point in the body. Such behaviors are not yet supported. Instead,"
|
|
||||||
" please avoid nested `tf.function`s or control flow statements that"
|
|
||||||
" may potentially cross a synchronization boundary, for example,"
|
|
||||||
" wrap the `fn` passed to `strategy.run` or the entire `strategy.run`"
|
|
||||||
" inside a `tf.function` or move the control flow out of `fn`")
|
|
||||||
|
|
||||||
t.has_paused.set()
|
|
||||||
t.should_run.wait()
|
|
||||||
t.should_run.clear()
|
|
||||||
if t.coord.should_stop():
|
|
||||||
raise _RequestedStop()
|
|
||||||
return t.merge_result
|
|
||||||
|
|
||||||
@property
|
|
||||||
def devices(self):
|
|
||||||
distribute_lib.require_replica_context(self)
|
|
||||||
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
|
||||||
return [self._strategy.extended.worker_devices_by_replica[replica_id]]
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_run
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import numpy_dataset
|
from tensorflow.python.distribute import numpy_dataset
|
||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
|
@ -456,9 +456,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||||
return var_creator(**kwargs)
|
return var_creator(**kwargs)
|
||||||
|
|
||||||
def _call_for_each_replica(self, fn, args, kwargs):
|
def _call_for_each_replica(self, fn, args, kwargs):
|
||||||
# pylint: disable=protected-access
|
return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
|
||||||
return mirrored_strategy._call_for_each_replica(
|
args, kwargs)
|
||||||
self._container_strategy(), self._compute_devices, fn, args, kwargs)
|
|
||||||
|
|
||||||
def _verify_destinations_not_different_worker(self, destinations):
|
def _verify_destinations_not_different_worker(self, destinations):
|
||||||
if not self._cluster_spec:
|
if not self._cluster_spec:
|
||||||
|
|
|
@ -1841,7 +1841,6 @@ class AggregatingVariableTest(test.TestCase, parameterized.TestCase):
|
||||||
self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.)
|
self.assertEqual(self.evaluate(aggregating._v.read_value()), 3.)
|
||||||
|
|
||||||
def testAssignAdd(self, distribution):
|
def testAssignAdd(self, distribution):
|
||||||
self.skipTest("b/151250566")
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
v = variable_scope.variable(
|
v = variable_scope.variable(
|
||||||
1, aggregation=variables_lib.VariableAggregation.MEAN)
|
1, aggregation=variables_lib.VariableAggregation.MEAN)
|
||||||
|
|
Loading…
Reference in New Issue