Introduce MPI allreduce and allgather in a new contrib project (#12299)
* Allreduce: Rebase to TF 1.3-rc1 (#3) * Introduce MPI allreduce in a new contrib project. This commit adds the tensorflow.contrib.mpi namespace and contrib project, which has a variety of ops that work with MPI. The MPI system works by starting a background thread which communicates between the different processes at a regular interval and schedules asynchronous reductions. At every tick, every rank will notify rank zero of the tensors it is ready to reduce, signifying completion with an empty DONE message. Rank zero will count how many ranks are ready to reduce every tensor, and, whenever a tensor is ready to reduce (that is, every rank is ready to reduce it), rank zero will issue a message to all other ranks directing them to reduce that tensor. This repeats for all the tensors that are ready to reduce, after which rank zero sends all other ranks a DONE message indicating that the tick is complete. Reviewed-by: Joel Hestness <jthestness@gmail.com> * Allreduce/Allgather: Major changes and fixes (#2) This commit constitutes many major updates to the TF MPI allreduce and allgather ops. Specifically, the following changes are included in this commit: 1) The allreduce and allgather ops had race conditions, which this commit fixes. Specifically, the BackgroundThreadLoop previously allocated temporary and output tensors after the main graph traversal thread has completed its call to MPIAll*::ComputeAsync(). Unfortunately, the ops kernel context's memory allocator is only guaranteed to be valid during the ComputeAsync call. This constraint requires ComputeAsync to allocate all tensors before returning; Otherwise, the memory allocator state may reflect allocations and deallocations from further ops that can cause races for the memory locations. To fix this, hoist the memory allocations to ComputeAsync. In this process, introduce a collective op record, which tracks the parameters of the op (e.g. input, output, and configurations). 2) Many models require capability to allreduce or allgather int64 tensors. We add functionality to handle long long data type (64-bit ints). 3) Eliminate the thread sleep. A major to-do item is to eliminate the need for polling between coordinator threads and other ranks. This change will require the coordinator rank to be able to wake up all other ranks when a collective is ready to be performed, but also for all ranks (i.e. background threads) to be woken up by graph traversal threads. In the meantime, remove the thread sleep, because it introduces significant run time overhead (e.g. >20%) for models with quick-running layers (e.g. few recurrent time-steps or few hidden nodes per layer). * mpi_ops.cc: Move toward more TF nature This commit changes a few bits and pieces to align more closely with Tensorflow structures and organization: 1) Use TF mutexes. TF mutexes provide nice scoping and management around std::mutex, and using them is consistent with other TF code. 2) Remove thread sleep at MPI initialization time. Thread sleep should not be used for polling activity. Instead, this commit replaces sleep-polling with a condition variable: The compute graph traversal thread waits on the condition variable until the background thread has completed initialization and signals the graph traversal thread that initialization is complete. 3) Slim MPI initialization check: Since TF permits many threads to be traversing the compute graph concurrently (e.g. with inter_op_parallelism_threads > 1), some graph traversal threads may not have set their GPU device ID. If such a thread executes an MPI op, it would fail the check in InitializedMPIOnSameDevice, because the background thread would be controlling a GPU with ID other than the default (0). Since graph traversal threads do not perform GPU activity, this GPU ID check was unnecessary. Remove it and refactor to just check whether MPI is initialized (IsMPIInitialized). * Rebase to TF 1.3.0-rc1 complete and tested * Allreduce: Rebase to TF 1.3-rc1 (#3) * Introduce MPI allreduce in a new contrib project. This commit adds the tensorflow.contrib.mpi namespace and contrib project, which has a variety of ops that work with MPI. The MPI system works by starting a background thread which communicates between the different processes at a regular interval and schedules asynchronous reductions. At every tick, every rank will notify rank zero of the tensors it is ready to reduce, signifying completion with an empty DONE message. Rank zero will count how many ranks are ready to reduce every tensor, and, whenever a tensor is ready to reduce (that is, every rank is ready to reduce it), rank zero will issue a message to all other ranks directing them to reduce that tensor. This repeats for all the tensors that are ready to reduce, after which rank zero sends all other ranks a DONE message indicating that the tick is complete. Reviewed-by: Joel Hestness <jthestness@gmail.com> * Allreduce/Allgather: Major changes and fixes (#2) This commit constitutes many major updates to the TF MPI allreduce and allgather ops. Specifically, the following changes are included in this commit: 1) The allreduce and allgather ops had race conditions, which this commit fixes. Specifically, the BackgroundThreadLoop previously allocated temporary and output tensors after the main graph traversal thread has completed its call to MPIAll*::ComputeAsync(). Unfortunately, the ops kernel context's memory allocator is only guaranteed to be valid during the ComputeAsync call. This constraint requires ComputeAsync to allocate all tensors before returning; Otherwise, the memory allocator state may reflect allocations and deallocations from further ops that can cause races for the memory locations. To fix this, hoist the memory allocations to ComputeAsync. In this process, introduce a collective op record, which tracks the parameters of the op (e.g. input, output, and configurations). 2) Many models require capability to allreduce or allgather int64 tensors. We add functionality to handle long long data type (64-bit ints). 3) Eliminate the thread sleep. A major to-do item is to eliminate the need for polling between coordinator threads and other ranks. This change will require the coordinator rank to be able to wake up all other ranks when a collective is ready to be performed, but also for all ranks (i.e. background threads) to be woken up by graph traversal threads. In the meantime, remove the thread sleep, because it introduces significant run time overhead (e.g. >20%) for models with quick-running layers (e.g. few recurrent time-steps or few hidden nodes per layer). * mpi_ops.cc: Move toward more TF nature This commit changes a few bits and pieces to align more closely with Tensorflow structures and organization: 1) Use TF mutexes. TF mutexes provide nice scoping and management around std::mutex, and using them is consistent with other TF code. 2) Remove thread sleep at MPI initialization time. Thread sleep should not be used for polling activity. Instead, this commit replaces sleep-polling with a condition variable: The compute graph traversal thread waits on the condition variable until the background thread has completed initialization and signals the graph traversal thread that initialization is complete. 3) Slim MPI initialization check: Since TF permits many threads to be traversing the compute graph concurrently (e.g. with inter_op_parallelism_threads > 1), some graph traversal threads may not have set their GPU device ID. If such a thread executes an MPI op, it would fail the check in InitializedMPIOnSameDevice, because the background thread would be controlling a GPU with ID other than the default (0). Since graph traversal threads do not perform GPU activity, this GPU ID check was unnecessary. Remove it and refactor to just check whether MPI is initialized (IsMPIInitialized). * Rebase to TF 1.3.0-rc1 complete and tested * Minor fixes * Point MPI message proto at contrib/mpi package * MPI Session: Fix graph handling * Pylint fixes * More pylint fixes * Python 2 pylint fix * MPI Collectives Ops: Fix coordinator shut down * Update copyrights to 2017 * Remove MPIDataType and switch to TF DataType * Add Allgather test, fix Allreduce test config * Fix BUILD file for TF sanity checks * Try guarding MPI collectives C++ files with TENSORFLOW_USE_MPI The TF build system on Github tries to build C++ source files in tensorflow/contrib/mpi_collectives even when configured with TF_NEED_MPI=0. This leads to a build failure when the mpi_collectives C++ files try to link against MPI third party headers, which are not set up. Unable to reproduce in contributor's build environment, we try guarding the MPI collectives C++ code with defines for TENSORFLOW_USE_MPI, similar to tensorflow/contrib/mpi. * Comment formatting Hopefully, this will trigger googlebot.
This commit is contained in:
parent
a0bbeb10e2
commit
c8535f36d1
@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//third_party/mpi:mpi.bzl", "if_mpi")
|
||||
|
||||
py_library(
|
||||
name = "contrib_py",
|
||||
srcs = glob(["**/*.py"]),
|
||||
@ -84,7 +86,7 @@ py_library(
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
],
|
||||
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
80
tensorflow/contrib/mpi_collectives/BUILD
Normal file
80
tensorflow/contrib/mpi_collectives/BUILD
Normal file
@ -0,0 +1,80 @@
|
||||
# Ops that communicate with other processes via MPI.
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "mpi_message_proto",
|
||||
srcs = ["mpi_message.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = ["//tensorflow/core:protos_all"],
|
||||
visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "mpi_collectives.so",
|
||||
srcs = [
|
||||
"mpi_ops.cc",
|
||||
"ring.cc",
|
||||
"ring.h",
|
||||
],
|
||||
gpu_srcs = [
|
||||
"ring.cu.cc",
|
||||
"ring.h",
|
||||
],
|
||||
deps = [
|
||||
":mpi_message_proto_cc",
|
||||
"//third_party/mpi",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "mpi_ops_test",
|
||||
srcs = ["mpi_ops_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
data = [
|
||||
":mpi_collectives.so",
|
||||
],
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mpi_ops_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"mpi_ops.py",
|
||||
],
|
||||
data = [
|
||||
":mpi_collectives.so",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
5
tensorflow/contrib/mpi_collectives/README.md
Normal file
5
tensorflow/contrib/mpi_collectives/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# MPI TensorFlow integration
|
||||
|
||||
Tensorflow MPI integration allows communicating between different TensorFlow
|
||||
processes using MPI. This enables training across multiple nodes and GPUs
|
||||
using high-speed interconnects.
|
273
tensorflow/contrib/mpi_collectives/__init__.py
Normal file
273
tensorflow/contrib/mpi_collectives/__init__.py
Normal file
@ -0,0 +1,273 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=g-short-docstring-punctuation
|
||||
"""## Communicating Between Processes with MPI
|
||||
|
||||
TensorFlow natively provides inter-device communication through send and
|
||||
receive ops and inter-node communication through Distributed TensorFlow, based
|
||||
on the same send and receive abstractions. On HPC clusters where Infiniband or
|
||||
other high-speed node interconnects are available, these can end up being
|
||||
insufficient for synchronous data-parallel training (without asynchronous
|
||||
gradient descent). This module implements a variety of MPI ops which can take
|
||||
advantage of hardware-specific MPI libraries for efficient communication.
|
||||
|
||||
In order to use this module, TensorFlow must be built with an MPI library,
|
||||
which can be provided to the `./configure` script at build time. As a user of
|
||||
TensorFlow, you will need to build TensorFlow yourself to select the MPI
|
||||
library to use; to do so, follow the [instructions for building TensorFlow from
|
||||
source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources).
|
||||
|
||||
### Utility Ops
|
||||
|
||||
In addition to reductions and gathers, this module provides utility operations
|
||||
for detecting the running MPI configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from tensorflow.contrib import mpi
|
||||
|
||||
# Use `mpi.Session` instead of `tf.Session`
|
||||
with mpi.Session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
print("My MPI Rank:", rank)
|
||||
|
||||
if rank == 0:
|
||||
print("MPI Size:", session.run(mpi.size()))
|
||||
```
|
||||
|
||||
@@rank
|
||||
@@size
|
||||
|
||||
### Ring Allreduce and Allgather
|
||||
|
||||
When summing or averaging tensors across many processes, communication can
|
||||
easily become a bottleneck. A naive implementation will send all the tensor
|
||||
values to the same process, perform the reduction, and then broadcast the
|
||||
values back to all other processes, effectively creating a synchronous
|
||||
parameter server in one process. However, the process responsible for
|
||||
performing the reduction will have to receive and send a massive amount of data
|
||||
which scales with the number of processes *and* the number of parameters in the
|
||||
model.
|
||||
|
||||
Instead of centralizing the reduction and having one primary reducer, we can
|
||||
implement a distributed allreduce or allgather. A bandwidth-optimal allreduce
|
||||
will end up sending 2(N - 1) values for every value in the input tensor,
|
||||
and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce
|
||||
requires at least (N - 1) sends between the different nodes, and a broadcast of
|
||||
the result also requires (N - 1) sends, for a total of 2 (N - 1); these two
|
||||
steps cannot be combined in a clever way to reduce the number of required
|
||||
sends.) This module implements bandwidth-optimal ring allreduce and ring
|
||||
allgather operations using MPI; by choosing a hardware-appropriate MPI
|
||||
implementation (such as OpenMPI with CUDA-IPC support), you can train large
|
||||
models with synchronous gradient descent with minimal communication overhead.
|
||||
|
||||
In addition to the `allreduce` and `allgather` functions, a convenience
|
||||
`DistributedOptimizer` wrapper is provided to simplify using these functions
|
||||
for reducing model gradients.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib import mpi_collectives as mpi
|
||||
|
||||
# Construct a simple linear regression model to optimize
|
||||
W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32)
|
||||
B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32)
|
||||
inputs = tf.placeholder("Inputs", shape=[None, 20])
|
||||
outputs = tf.placeholder("Outputs", shape=[None, 1])
|
||||
loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs)
|
||||
|
||||
# Training using MPI allreduce with DistributedOptimizer
|
||||
optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer())
|
||||
train = optimizer.minimize(loss)
|
||||
|
||||
# Average loss over all ranks, for printing.
|
||||
# Do not pass this to an optimizer!
|
||||
avg_loss = mpi.allreduce(loss)
|
||||
|
||||
# On different ranks, feed different input data.
|
||||
with mpi.Session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
batch_inputs, batch_outputs = construct_batch_for_rank(rank)
|
||||
feed_dict = {inputs: batch_inputs, outputs: batch_outputs}
|
||||
_, l = session.run([train, avg_loss], feed_dict=feed_dict)
|
||||
print("Average Loss:", l)
|
||||
```
|
||||
|
||||
[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms
|
||||
for Clusters of Workstations".
|
||||
|
||||
@@Session
|
||||
@@DistributedOptimizer
|
||||
@@allreduce
|
||||
@@allgather
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.mpi_collectives.mpi_ops import size
|
||||
from tensorflow.contrib.mpi_collectives.mpi_ops import rank
|
||||
from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank
|
||||
from tensorflow.contrib.mpi_collectives.mpi_ops import allgather
|
||||
from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce
|
||||
from tensorflow.contrib.mpi_collectives.mpi_ops import init
|
||||
|
||||
|
||||
def allreduce(tensor, average=True):
|
||||
"""Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices.
|
||||
|
||||
Arguments:
|
||||
tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce.
|
||||
The shape of the input must be identical across all ranks.
|
||||
average: If True, computes the average over all ranks.
|
||||
Otherwise, computes the sum over all ranks.
|
||||
|
||||
This function performs a bandwidth-optimal ring allreduce on the input
|
||||
tensor. If the input is an tf.IndexedSlices, the function instead does an
|
||||
allgather on the values and the indices, effectively doing an allreduce on
|
||||
the represented tensor.
|
||||
"""
|
||||
if isinstance(tensor, tf.IndexedSlices):
|
||||
# For IndexedSlices, do two allgathers intead of an allreduce.
|
||||
mpi_size = tf.cast(size(), tensor.values.dtype)
|
||||
values = allgather(tensor.values)
|
||||
indices = allgather(tensor.indices)
|
||||
|
||||
# To make this operation into an average, divide all gathered values by
|
||||
# the MPI size.
|
||||
new_values = tf.div(values, mpi_size) if average else values
|
||||
return tf.IndexedSlices(new_values, indices,
|
||||
dense_shape=tensor.dense_shape)
|
||||
else:
|
||||
mpi_size = tf.cast(size(), tensor.dtype)
|
||||
summed_tensor = _allreduce(tensor)
|
||||
new_tensor = (tf.div(summed_tensor, mpi_size)
|
||||
if average else summed_tensor)
|
||||
return new_tensor
|
||||
|
||||
|
||||
class DistributedOptimizer(tf.train.Optimizer):
|
||||
"""An optimizer that wraps another tf.Optimizer, using an MPI allreduce to
|
||||
average gradient values before applying gradients to model weights."""
|
||||
|
||||
def __init__(self, optimizer, name=None, use_locking=False):
|
||||
"""Construct a new DistributedOptimizer, which uses another optimizer
|
||||
under the hood for computing single-process gradient values and
|
||||
applying gradient updates after the gradient values have been averaged
|
||||
across all the MPI ranks.
|
||||
|
||||
Args:
|
||||
optimizer: Optimizer to use for computing gradients and applying updates.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to "Distributed" followed by the provided
|
||||
optimizer type.
|
||||
use_locking: Whether to use locking when updating variables. See
|
||||
Optimizer.__init__ for more info.
|
||||
"""
|
||||
if name is None:
|
||||
name = "Distributed{}".format(type(optimizer).__name__)
|
||||
|
||||
self._optimizer = optimizer
|
||||
super(DistributedOptimizer, self).__init__(
|
||||
name=name, use_locking=use_locking)
|
||||
|
||||
def compute_gradients(self, *args, **kwargs):
|
||||
"""Compute gradients of all trainable variables.
|
||||
|
||||
See Optimizer.compute_gradients() for more info.
|
||||
|
||||
In DistributedOptimizer, compute_gradients() is overriden to also
|
||||
allreduce the gradients before returning them.
|
||||
"""
|
||||
gradients = (super(DistributedOptimizer, self)
|
||||
.compute_gradients(*args, **kwargs))
|
||||
return [(allreduce(gradient), var) for (gradient, var) in gradients]
|
||||
|
||||
def _apply_dense(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._apply_dense(*args, **kwargs)
|
||||
|
||||
def _apply_sparse(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._apply_sparse(*args, **kwargs)
|
||||
|
||||
def _apply_sparse_duplicate_indices(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._apply_sparse_duplicate_indices(*args,
|
||||
**kwargs)
|
||||
|
||||
def _prepare(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._prepare(*args, **kwargs)
|
||||
|
||||
def _create_slots(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._create_slots(*args, **kwargs)
|
||||
|
||||
def _valid_dtypes(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._valid_dtypes(*args, **kwargs)
|
||||
|
||||
def _finish(self, *args, **kwargs):
|
||||
"""Calls this same method on the underlying optimizer."""
|
||||
return self._optimizer._finish(*args, **kwargs)
|
||||
|
||||
|
||||
class Session(tf.Session):
|
||||
"""A class for running TensorFlow operations, with copies of the same graph
|
||||
running distributed across different MPI nodes.
|
||||
|
||||
The primary difference between `tf.Session` and
|
||||
`tf.contrib.mpi_collectives.Session` is that the MPI `Session` ensures that
|
||||
the `Session` options are correct for use with `tf.contrib.mpi`, and
|
||||
initializes MPI immediately upon the start of the session.
|
||||
"""
|
||||
|
||||
def __init__(self, target='', graph=None, config=None):
|
||||
"""Creates a new TensorFlow MPI session.
|
||||
|
||||
Unlike a normal `tf.Session`, an MPI Session may only use a single GPU,
|
||||
which must be specified in advance before the session is initialized.
|
||||
In addition, it only uses a single graph evaluation thread, and
|
||||
initializes MPI immediately upon starting.
|
||||
|
||||
If no `graph` argument is specified when constructing the session,
|
||||
the default graph will be launched in the session. If you are
|
||||
using more than one graph (created with `tf.Graph()` in the same
|
||||
process, you will have to use different sessions for each graph,
|
||||
but each graph can be used in multiple sessions. In this case, it
|
||||
is often clearer to pass the graph to be launched explicitly to
|
||||
the session constructor.
|
||||
|
||||
Args:
|
||||
target: (Optional.) The execution engine to connect to.
|
||||
graph: (Optional.) The `Graph` to be launched (described above).
|
||||
config: (Optional.) A `ConfigProto` protocol buffer with configuration
|
||||
options for the session.
|
||||
"""
|
||||
super(Session, self).__init__(target, graph, config=config)
|
||||
|
||||
# Initialize MPI on the relevant device.
|
||||
# TODO: Move this to library load and eliminate mpi.Session()
|
||||
if graph is None:
|
||||
graph = tf.get_default_graph()
|
||||
with graph.as_default():
|
||||
self.run(init())
|
96
tensorflow/contrib/mpi_collectives/mpi_allgather_test.py
Normal file
96
tensorflow/contrib/mpi_collectives/mpi_allgather_test.py
Normal file
@ -0,0 +1,96 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.mpi_collectives as mpi
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
average_allgather = False
|
||||
|
||||
|
||||
class AllgatherTest(test.TestCase):
|
||||
def checkAllgather(self, num_ranks, all_gathered, local_gathered):
|
||||
# Ensure that indices match.
|
||||
all_gat_ind = np.sort(all_gathered.indices)
|
||||
loc_gat_ind = np.sort(local_gathered.indices)
|
||||
assert(len(loc_gat_ind) == len(all_gat_ind))
|
||||
for i in range(len(loc_gat_ind)):
|
||||
assert(loc_gat_ind[i] == all_gat_ind[i])
|
||||
|
||||
# For each index, verify same values.
|
||||
local_checked = []
|
||||
for i in range(len(local_gathered.indices)):
|
||||
local_checked.append(False)
|
||||
for i in range(len(all_gathered.indices)):
|
||||
all_index = all_gathered.indices[i]
|
||||
# TODO(jthestness): Make this lookup quicker using sorting.
|
||||
loc_index = -1
|
||||
for j in range(len(local_gathered.indices)):
|
||||
if local_gathered.indices[j] == all_index and not local_checked[j]:
|
||||
loc_index = j
|
||||
local_checked[j] = True
|
||||
break
|
||||
assert(loc_index >= 0)
|
||||
correct_output = local_gathered.values[loc_index][0]
|
||||
if average_allgather:
|
||||
correct_output = correct_output / float(num_ranks)
|
||||
assert(all_gathered.values[i][0] == correct_output)
|
||||
|
||||
|
||||
def test_mpi_allgather(self):
|
||||
# Get MPI rank
|
||||
my_rank = int(os.environ['PMI_RANK'])
|
||||
num_ranks = int(os.environ['PMI_SIZE'])
|
||||
|
||||
indices_per_rank = 100
|
||||
tensor_width = 10
|
||||
|
||||
# Create IndexedSlices for each rank, some with overlapping indices.
|
||||
to_gather_indices = []
|
||||
to_gather_values = []
|
||||
to_gather = []
|
||||
for rank_id in range(num_ranks):
|
||||
indices = []
|
||||
values = []
|
||||
my_multiple = rank_id + 1
|
||||
current_index = my_multiple
|
||||
for i in range(indices_per_rank):
|
||||
indices.append(current_index)
|
||||
ones_tensor = tf.ones([tensor_width])
|
||||
values.append(tf.multiply(ones_tensor,
|
||||
tf.fill(ones_tensor.get_shape(),
|
||||
float(current_index))))
|
||||
current_index += my_multiple
|
||||
concat_ind = tf.stack(indices)
|
||||
concat_vals = tf.stack(values)
|
||||
to_gather_indices.append(concat_ind)
|
||||
to_gather_values.append(concat_vals)
|
||||
to_gather.append(tf.IndexedSlices(concat_vals, concat_ind))
|
||||
|
||||
# Collect the local IndexedSlices (indices and values) to create
|
||||
# correct IndexedSlices output.
|
||||
correct_gather_indices = tf.concat(to_gather_indices, 0)
|
||||
correct_gather_values = tf.concat(to_gather_values, 0)
|
||||
correct_gather = tf.IndexedSlices(correct_gather_values,
|
||||
correct_gather_indices)
|
||||
|
||||
all_gather = mpi.allreduce(to_gather[my_rank], average_allgather)
|
||||
|
||||
# NOTE: This assumes that device IDs are numbered the same as ranks.
|
||||
gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
|
||||
config = tf.ConfigProto(gpu_options=gpu_options)
|
||||
|
||||
# MPI Session to test allgather.
|
||||
with mpi.Session(config=config) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
all_gathered, local_gathered = sess.run([all_gather, correct_gather])
|
||||
|
||||
# Compare all_gathered with local_gathered.
|
||||
self.checkAllgather(num_ranks, all_gathered, local_gathered)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
136
tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
Normal file
136
tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
Normal file
@ -0,0 +1,136 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.mpi_collectives as mpi
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
average_allreduce = False
|
||||
max_wrong_count = -1
|
||||
|
||||
|
||||
class AllreduceTest(test.TestCase):
|
||||
def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red,
|
||||
our_correct):
|
||||
# Find reduced/allreduced indices that are wrong and print all the
|
||||
# values from output, slices, reduced, allreduced, so we can debug
|
||||
# which is incorrect:
|
||||
wrong_count = 0
|
||||
red_dims = out_loc_red.shape
|
||||
assert(len(red_dims) == 2)
|
||||
for i in range(red_dims[0]):
|
||||
for j in range(red_dims[1]):
|
||||
suffix = ""
|
||||
if out_loc_red[i][j] != my_correct[i][j] or \
|
||||
out_all_red[i][j] != our_correct[i][j]:
|
||||
suffix = "WRONG"
|
||||
wrong_count += 1
|
||||
print("{}\t{}\t{}\t{}\t{}\t{}"
|
||||
.format(my_rank, i, j, out_loc_red[i][j],
|
||||
out_all_red[i][j], suffix), flush=True)
|
||||
if max_wrong_count > 0 and wrong_count >= max_wrong_count:
|
||||
return
|
||||
|
||||
def test_mpi_allreduce(self):
|
||||
# Get MPI rank
|
||||
my_rank = int(os.environ['PMI_RANK'])
|
||||
num_ranks = int(os.environ['PMI_SIZE'])
|
||||
|
||||
stages = 13
|
||||
batch_size = 1331
|
||||
hidden_size = batch_size
|
||||
out_size = batch_size
|
||||
|
||||
# Input placeholder (batch_size x hidden) - init to 1s
|
||||
inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size),
|
||||
name="Input")
|
||||
|
||||
# Large matrices (hidden x out_dim) - init random
|
||||
weights = []
|
||||
for i in range(stages):
|
||||
initer = tf.constant_initializer(pow(2.0, i + 1.0))
|
||||
weights.append(tf.get_variable("weights_{}".format(i),
|
||||
shape=(hidden_size, out_size),
|
||||
dtype=tf.float32,
|
||||
initializer=initer))
|
||||
|
||||
# Calculate output through dependent allreduces
|
||||
stage_input = inputs
|
||||
for i in range(stages):
|
||||
inter_output = tf.add(stage_input, weights[i],
|
||||
name="add_red_{}".format(i))
|
||||
stage_input = mpi.allreduce(inter_output,
|
||||
average=average_allreduce)
|
||||
|
||||
all_reduced = stage_input
|
||||
|
||||
# Local reduced output for verification
|
||||
local_input = inputs
|
||||
for i in range(stages):
|
||||
inter_output = tf.add(local_input, weights[i],
|
||||
name="addin_loc_{}".format(i))
|
||||
my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)),
|
||||
dtype=tf.float32, name="loc_redr_{}".format(i))
|
||||
for r in range(num_ranks):
|
||||
my_reducer = tf.add(my_reducer, inter_output,
|
||||
name="add_loc_{}_{}".format(i, r))
|
||||
if average_allreduce:
|
||||
local_input = tf.div(my_reducer, num_ranks,
|
||||
name="div_loc_{}".format(i))
|
||||
else:
|
||||
local_input = my_reducer
|
||||
|
||||
local_reduced = local_input
|
||||
|
||||
# NOTE: This assumes that device IDs are numbered the same as ranks
|
||||
gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
|
||||
config = tf.ConfigProto(gpu_options=gpu_options)
|
||||
|
||||
# MPI Session to test allreduce
|
||||
with mpi.Session(config=config) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
input_feed = np.ones((batch_size, hidden_size), dtype=np.float32)
|
||||
our_output = input_feed[0][0]
|
||||
spread_var = 100
|
||||
input_feed = input_feed + my_rank * spread_var
|
||||
my_output = input_feed[0][0]
|
||||
for i in range(stages):
|
||||
curr_feed = my_output + pow(2.0, i + 1.0)
|
||||
my_output = curr_feed * num_ranks + 1
|
||||
curr_our_feed = our_output + pow(2.0, i + 1.0)
|
||||
if i == 0:
|
||||
sum_ranks = num_ranks * (num_ranks - 1) / 2
|
||||
our_output = curr_our_feed * num_ranks + \
|
||||
spread_var * sum_ranks
|
||||
else:
|
||||
our_output = curr_our_feed * num_ranks
|
||||
|
||||
print("rank {}: My output is {}".format(my_rank, my_output))
|
||||
my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
|
||||
my_correct = my_correct + my_output
|
||||
print("rank {}: Our output is {}".format(my_rank, our_output))
|
||||
our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
|
||||
our_correct = our_correct + our_output
|
||||
|
||||
for i in range(1000):
|
||||
if i % 100 == 0:
|
||||
print("{}: iter {}".format(my_rank, i), flush=True)
|
||||
feed_dict = {inputs: input_feed}
|
||||
out_all_red, out_loc_red \
|
||||
= sess.run([all_reduced, local_reduced],
|
||||
feed_dict=feed_dict)
|
||||
|
||||
if not np.allclose(out_loc_red, my_correct) or \
|
||||
not np.allclose(out_all_red, our_correct):
|
||||
print("Test incorrect on iter {}".format(i), flush=True)
|
||||
self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red,
|
||||
our_correct)
|
||||
assert(np.allclose(out_loc_red, my_correct) and
|
||||
np.allclose(out_all_red, our_correct))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
49
tensorflow/contrib/mpi_collectives/mpi_message.proto
Normal file
49
tensorflow/contrib/mpi_collectives/mpi_message.proto
Normal file
@ -0,0 +1,49 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.contrib.mpi;
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
// An MPIRequest is a message sent from a rank greater than zero to the
|
||||
// coordinator (rank zero), informing the coordinator of an operation that
|
||||
// the rank wants to do and the tensor that it wants to apply the operation to.
|
||||
message MPIRequest {
|
||||
enum RequestType {
|
||||
ALLREDUCE = 0;
|
||||
ALLGATHER = 1;
|
||||
}
|
||||
|
||||
// The request rank is necessary to create a consistent ordering of results,
|
||||
// for example in the allgather where the order of outputs should be sorted
|
||||
// by rank.
|
||||
int32 request_rank = 1;
|
||||
RequestType request_type = 2;
|
||||
DataType tensor_type = 3;
|
||||
string tensor_name = 4;
|
||||
TensorShapeProto tensor_shape = 5;
|
||||
};
|
||||
|
||||
// An MPIResponse is a message sent from the coordinator (rank zero) to a rank
|
||||
// greater than zero, informing the rank of an operation should be performed
|
||||
// now. If the operation requested would result in an error (for example, due
|
||||
// to a type or shape mismatch), then the MPIResponse can contain an error and
|
||||
// an error message instead. Finally, an MPIResponse can be a DONE message (if
|
||||
// there are no more tensors to reduce on this tick of the background loop) or
|
||||
// SHUTDOWN if all MPI processes should shut down.
|
||||
message MPIResponse {
|
||||
enum ResponseType {
|
||||
ALLREDUCE = 0;
|
||||
ALLGATHER = 1;
|
||||
ERROR = 2;
|
||||
DONE = 3;
|
||||
SHUTDOWN = 4;
|
||||
}
|
||||
|
||||
// Empty if the type is DONE or SHUTDOWN.
|
||||
ResponseType response_type = 1;
|
||||
string tensor_name = 2;
|
||||
|
||||
// Empty unless response_type is ERROR.
|
||||
string error_message = 3;
|
||||
};
|
1241
tensorflow/contrib/mpi_collectives/mpi_ops.cc
Normal file
1241
tensorflow/contrib/mpi_collectives/mpi_ops.cc
Normal file
File diff suppressed because it is too large
Load Diff
165
tensorflow/contrib/mpi_collectives/mpi_ops.py
Normal file
165
tensorflow/contrib/mpi_collectives/mpi_ops.py
Normal file
@ -0,0 +1,165 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Inter-process communication using MPI."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def _load_library(name, op_list=None):
|
||||
"""Loads a .so file containing the specified operators.
|
||||
|
||||
Args:
|
||||
name: The name of the .so file to load.
|
||||
op_list: A list of names of operators that the library should have. If None
|
||||
then the .so file's contents will not be verified.
|
||||
|
||||
Raises:
|
||||
NameError if one of the required ops is missing.
|
||||
"""
|
||||
try:
|
||||
filename = resource_loader.get_path_to_datafile(name)
|
||||
library = load_library.load_op_library(filename)
|
||||
for expected_op in (op_list or []):
|
||||
for lib_op in library.OP_LIST.op:
|
||||
if lib_op.name == expected_op:
|
||||
break
|
||||
else:
|
||||
raise NameError(
|
||||
'Could not find operator %s in dynamic library %s' %
|
||||
(expected_op, name))
|
||||
return library
|
||||
except errors.NotFoundError:
|
||||
logging.warning('%s file could not be loaded.', name)
|
||||
|
||||
|
||||
MPI_LIB = _load_library('mpi_collectives.so', ['MPISize', 'MPIRank',
|
||||
'MPILocalRank', 'MPIAllgather',
|
||||
'MPIAllreduce'])
|
||||
|
||||
|
||||
def size(name=None):
|
||||
"""An op which returns the number of MPI processes.
|
||||
|
||||
This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
|
||||
size of the global communicator.
|
||||
|
||||
Returns:
|
||||
An integer scalar containing the number of MPI processes.
|
||||
"""
|
||||
return MPI_LIB.mpi_size(name=name)
|
||||
|
||||
|
||||
ops.NotDifferentiable('MPISize')
|
||||
|
||||
|
||||
def rank(name=None):
|
||||
"""An op which returns the MPI rank of the calling process.
|
||||
|
||||
This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
|
||||
rank of the current process in the global communicator.
|
||||
|
||||
Returns:
|
||||
An integer scalar with the MPI rank of the calling process.
|
||||
"""
|
||||
return MPI_LIB.mpi_rank(name=name)
|
||||
|
||||
|
||||
ops.NotDifferentiable('MPIRank')
|
||||
|
||||
|
||||
def init(name=None):
|
||||
"""An op which initializes MPI on the device on which it is run.
|
||||
|
||||
All future MPI ops must be run on the same device that the `init` op was run
|
||||
on.
|
||||
"""
|
||||
return MPI_LIB.mpi_init(name=name)
|
||||
|
||||
|
||||
ops.NotDifferentiable('MPIInit')
|
||||
|
||||
|
||||
def local_rank(name=None):
|
||||
"""An op which returns the local MPI rank of the calling process, within the
|
||||
node that it is running on. For example, if there are seven processes running
|
||||
on a node, their local ranks will be zero through six, inclusive.
|
||||
|
||||
This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
|
||||
which only includes processes on the same node.
|
||||
|
||||
Returns:
|
||||
An integer scalar with the local MPI rank of the calling process.
|
||||
"""
|
||||
return MPI_LIB.mpi_local_rank(name=name)
|
||||
|
||||
|
||||
ops.NotDifferentiable('MPILocalRank')
|
||||
|
||||
|
||||
def _allreduce(tensor, name=None):
|
||||
"""An op which sums an input tensor over all the MPI processes.
|
||||
|
||||
The reduction operation is keyed by the name of the op. The tensor type and
|
||||
shape must be the same on all MPI processes for a given name. The reduction
|
||||
will not start until all processes are ready to send and receive the tensor.
|
||||
|
||||
Returns:
|
||||
A tensor of the same shape and type as `tensor`, summed across all
|
||||
processes.
|
||||
"""
|
||||
return MPI_LIB.mpi_allreduce(tensor, name=name)
|
||||
|
||||
|
||||
ops.NotDifferentiable('MPIAllreduce')
|
||||
|
||||
|
||||
def allgather(tensor, name=None):
|
||||
"""An op which concatenates the input tensor with the same input tensor on
|
||||
all other MPI processes.
|
||||
|
||||
The concatenation is done on the first dimension, so the input tensors on the
|
||||
different processes must have the same rank and shape, except for the first
|
||||
dimension, which is allowed to be different.
|
||||
|
||||
Returns:
|
||||
A tensor of the same type as `tensor`, concatenated on dimension zero
|
||||
across all processes. The shape is identical to the input shape, except for
|
||||
the first dimension, which may be greater and is the sum of all first
|
||||
dimensions of the tensors in different MPI processes.
|
||||
"""
|
||||
# Specify that first allgather is to collect the tensor gather sizes,
|
||||
# indicated by passing in a scalar (0-D tensor) of value 0
|
||||
sizes_flag = tf.constant(0, dtype=tf.int64, name="size_flag_const")
|
||||
my_size = tf.slice(tf.shape(tensor, out_type=tf.int64), [0], [1], name="size_slice")
|
||||
if name is None:
|
||||
name = "allgather"
|
||||
sizing_name = "{}_sizing".format(name)
|
||||
sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name)
|
||||
return MPI_LIB.mpi_allgather(tensor, sizes, name=name)
|
||||
|
||||
|
||||
ops.NotDifferentiable('MPIAllgather')
|
||||
|
||||
|
296
tensorflow/contrib/mpi_collectives/mpi_ops_test.py
Normal file
296
tensorflow/contrib/mpi_collectives/mpi_ops_test.py
Normal file
@ -0,0 +1,296 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Tests for tensorflow.contrib.mpi_collectives.mpi_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import itertools
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import tensorflow.contrib.mpi_collectives as mpi
|
||||
|
||||
|
||||
def mpi_env_rank_and_size():
|
||||
"""Get MPI rank and size from environment variables and return them as a
|
||||
tuple of integers.
|
||||
|
||||
Most MPI implementations have an `mpirun` or `mpiexec` command that will
|
||||
run an MPI executable and set up all communication necessary between the
|
||||
different processors. As part of that set up, they will set environment
|
||||
variables that contain the rank and size of the MPI_COMM_WORLD
|
||||
communicator. We can read those environment variables from Python in order
|
||||
to ensure that `mpi.rank()` and `mpi.size()` return the expected values.
|
||||
|
||||
Since MPI is just a standard, not an implementation, implementations
|
||||
typically choose their own environment variable names. This function tries
|
||||
to support several different implementation, but really it only needs to
|
||||
support whatever implementation we want to use for the TensorFlow test
|
||||
suite.
|
||||
|
||||
If this is not running under MPI, then defaults of rank zero and size one
|
||||
are returned. (This is appropriate because when you call MPI_Init in an
|
||||
application not started with mpirun, it will create a new independent
|
||||
communicator with only one process in it.)
|
||||
"""
|
||||
rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split()
|
||||
size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split()
|
||||
|
||||
for rank_var, size_var in zip(rank_env, size_env):
|
||||
rank = os.environ.get(rank_var)
|
||||
size = os.environ.get(size_var)
|
||||
if rank is not None and size is not None:
|
||||
return int(rank), int(size)
|
||||
|
||||
# Default to rank zero and size one if there are no environment variables
|
||||
return 0, 1
|
||||
|
||||
|
||||
class MPITests(tf.test.TestCase):
|
||||
"""
|
||||
Tests for MPI ops in tensorflow.contrib.mpi_collectives.
|
||||
"""
|
||||
|
||||
def test_mpi_rank(self):
|
||||
"""Test that the rank returned by mpi.rank() is correct."""
|
||||
true_rank, _ = mpi_env_rank_and_size()
|
||||
with self.test_session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
self.assertEqual(true_rank, rank)
|
||||
|
||||
def test_mpi_size(self):
|
||||
"""Test that the size returned by mpi.size() is correct."""
|
||||
_, true_size = mpi_env_rank_and_size()
|
||||
with self.test_session() as session:
|
||||
size = session.run(mpi.size())
|
||||
self.assertEqual(true_size, size)
|
||||
|
||||
def test_mpi_allreduce_cpu(self):
|
||||
"""Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors."""
|
||||
with self.test_session() as session:
|
||||
size = session.run(mpi.size())
|
||||
|
||||
dtypes = [tf.int32, tf.float32]
|
||||
dims = [1, 2, 3]
|
||||
for dtype, dim in itertools.product(dtypes, dims):
|
||||
tf.set_random_seed(1234)
|
||||
tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype)
|
||||
summed = mpi.allreduce(tensor, average=False)
|
||||
multiplied = tensor * size
|
||||
max_difference = tf.reduce_max(tf.abs(summed - multiplied))
|
||||
|
||||
# Threshold for floating point equality depends on number of
|
||||
# ranks, since we're comparing against precise multiplication.
|
||||
if size <= 3:
|
||||
threshold = 0
|
||||
elif size < 10:
|
||||
threshold = 1e-4
|
||||
elif size < 15:
|
||||
threshold = 5e-4
|
||||
else:
|
||||
break
|
||||
|
||||
diff = session.run(max_difference)
|
||||
self.assertTrue(diff <= threshold,
|
||||
"mpi.allreduce produces incorrect results")
|
||||
|
||||
def test_mpi_allreduce_gpu(self):
|
||||
"""Test that the allreduce works on GPUs.
|
||||
|
||||
This test will crash badly if used with an MPI implementation that does
|
||||
not support GPU memory transfers directly, as it will call MPI_Send on
|
||||
a GPU data pointer."""
|
||||
# Only do this test if there are GPUs available.
|
||||
if not tf.test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
|
||||
no_gpus = tf.GPUOptions(visible_device_list="")
|
||||
cpu_config = tf.ConfigProto(gpu_options=no_gpus)
|
||||
with self.test_session(config=cpu_config) as session:
|
||||
local_rank = session.run(mpi.local_rank())
|
||||
|
||||
one_gpu = tf.GPUOptions(visible_device_list=str(local_rank))
|
||||
gpu_config = tf.ConfigProto(gpu_options=one_gpu)
|
||||
with self.test_session(config=gpu_config) as session:
|
||||
size = session.run(mpi.size())
|
||||
|
||||
dtype = tf.float32
|
||||
dim = 3
|
||||
with tf.device("/gpu:0"):
|
||||
tf.set_random_seed(1234)
|
||||
tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype)
|
||||
summed = mpi.allreduce(tensor, average=False)
|
||||
multiplied = tensor * size
|
||||
max_difference = tf.reduce_max(tf.abs(summed - multiplied))
|
||||
|
||||
# Threshold for floating point equality depends on number of
|
||||
# ranks, since we're comparing against precise multiplication.
|
||||
if size <= 3:
|
||||
threshold = 0
|
||||
elif size < 10:
|
||||
threshold = 1e-4
|
||||
elif size < 15:
|
||||
threshold = 5e-4
|
||||
else:
|
||||
return
|
||||
|
||||
diff = session.run(max_difference)
|
||||
self.assertTrue(diff <= threshold,
|
||||
"mpi.allreduce on GPU produces incorrect results")
|
||||
|
||||
def test_mpi_allreduce_error(self):
|
||||
"""Test that the allreduce raises an error if different ranks try to
|
||||
send tensors of different rank or dimension."""
|
||||
with self.test_session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
size = session.run(mpi.size())
|
||||
|
||||
# This test does not apply if there is only one worker.
|
||||
if size == 1:
|
||||
return
|
||||
|
||||
# Same rank, different dimension
|
||||
tf.set_random_seed(1234)
|
||||
dims = [17 + rank] * 3
|
||||
tensor = tf.random_uniform(dims, -1.0, 1.0)
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
session.run(mpi.allreduce(tensor))
|
||||
|
||||
# Same number of elements, different rank
|
||||
tf.set_random_seed(1234)
|
||||
if rank == 0:
|
||||
dims = [17, 23 * 57]
|
||||
else:
|
||||
dims = [17, 23, 57]
|
||||
tensor = tf.random_uniform(dims, -1.0, 1.0)
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
session.run(mpi.allreduce(tensor))
|
||||
|
||||
def test_mpi_allreduce_type_error(self):
|
||||
"""Test that the allreduce raises an error if different ranks try to
|
||||
send tensors of different type."""
|
||||
with self.test_session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
size = session.run(mpi.size())
|
||||
|
||||
# This test does not apply if there is only one worker.
|
||||
if size == 1:
|
||||
return
|
||||
|
||||
# Same rank, different dimension
|
||||
dims = [17] * 3
|
||||
tensor = tf.ones(dims, dtype=tf.int32 if rank % 2 == 0 else tf.float32)
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
session.run(mpi.allreduce(tensor))
|
||||
|
||||
def test_mpi_allgather(self):
|
||||
"""Test that the allgather correctly gathers 1D, 2D, 3D tensors."""
|
||||
with self.test_session() as session:
|
||||
size = session.run(mpi.size())
|
||||
rank = session.run(mpi.rank())
|
||||
|
||||
dtypes = tf.int32, tf.float32
|
||||
dims = 1, 2, 3
|
||||
for dtype, dim in itertools.product(dtypes, dims):
|
||||
tensor = tf.ones([17] * dim, dtype=dtype) * rank
|
||||
gathered = mpi.allgather(tensor)
|
||||
|
||||
gathered_tensor = session.run(gathered)
|
||||
self.assertEqual(list(gathered_tensor.shape),
|
||||
[17 * size] + [17] * (dim - 1))
|
||||
|
||||
for i in range(size):
|
||||
rank_tensor = tf.slice(gathered_tensor, [i * 17] + [0] * (dim - 1),
|
||||
[17] + [-1] * (dim - 1))
|
||||
self.assertEqual(list(rank_tensor.shape), [17] * dim)
|
||||
self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))),
|
||||
"mpi.allgather produces incorrect gathered tensor")
|
||||
|
||||
def test_mpi_allgather_variable_size(self):
|
||||
"""Test that the allgather correctly gathers 1D, 2D, 3D tensors,
|
||||
even if those tensors have different sizes along the first dim."""
|
||||
with self.test_session() as session:
|
||||
size = session.run(mpi.size())
|
||||
rank = session.run(mpi.rank())
|
||||
|
||||
dtypes = tf.int32, tf.float32
|
||||
dims = 1, 2, 3
|
||||
for dtype, dim in itertools.product(dtypes, dims):
|
||||
# Support tests up to MPI Size of 35
|
||||
if size > 35:
|
||||
break
|
||||
|
||||
tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5
|
||||
tensor_sizes = tensor_sizes[:size]
|
||||
|
||||
tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1),
|
||||
dtype=dtype) * rank
|
||||
gathered = mpi.allgather(tensor)
|
||||
|
||||
gathered_tensor = session.run(gathered)
|
||||
expected_size = sum(tensor_sizes)
|
||||
self.assertEqual(list(gathered_tensor.shape),
|
||||
[expected_size] + [17] * (dim - 1))
|
||||
|
||||
for i in range(size):
|
||||
rank_size = [tensor_sizes[i]] + [17] * (dim - 1)
|
||||
rank_tensor = tf.slice(gathered,
|
||||
[sum(tensor_sizes[:i])] + [0] * (dim - 1),
|
||||
rank_size)
|
||||
self.assertEqual(list(rank_tensor.shape), rank_size)
|
||||
self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))),
|
||||
"mpi.allgather produces incorrect gathered tensor")
|
||||
|
||||
def test_mpi_allgather_error(self):
|
||||
"""Test that the allgather returns an error if any dimension besides
|
||||
the first is different among the tensors being gathered."""
|
||||
with self.test_session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
size = session.run(mpi.size())
|
||||
|
||||
# This test does not apply if there is only one worker.
|
||||
if size == 1:
|
||||
return
|
||||
|
||||
tensor_size = [17] * 3
|
||||
tensor_size[1] = 10 * (rank + 1)
|
||||
tensor = tf.ones(tensor_size, dtype=tf.float32) * rank
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
session.run(mpi.allgather(tensor))
|
||||
|
||||
def test_mpi_allgather_type_error(self):
|
||||
"""Test that the allgather returns an error if the types being gathered
|
||||
differ among the processes"""
|
||||
with self.test_session() as session:
|
||||
rank = session.run(mpi.rank())
|
||||
size = session.run(mpi.size())
|
||||
|
||||
# This test does not apply if there is only one worker.
|
||||
if size == 1:
|
||||
return
|
||||
|
||||
tensor_size = [17] * 3
|
||||
dtype = tf.int32 if rank % 2 == 0 else tf.float32
|
||||
tensor = tf.ones(tensor_size, dtype=dtype) * rank
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
session.run(mpi.allgather(tensor))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
59
tensorflow/contrib/mpi_collectives/ring.cc
Normal file
59
tensorflow/contrib/mpi_collectives/ring.cc
Normal file
@ -0,0 +1,59 @@
|
||||
#ifdef TENSORFLOW_USE_MPI
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/contrib/mpi_collectives/ring.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace contrib {
|
||||
namespace mpi {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
extern template MPI_Datatype MPIType<float>();
|
||||
extern template MPI_Datatype MPIType<int>();
|
||||
extern template MPI_Datatype MPIType<long long>();
|
||||
extern template DataType TensorFlowDataType<float>();
|
||||
extern template DataType TensorFlowDataType<int>();
|
||||
extern template DataType TensorFlowDataType<long long>();
|
||||
|
||||
|
||||
// Generate all necessary specializations for RingAllreduce.
|
||||
template Status RingAllreduce<CPUDevice, int>(
|
||||
OpKernelContext*, const Tensor*, Tensor*, Tensor*);
|
||||
template Status RingAllreduce<CPUDevice, long long>(
|
||||
OpKernelContext*, const Tensor*, Tensor*, Tensor*);
|
||||
template Status RingAllreduce<CPUDevice, float>(
|
||||
OpKernelContext*, const Tensor*, Tensor*, Tensor*);
|
||||
|
||||
// Generate all necessary specializations for RingAllgather.
|
||||
template Status RingAllgather<CPUDevice, int>(
|
||||
OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
|
||||
template Status RingAllgather<CPUDevice, long long>(
|
||||
OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
|
||||
template Status RingAllgather<CPUDevice, float>(
|
||||
OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
|
||||
|
||||
// Copy data on a CPU using a straight-forward memcpy.
|
||||
template<> void CopyTensorData<CPUDevice>(void* dst, void* src, size_t size) {
|
||||
std::memcpy(dst, src, size);
|
||||
};
|
||||
|
||||
// Accumulate values on a CPU.
|
||||
#define GENERATE_ACCUMULATE(type) \
|
||||
template<> void AccumulateTensorData<CPUDevice, type>( \
|
||||
type* dst, type* src, size_t size) { \
|
||||
for (unsigned int i = 0; i < size; i++) { \
|
||||
dst[i] += src[i]; \
|
||||
} \
|
||||
};
|
||||
GENERATE_ACCUMULATE(int);
|
||||
GENERATE_ACCUMULATE(long long);
|
||||
GENERATE_ACCUMULATE(float);
|
||||
#undef GENERATE_ACCUMULATE
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_USE_MPI
|
78
tensorflow/contrib/mpi_collectives/ring.cu.cc
Normal file
78
tensorflow/contrib/mpi_collectives/ring.cu.cc
Normal file
@ -0,0 +1,78 @@
|
||||
#ifdef TENSORFLOW_USE_MPI
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/contrib/mpi_collectives/ring.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace contrib {
|
||||
namespace mpi {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
template<> MPI_Datatype MPIType<float>() { return MPI_FLOAT; };
|
||||
template<> MPI_Datatype MPIType<int>() { return MPI_INT; };
|
||||
template<> MPI_Datatype MPIType<long long>() { return MPI_LONG_LONG; };
|
||||
|
||||
template<> DataType TensorFlowDataType<float>() { return DT_FLOAT; };
|
||||
template<> DataType TensorFlowDataType<int>() { return DT_INT32; };
|
||||
template<> DataType TensorFlowDataType<long long>() { return DT_INT64; };
|
||||
|
||||
// Generate all necessary specializations for RingAllreduce.
|
||||
template Status RingAllreduce<GPUDevice, int>(
|
||||
OpKernelContext*, const Tensor*, Tensor*, Tensor*);
|
||||
template Status RingAllreduce<GPUDevice, long long>(
|
||||
OpKernelContext*, const Tensor*, Tensor*, Tensor*);
|
||||
template Status RingAllreduce<GPUDevice, float>(
|
||||
OpKernelContext*, const Tensor*, Tensor*, Tensor*);
|
||||
|
||||
// Generate all necessary specializations for RingAllgather.
|
||||
template Status RingAllgather<GPUDevice, int>(
|
||||
OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
|
||||
template Status RingAllgather<GPUDevice, long long>(
|
||||
OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
|
||||
template Status RingAllgather<GPUDevice, float>(
|
||||
OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
|
||||
|
||||
// Synchronously copy data on the GPU, using a different stream than the default
|
||||
// and than TensorFlow to avoid synchronizing on operations unrelated to the
|
||||
// allreduce.
|
||||
template<> void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
|
||||
auto stream = CudaStreamForMPI();
|
||||
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
};
|
||||
|
||||
// Elementwise accumulation kernel for GPU.
|
||||
template <typename T>
|
||||
__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
i < N;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
out[i] += in[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronously accumulate tensors on the GPU, using a different stream than
|
||||
// the default and than TensorFlow to avoid synchronizing on operations
|
||||
// unrelated to the allreduce.
|
||||
#define GENERATE_ACCUMULATE(type) \
|
||||
template<> void AccumulateTensorData<GPUDevice, type>( \
|
||||
type* dst, type* src, size_t size) { \
|
||||
auto stream = CudaStreamForMPI(); \
|
||||
elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size); \
|
||||
cudaStreamSynchronize(stream); \
|
||||
};
|
||||
GENERATE_ACCUMULATE(int);
|
||||
GENERATE_ACCUMULATE(long long);
|
||||
GENERATE_ACCUMULATE(float);
|
||||
#undef GENERATE_ACCUMULATE
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_USE_MPI
|
312
tensorflow/contrib/mpi_collectives/ring.h
Normal file
312
tensorflow/contrib/mpi_collectives/ring.h
Normal file
@ -0,0 +1,312 @@
|
||||
#ifndef TENSORFLOW_CONTRIB_MPI_H_
|
||||
#define TENSORFLOW_CONTRIB_MPI_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_MPI
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "cuda_runtime.h"
|
||||
#endif
|
||||
|
||||
// Needed to avoid header issues with C++-supporting MPI implementations
|
||||
#define OMPI_SKIP_MPICXX
|
||||
#include "third_party/mpi/mpi.h"
|
||||
|
||||
#define TAG_TENSOR 12
|
||||
|
||||
namespace tensorflow {
|
||||
namespace contrib {
|
||||
namespace mpi {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
// Convert from templated types to values we can pass to MPI.
|
||||
template<typename T>
|
||||
MPI_Datatype MPIType();
|
||||
|
||||
// Convert from templated types to TensorFlow data types.
|
||||
template<typename T>
|
||||
DataType TensorFlowDataType();
|
||||
|
||||
#define MPI_REQUIRES_OK(MPI_STATUS) \
|
||||
if ((MPI_STATUS) != MPI_SUCCESS) { \
|
||||
return errors::Unknown("MPI operation failed unexpectedly."); \
|
||||
}
|
||||
|
||||
// Copy data from one tensor to another tensor.
|
||||
// This uses a custom CUDA stream on GPU, which is necessary to overlay the
|
||||
// backpropagation computations with the allreduce.
|
||||
template <typename Device>
|
||||
void CopyTensorData(void* destination, void* source, size_t size);
|
||||
|
||||
// Add a tensor into another tensor, accumulating in place.
|
||||
// This uses a custom CUDA stream on GPU, which is necessary to overlay the
|
||||
// backpropagation computations with the allreduce.
|
||||
template <typename Device, typename T>
|
||||
void AccumulateTensorData(T* destination, T* source, size_t size);
|
||||
|
||||
// We need to get the right stream for doing CUDA memory transfers and
|
||||
// operations, which is possibly different from the standard TensorFlow stream.
|
||||
#if GOOGLE_CUDA
|
||||
cudaStream_t CudaStreamForMPI();
|
||||
#endif
|
||||
|
||||
/* Perform a ring allreduce on the data. Allocate the necessary output tensor and
|
||||
* store it in the output parameter.
|
||||
*
|
||||
* Assumes that all MPI processes are doing an allreduce of the same tensor,
|
||||
* with the same dimensions.
|
||||
*
|
||||
* A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the allreduce,
|
||||
* the nodes involved are arranged in a ring:
|
||||
*
|
||||
* .--0--.
|
||||
* / \
|
||||
* 3 1
|
||||
* \ /
|
||||
* *--2--*
|
||||
*
|
||||
* Each node always sends to the next clockwise node in the ring, and receives
|
||||
* from the previous one.
|
||||
*
|
||||
* The allreduce is done in two parts: a scatter-reduce and an allgather. In
|
||||
* the scatter reduce, a reduction is done, so that each node ends up with a
|
||||
* chunk of the final output tensor which has contributions from all other
|
||||
* nodes. In the allgather, those chunks are distributed among all the nodes,
|
||||
* so that all nodes have the entire output tensor.
|
||||
*
|
||||
* Both of these operations are done by dividing the input tensor into N
|
||||
* evenly sized chunks (where N is the number of nodes in the ring).
|
||||
*
|
||||
* The scatter-reduce is done in N-1 steps. In the ith step, node j will send
|
||||
* the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to
|
||||
* its existing data for that chunk. For example, in the first iteration with
|
||||
* the ring depicted above, you will have the following transfers:
|
||||
*
|
||||
* Segment 0: Node 0 --> Node 1
|
||||
* Segment 1: Node 1 --> Node 2
|
||||
* Segment 2: Node 2 --> Node 3
|
||||
* Segment 3: Node 3 --> Node 0
|
||||
*
|
||||
* In the second iteration, you'll have the following transfers:
|
||||
*
|
||||
* Segment 0: Node 1 --> Node 2
|
||||
* Segment 1: Node 2 --> Node 3
|
||||
* Segment 2: Node 3 --> Node 0
|
||||
* Segment 3: Node 0 --> Node 1
|
||||
*
|
||||
* After this iteration, Node 2 has 3 of the four contributions to Segment 0.
|
||||
* The last iteration has the following transfers:
|
||||
*
|
||||
* Segment 0: Node 2 --> Node 3
|
||||
* Segment 1: Node 3 --> Node 0
|
||||
* Segment 2: Node 0 --> Node 1
|
||||
* Segment 3: Node 1 --> Node 2
|
||||
*
|
||||
* After this iteration, Node 3 has the fully accumulated Segment 0; Node 0
|
||||
* has the fully accumulated Segment 1; and so on. The scatter-reduce is complete.
|
||||
*
|
||||
* Next, the allgather distributes these fully accumululated chunks across all nodes.
|
||||
* Communication proceeds in the same ring, once again in N-1 steps. At the ith step,
|
||||
* node j will send chunk (j - i + 1) and receive chunk (j - i). For example, at the
|
||||
* first iteration, the following transfers will occur:
|
||||
*
|
||||
* Segment 0: Node 3 --> Node 0
|
||||
* Segment 1: Node 0 --> Node 1
|
||||
* Segment 2: Node 1 --> Node 2
|
||||
* Segment 3: Node 2 --> Node 3
|
||||
*
|
||||
* After the first iteration, Node 0 will have a fully accumulated Segment 0
|
||||
* (from Node 3) and Segment 1. In the next iteration, Node 0 will send its
|
||||
* just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3.
|
||||
* After this has continued for N - 1 iterations, all nodes will have a the fully
|
||||
* accumulated tensor.
|
||||
*
|
||||
* Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the allgather.
|
||||
* Each send will contain K / N bytes, if there are K bytes in the original tensor on every node.
|
||||
* Thus, each node sends and receives 2K(N - 1)/N bytes of data, and the performance of the allreduce
|
||||
* (assuming no latency in connections) is constrained by the slowest interconnect between the nodes.
|
||||
*
|
||||
*/
|
||||
template<typename Device, typename T>
|
||||
Status RingAllreduce(OpKernelContext* context, const Tensor* input,
|
||||
Tensor* temp, Tensor* output) {
|
||||
// Acquire MPI size and rank
|
||||
int n, r;
|
||||
MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
|
||||
MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
|
||||
|
||||
T* buffer = (T*) output->tensor_data().data();
|
||||
|
||||
CopyTensorData<Device>((void*) buffer,
|
||||
(void*) input->tensor_data().data(),
|
||||
output->tensor_data().size());
|
||||
|
||||
// Calculate segment sizes and segment ends
|
||||
const size_t elements_to_reduce = input->NumElements();
|
||||
const size_t segment_size = elements_to_reduce / n;
|
||||
std::vector<size_t> segment_sizes(n, segment_size);
|
||||
|
||||
const size_t residual = elements_to_reduce % n;
|
||||
for (size_t i = 0; i < residual; ++i) {
|
||||
segment_sizes[i]++;
|
||||
}
|
||||
|
||||
std::vector<size_t> segment_starts(n);
|
||||
segment_starts[0] = 0;
|
||||
for (size_t i = 1; i < segment_starts.size(); ++i) {
|
||||
segment_starts[i] = segment_starts[i-1] + segment_sizes[i-1];
|
||||
}
|
||||
|
||||
assert(segment_starts[n-1] + segment_sizes[n-1] == elements_to_reduce);
|
||||
|
||||
T* segment_recv = (T*) temp->tensor_data().data();
|
||||
|
||||
// Receive from your left neighbor with wrap-around
|
||||
const size_t recv_from = ((r - 1) + n) % n;
|
||||
|
||||
// Send to your right neighbor with wrap-around
|
||||
const size_t send_to = (r + 1) % n;
|
||||
|
||||
MPI_Status recv_status;
|
||||
MPI_Request recv_req;
|
||||
|
||||
// Now start ring. At every step, for every rank, we iterate through
|
||||
// segments with wraparound and send and recv from our neighbors and reduce
|
||||
// locally. At the i'th iteration, rank r, sends segment (r-i) and receives
|
||||
// segment (r-i-1).
|
||||
for (int i = 0; i < n - 1; i++) {
|
||||
const size_t send_seg_id = ((r-i) + n) % n;
|
||||
const size_t recv_seg_id = ((r-i-1) + n) % n;
|
||||
|
||||
T* segment_send = &(buffer[segment_starts[send_seg_id]]);
|
||||
|
||||
MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id],
|
||||
MPIType<T>(), recv_from, TAG_TENSOR,
|
||||
MPI_COMM_WORLD, &recv_req));
|
||||
|
||||
MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id],
|
||||
MPIType<T>(), send_to, TAG_TENSOR,
|
||||
MPI_COMM_WORLD));
|
||||
|
||||
T *segment_update = &(buffer[segment_starts[recv_seg_id]]);
|
||||
|
||||
// Wait for recv to complete before reduction
|
||||
MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status));
|
||||
|
||||
const size_t recv_seg_size = segment_sizes[recv_seg_id];
|
||||
AccumulateTensorData<Device, T>(
|
||||
segment_update, segment_recv, recv_seg_size);
|
||||
}
|
||||
|
||||
// Now start pipelined ring allgather. At every step, for every rank, we
|
||||
// iterate through segments with wraparound and send and recv from our
|
||||
// neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and
|
||||
// receives segment (r-i).
|
||||
for (size_t i = 0; i < n - 1; ++i) {
|
||||
const size_t send_seg_id = ((r-i+1) + n) % n;
|
||||
const size_t recv_seg_id = ((r-i) + n) % n;
|
||||
|
||||
// Segment to send - at every iteration we send segment (r-i+1)
|
||||
T* segment_send = &(buffer[segment_starts[send_seg_id]]);
|
||||
|
||||
// Segment to recv - at every iteration we receive segment (r-i)
|
||||
T* segment_recv = &(buffer[segment_starts[recv_seg_id]]);
|
||||
|
||||
MPI_REQUIRES_OK(MPI_Sendrecv(segment_send, segment_sizes[send_seg_id],
|
||||
MPIType<T>(), send_to, TAG_TENSOR, segment_recv,
|
||||
segment_sizes[recv_seg_id], MPIType<T>(), recv_from,
|
||||
TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform a ring allgather on a Tensor. Other ranks may allgather with a
|
||||
// tensor which differs in the first dimension only; all other dimensions must
|
||||
// be the same.
|
||||
//
|
||||
// For more information on the ring allgather, read the documentation for the
|
||||
// ring allreduce, which includes a ring allgather.
|
||||
template<typename Device, typename T>
|
||||
Status RingAllgather(OpKernelContext* context, const Tensor* input,
|
||||
const std::vector<size_t>& sizes, Tensor* output) {
|
||||
// Acquire MPI size and rank
|
||||
int n, r;
|
||||
MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
|
||||
MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
|
||||
|
||||
assert(sizes.size() == n);
|
||||
assert(input->dim_size(0) == sizes[r]);
|
||||
|
||||
// Compute number of elements in every "row". We can't compute number of
|
||||
// elements in every chunks, because those chunks are variable length.
|
||||
size_t elements_per_row = 1;
|
||||
for (int i = 1; i < input->shape().dims(); i++) {
|
||||
elements_per_row *= input->dim_size(i);
|
||||
}
|
||||
|
||||
// Copy data from input tensor to correct place in output tensor.
|
||||
std::vector<size_t> segment_starts(n);
|
||||
segment_starts[0] = 0;
|
||||
for (int i = 1; i < n; i++) {
|
||||
segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1];
|
||||
}
|
||||
size_t offset = segment_starts[r];
|
||||
|
||||
// Copy data to the right offset for this rank.
|
||||
T* buffer = (T*) output->tensor_data().data();
|
||||
CopyTensorData<Device>((void*) (buffer + offset),
|
||||
(void*) input->tensor_data().data(),
|
||||
elements_per_row * sizes[r] * sizeof(T));
|
||||
|
||||
// Receive from your left neighbor with wrap-around
|
||||
const size_t recv_from = ((r - 1) + n) % n;
|
||||
|
||||
// Send to your right neighbor with wrap-around
|
||||
const size_t send_to = (r + 1) % n;
|
||||
|
||||
// Perform a ring allgather. At every step, for every rank, we iterate
|
||||
// through segments with wraparound and send and recv from our neighbors.
|
||||
// At the i'th iteration, rank r, sends segment (r-i) and receives segment
|
||||
// (r-1-i).
|
||||
MPI_Status recv_status;
|
||||
for (size_t i = 0; i < n - 1; ++i) {
|
||||
const size_t send_seg_id = ((r-i) + n) % n;
|
||||
const size_t recv_seg_id = ((r-i-1) + n) % n;
|
||||
|
||||
// Segment to send - at every iteration we send segment (r-i)
|
||||
size_t offset_send = segment_starts[send_seg_id];
|
||||
size_t rows_send = sizes[send_seg_id];
|
||||
T* segment_send = &(buffer[offset_send]);
|
||||
|
||||
// Segment to recv - at every iteration we receive segment (r-1-i)
|
||||
size_t offset_recv = segment_starts[recv_seg_id];
|
||||
size_t rows_recv = sizes[recv_seg_id];
|
||||
T* segment_recv = &(buffer[offset_recv]);
|
||||
|
||||
MPI_REQUIRES_OK(MPI_Sendrecv(
|
||||
segment_send, elements_per_row * rows_send,
|
||||
MPIType<T>(), send_to, TAG_TENSOR, segment_recv,
|
||||
elements_per_row * rows_recv, MPIType<T>(), recv_from,
|
||||
TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_USE_MPI
|
||||
|
||||
#undef TENSORFLOW_CONTRIB_MPI_H_
|
||||
#endif // TENSORFLOW_CONTRIB_MPI_H_
|
24
tensorflow/tools/ci_build/Dockerfile.cpu.mpi
Normal file
24
tensorflow/tools/ci_build/Dockerfile.cpu.mpi
Normal file
@ -0,0 +1,24 @@
|
||||
FROM ubuntu:14.04
|
||||
|
||||
LABEL authors="Andrew Gibiansky <andrew.gibiansky@gmail.com>, Joel Hestness <jthestness@gmail.com>"
|
||||
|
||||
# Copy and run the install scripts.
|
||||
COPY install/*.sh /install/
|
||||
RUN /install/install_bootstrap_deb_packages.sh
|
||||
RUN add-apt-repository -y ppa:openjdk-r/ppa && \
|
||||
add-apt-repository -y ppa:mc3man/trusty-media && \
|
||||
add-apt-repository -y ppa:george-edison55/cmake-3.x
|
||||
RUN /install/install_deb_packages.sh
|
||||
RUN /install/install_pip_packages.sh
|
||||
RUN /install/install_bazel.sh
|
||||
RUN /install/install_proto3.sh
|
||||
RUN /install/install_buildifier.sh
|
||||
RUN /install/install_mpi.sh
|
||||
|
||||
# Set up bazelrc.
|
||||
COPY install/.bazelrc /root/.bazelrc
|
||||
ENV BAZELRC /root/.bazelrc
|
||||
|
||||
# Set up MPI
|
||||
ENV TF_NEED_MPI 1
|
||||
ENV MPI_HOME /usr/lib/openmpi
|
23
tensorflow/tools/ci_build/install/install_mpi.sh
Executable file
23
tensorflow/tools/ci_build/install/install_mpi.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set +e
|
||||
mpiexec=$(which mpiexec)
|
||||
if [[ -z "$mpiexec_location" ]]; then
|
||||
# Install dependencies from ubuntu deb repository.
|
||||
apt-get update
|
||||
apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev
|
||||
fi
|
3
third_party/mpi/.gitignore
vendored
Normal file
3
third_party/mpi/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
*.h
|
||||
*.dylib
|
||||
*.so
|
26
third_party/mpi_collectives/BUILD
vendored
Normal file
26
third_party/mpi_collectives/BUILD
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE.txt"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mpi",
|
||||
srcs = select({
|
||||
"//tensorflow:darwin": ["libmpi.dylib"],
|
||||
"//conditions:default": ["libmpi.so"],
|
||||
}),
|
||||
hdrs = ["mpi.h", "mpi_portable_platform.h"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user