diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 7f740f1633c..1d25c50722b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//third_party/mpi:mpi.bzl", "if_mpi") + py_library( name = "contrib_py", srcs = glob(["**/*.py"]), @@ -84,7 +86,7 @@ py_library( "//tensorflow/contrib/tpu", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", - ], + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]), ) cc_library( diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD new file mode 100644 index 00000000000..11c5d6e776d --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -0,0 +1,80 @@ +# Ops that communicate with other processes via MPI. + +package(default_visibility = [ + "//tensorflow:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_cc", +) + +tf_proto_library_cc( + name = "mpi_message_proto", + srcs = ["mpi_message.proto"], + cc_api_version = 2, + protodeps = ["//tensorflow/core:protos_all"], + visibility = [ + "//tensorflow:__subpackages__", + ], +) + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_custom_op_library( + name = "mpi_collectives.so", + srcs = [ + "mpi_ops.cc", + "ring.cc", + "ring.h", + ], + gpu_srcs = [ + "ring.cu.cc", + "ring.h", + ], + deps = [ + ":mpi_message_proto_cc", + "//third_party/mpi", + ], +) + +tf_py_test( + name = "mpi_ops_test", + srcs = ["mpi_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + ], + data = [ + ":mpi_collectives.so", + ], + tags = ["manual"], +) + +py_library( + name = "mpi_ops_py", + srcs = [ + "__init__.py", + "mpi_ops.py", + ], + data = [ + ":mpi_collectives.so", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/mpi_collectives/README.md b/tensorflow/contrib/mpi_collectives/README.md new file mode 100644 index 00000000000..c5e1a8c37e3 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/README.md @@ -0,0 +1,5 @@ +# MPI TensorFlow integration + +Tensorflow MPI integration allows communicating between different TensorFlow +processes using MPI. This enables training across multiple nodes and GPUs +using high-speed interconnects. diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py new file mode 100644 index 00000000000..b94f7b0a353 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/__init__.py @@ -0,0 +1,273 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-short-docstring-punctuation +"""## Communicating Between Processes with MPI + +TensorFlow natively provides inter-device communication through send and +receive ops and inter-node communication through Distributed TensorFlow, based +on the same send and receive abstractions. On HPC clusters where Infiniband or +other high-speed node interconnects are available, these can end up being +insufficient for synchronous data-parallel training (without asynchronous +gradient descent). This module implements a variety of MPI ops which can take +advantage of hardware-specific MPI libraries for efficient communication. + +In order to use this module, TensorFlow must be built with an MPI library, +which can be provided to the `./configure` script at build time. As a user of +TensorFlow, you will need to build TensorFlow yourself to select the MPI +library to use; to do so, follow the [instructions for building TensorFlow from +source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources). + +### Utility Ops + +In addition to reductions and gathers, this module provides utility operations +for detecting the running MPI configuration. + +Example: + +```python +from tensorflow.contrib import mpi + +# Use `mpi.Session` instead of `tf.Session` +with mpi.Session() as session: + rank = session.run(mpi.rank()) + print("My MPI Rank:", rank) + + if rank == 0: + print("MPI Size:", session.run(mpi.size())) +``` + +@@rank +@@size + +### Ring Allreduce and Allgather + +When summing or averaging tensors across many processes, communication can +easily become a bottleneck. A naive implementation will send all the tensor +values to the same process, perform the reduction, and then broadcast the +values back to all other processes, effectively creating a synchronous +parameter server in one process. However, the process responsible for +performing the reduction will have to receive and send a massive amount of data +which scales with the number of processes *and* the number of parameters in the +model. + +Instead of centralizing the reduction and having one primary reducer, we can +implement a distributed allreduce or allgather. A bandwidth-optimal allreduce +will end up sending 2(N - 1) values for every value in the input tensor, +and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce +requires at least (N - 1) sends between the different nodes, and a broadcast of +the result also requires (N - 1) sends, for a total of 2 (N - 1); these two +steps cannot be combined in a clever way to reduce the number of required +sends.) This module implements bandwidth-optimal ring allreduce and ring +allgather operations using MPI; by choosing a hardware-appropriate MPI +implementation (such as OpenMPI with CUDA-IPC support), you can train large +models with synchronous gradient descent with minimal communication overhead. + +In addition to the `allreduce` and `allgather` functions, a convenience +`DistributedOptimizer` wrapper is provided to simplify using these functions +for reducing model gradients. + +Example: + +```python +import tensorflow as tf +from tensorflow.contrib import mpi_collectives as mpi + +# Construct a simple linear regression model to optimize +W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32) +B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32) +inputs = tf.placeholder("Inputs", shape=[None, 20]) +outputs = tf.placeholder("Outputs", shape=[None, 1]) +loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs) + +# Training using MPI allreduce with DistributedOptimizer +optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer()) +train = optimizer.minimize(loss) + +# Average loss over all ranks, for printing. +# Do not pass this to an optimizer! +avg_loss = mpi.allreduce(loss) + +# On different ranks, feed different input data. +with mpi.Session() as session: + rank = session.run(mpi.rank()) + batch_inputs, batch_outputs = construct_batch_for_rank(rank) + feed_dict = {inputs: batch_inputs, outputs: batch_outputs} + _, l = session.run([train, avg_loss], feed_dict=feed_dict) + print("Average Loss:", l) +``` + +[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms +for Clusters of Workstations". + +@@Session +@@DistributedOptimizer +@@allreduce +@@allgather +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.mpi_collectives.mpi_ops import size +from tensorflow.contrib.mpi_collectives.mpi_ops import rank +from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank +from tensorflow.contrib.mpi_collectives.mpi_ops import allgather +from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce +from tensorflow.contrib.mpi_collectives.mpi_ops import init + + +def allreduce(tensor, average=True): + """Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices. + + Arguments: + tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce. + The shape of the input must be identical across all ranks. + average: If True, computes the average over all ranks. + Otherwise, computes the sum over all ranks. + + This function performs a bandwidth-optimal ring allreduce on the input + tensor. If the input is an tf.IndexedSlices, the function instead does an + allgather on the values and the indices, effectively doing an allreduce on + the represented tensor. + """ + if isinstance(tensor, tf.IndexedSlices): + # For IndexedSlices, do two allgathers intead of an allreduce. + mpi_size = tf.cast(size(), tensor.values.dtype) + values = allgather(tensor.values) + indices = allgather(tensor.indices) + + # To make this operation into an average, divide all gathered values by + # the MPI size. + new_values = tf.div(values, mpi_size) if average else values + return tf.IndexedSlices(new_values, indices, + dense_shape=tensor.dense_shape) + else: + mpi_size = tf.cast(size(), tensor.dtype) + summed_tensor = _allreduce(tensor) + new_tensor = (tf.div(summed_tensor, mpi_size) + if average else summed_tensor) + return new_tensor + + +class DistributedOptimizer(tf.train.Optimizer): + """An optimizer that wraps another tf.Optimizer, using an MPI allreduce to + average gradient values before applying gradients to model weights.""" + + def __init__(self, optimizer, name=None, use_locking=False): + """Construct a new DistributedOptimizer, which uses another optimizer + under the hood for computing single-process gradient values and + applying gradient updates after the gradient values have been averaged + across all the MPI ranks. + + Args: + optimizer: Optimizer to use for computing gradients and applying updates. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Distributed" followed by the provided + optimizer type. + use_locking: Whether to use locking when updating variables. See + Optimizer.__init__ for more info. + """ + if name is None: + name = "Distributed{}".format(type(optimizer).__name__) + + self._optimizer = optimizer + super(DistributedOptimizer, self).__init__( + name=name, use_locking=use_locking) + + def compute_gradients(self, *args, **kwargs): + """Compute gradients of all trainable variables. + + See Optimizer.compute_gradients() for more info. + + In DistributedOptimizer, compute_gradients() is overriden to also + allreduce the gradients before returning them. + """ + gradients = (super(DistributedOptimizer, self) + .compute_gradients(*args, **kwargs)) + return [(allreduce(gradient), var) for (gradient, var) in gradients] + + def _apply_dense(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_dense(*args, **kwargs) + + def _apply_sparse(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_sparse(*args, **kwargs) + + def _apply_sparse_duplicate_indices(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._apply_sparse_duplicate_indices(*args, + **kwargs) + + def _prepare(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._prepare(*args, **kwargs) + + def _create_slots(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._create_slots(*args, **kwargs) + + def _valid_dtypes(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._valid_dtypes(*args, **kwargs) + + def _finish(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer._finish(*args, **kwargs) + + +class Session(tf.Session): + """A class for running TensorFlow operations, with copies of the same graph + running distributed across different MPI nodes. + + The primary difference between `tf.Session` and + `tf.contrib.mpi_collectives.Session` is that the MPI `Session` ensures that + the `Session` options are correct for use with `tf.contrib.mpi`, and + initializes MPI immediately upon the start of the session. + """ + + def __init__(self, target='', graph=None, config=None): + """Creates a new TensorFlow MPI session. + + Unlike a normal `tf.Session`, an MPI Session may only use a single GPU, + which must be specified in advance before the session is initialized. + In addition, it only uses a single graph evaluation thread, and + initializes MPI immediately upon starting. + + If no `graph` argument is specified when constructing the session, + the default graph will be launched in the session. If you are + using more than one graph (created with `tf.Graph()` in the same + process, you will have to use different sessions for each graph, + but each graph can be used in multiple sessions. In this case, it + is often clearer to pass the graph to be launched explicitly to + the session constructor. + + Args: + target: (Optional.) The execution engine to connect to. + graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional.) A `ConfigProto` protocol buffer with configuration + options for the session. + """ + super(Session, self).__init__(target, graph, config=config) + + # Initialize MPI on the relevant device. + # TODO: Move this to library load and eliminate mpi.Session() + if graph is None: + graph = tf.get_default_graph() + with graph.as_default(): + self.run(init()) diff --git a/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py b/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py new file mode 100644 index 00000000000..62fd1c281c1 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py @@ -0,0 +1,96 @@ +from __future__ import print_function + +import os +import numpy as np +import tensorflow as tf +import tensorflow.contrib.mpi_collectives as mpi +from tensorflow.python.platform import test + + +average_allgather = False + + +class AllgatherTest(test.TestCase): + def checkAllgather(self, num_ranks, all_gathered, local_gathered): + # Ensure that indices match. + all_gat_ind = np.sort(all_gathered.indices) + loc_gat_ind = np.sort(local_gathered.indices) + assert(len(loc_gat_ind) == len(all_gat_ind)) + for i in range(len(loc_gat_ind)): + assert(loc_gat_ind[i] == all_gat_ind[i]) + + # For each index, verify same values. + local_checked = [] + for i in range(len(local_gathered.indices)): + local_checked.append(False) + for i in range(len(all_gathered.indices)): + all_index = all_gathered.indices[i] + # TODO(jthestness): Make this lookup quicker using sorting. + loc_index = -1 + for j in range(len(local_gathered.indices)): + if local_gathered.indices[j] == all_index and not local_checked[j]: + loc_index = j + local_checked[j] = True + break + assert(loc_index >= 0) + correct_output = local_gathered.values[loc_index][0] + if average_allgather: + correct_output = correct_output / float(num_ranks) + assert(all_gathered.values[i][0] == correct_output) + + + def test_mpi_allgather(self): + # Get MPI rank + my_rank = int(os.environ['PMI_RANK']) + num_ranks = int(os.environ['PMI_SIZE']) + + indices_per_rank = 100 + tensor_width = 10 + + # Create IndexedSlices for each rank, some with overlapping indices. + to_gather_indices = [] + to_gather_values = [] + to_gather = [] + for rank_id in range(num_ranks): + indices = [] + values = [] + my_multiple = rank_id + 1 + current_index = my_multiple + for i in range(indices_per_rank): + indices.append(current_index) + ones_tensor = tf.ones([tensor_width]) + values.append(tf.multiply(ones_tensor, + tf.fill(ones_tensor.get_shape(), + float(current_index)))) + current_index += my_multiple + concat_ind = tf.stack(indices) + concat_vals = tf.stack(values) + to_gather_indices.append(concat_ind) + to_gather_values.append(concat_vals) + to_gather.append(tf.IndexedSlices(concat_vals, concat_ind)) + + # Collect the local IndexedSlices (indices and values) to create + # correct IndexedSlices output. + correct_gather_indices = tf.concat(to_gather_indices, 0) + correct_gather_values = tf.concat(to_gather_values, 0) + correct_gather = tf.IndexedSlices(correct_gather_values, + correct_gather_indices) + + all_gather = mpi.allreduce(to_gather[my_rank], average_allgather) + + # NOTE: This assumes that device IDs are numbered the same as ranks. + gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) + config = tf.ConfigProto(gpu_options=gpu_options) + + # MPI Session to test allgather. + with mpi.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + all_gathered, local_gathered = sess.run([all_gather, correct_gather]) + + # Compare all_gathered with local_gathered. + self.checkAllgather(num_ranks, all_gathered, local_gathered) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py new file mode 100644 index 00000000000..9f4f8b3ff98 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py @@ -0,0 +1,136 @@ +from __future__ import print_function + +import os +import numpy as np +import tensorflow as tf +import tensorflow.contrib.mpi_collectives as mpi +from tensorflow.python.platform import test + + +average_allreduce = False +max_wrong_count = -1 + + +class AllreduceTest(test.TestCase): + def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red, + our_correct): + # Find reduced/allreduced indices that are wrong and print all the + # values from output, slices, reduced, allreduced, so we can debug + # which is incorrect: + wrong_count = 0 + red_dims = out_loc_red.shape + assert(len(red_dims) == 2) + for i in range(red_dims[0]): + for j in range(red_dims[1]): + suffix = "" + if out_loc_red[i][j] != my_correct[i][j] or \ + out_all_red[i][j] != our_correct[i][j]: + suffix = "WRONG" + wrong_count += 1 + print("{}\t{}\t{}\t{}\t{}\t{}" + .format(my_rank, i, j, out_loc_red[i][j], + out_all_red[i][j], suffix), flush=True) + if max_wrong_count > 0 and wrong_count >= max_wrong_count: + return + + def test_mpi_allreduce(self): + # Get MPI rank + my_rank = int(os.environ['PMI_RANK']) + num_ranks = int(os.environ['PMI_SIZE']) + + stages = 13 + batch_size = 1331 + hidden_size = batch_size + out_size = batch_size + + # Input placeholder (batch_size x hidden) - init to 1s + inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size), + name="Input") + + # Large matrices (hidden x out_dim) - init random + weights = [] + for i in range(stages): + initer = tf.constant_initializer(pow(2.0, i + 1.0)) + weights.append(tf.get_variable("weights_{}".format(i), + shape=(hidden_size, out_size), + dtype=tf.float32, + initializer=initer)) + + # Calculate output through dependent allreduces + stage_input = inputs + for i in range(stages): + inter_output = tf.add(stage_input, weights[i], + name="add_red_{}".format(i)) + stage_input = mpi.allreduce(inter_output, + average=average_allreduce) + + all_reduced = stage_input + + # Local reduced output for verification + local_input = inputs + for i in range(stages): + inter_output = tf.add(local_input, weights[i], + name="addin_loc_{}".format(i)) + my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)), + dtype=tf.float32, name="loc_redr_{}".format(i)) + for r in range(num_ranks): + my_reducer = tf.add(my_reducer, inter_output, + name="add_loc_{}_{}".format(i, r)) + if average_allreduce: + local_input = tf.div(my_reducer, num_ranks, + name="div_loc_{}".format(i)) + else: + local_input = my_reducer + + local_reduced = local_input + + # NOTE: This assumes that device IDs are numbered the same as ranks + gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) + config = tf.ConfigProto(gpu_options=gpu_options) + + # MPI Session to test allreduce + with mpi.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + input_feed = np.ones((batch_size, hidden_size), dtype=np.float32) + our_output = input_feed[0][0] + spread_var = 100 + input_feed = input_feed + my_rank * spread_var + my_output = input_feed[0][0] + for i in range(stages): + curr_feed = my_output + pow(2.0, i + 1.0) + my_output = curr_feed * num_ranks + 1 + curr_our_feed = our_output + pow(2.0, i + 1.0) + if i == 0: + sum_ranks = num_ranks * (num_ranks - 1) / 2 + our_output = curr_our_feed * num_ranks + \ + spread_var * sum_ranks + else: + our_output = curr_our_feed * num_ranks + + print("rank {}: My output is {}".format(my_rank, my_output)) + my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) + my_correct = my_correct + my_output + print("rank {}: Our output is {}".format(my_rank, our_output)) + our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) + our_correct = our_correct + our_output + + for i in range(1000): + if i % 100 == 0: + print("{}: iter {}".format(my_rank, i), flush=True) + feed_dict = {inputs: input_feed} + out_all_red, out_loc_red \ + = sess.run([all_reduced, local_reduced], + feed_dict=feed_dict) + + if not np.allclose(out_loc_red, my_correct) or \ + not np.allclose(out_all_red, our_correct): + print("Test incorrect on iter {}".format(i), flush=True) + self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red, + our_correct) + assert(np.allclose(out_loc_red, my_correct) and + np.allclose(out_all_red, our_correct)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/mpi_collectives/mpi_message.proto b/tensorflow/contrib/mpi_collectives/mpi_message.proto new file mode 100644 index 00000000000..ce64ce5ced9 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_message.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package tensorflow.contrib.mpi; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// An MPIRequest is a message sent from a rank greater than zero to the +// coordinator (rank zero), informing the coordinator of an operation that +// the rank wants to do and the tensor that it wants to apply the operation to. +message MPIRequest { + enum RequestType { + ALLREDUCE = 0; + ALLGATHER = 1; + } + + // The request rank is necessary to create a consistent ordering of results, + // for example in the allgather where the order of outputs should be sorted + // by rank. + int32 request_rank = 1; + RequestType request_type = 2; + DataType tensor_type = 3; + string tensor_name = 4; + TensorShapeProto tensor_shape = 5; +}; + +// An MPIResponse is a message sent from the coordinator (rank zero) to a rank +// greater than zero, informing the rank of an operation should be performed +// now. If the operation requested would result in an error (for example, due +// to a type or shape mismatch), then the MPIResponse can contain an error and +// an error message instead. Finally, an MPIResponse can be a DONE message (if +// there are no more tensors to reduce on this tick of the background loop) or +// SHUTDOWN if all MPI processes should shut down. +message MPIResponse { + enum ResponseType { + ALLREDUCE = 0; + ALLGATHER = 1; + ERROR = 2; + DONE = 3; + SHUTDOWN = 4; + } + + // Empty if the type is DONE or SHUTDOWN. + ResponseType response_type = 1; + string tensor_name = 2; + + // Empty unless response_type is ERROR. + string error_message = 3; +}; diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/mpi_ops.cc new file mode 100644 index 00000000000..33f1ec03176 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_ops.cc @@ -0,0 +1,1241 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ + +#ifdef TENSORFLOW_USE_MPI + +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/mutex.h" + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#include +#include "tensorflow/stream_executor/stream.h" +#endif + +#include "tensorflow/stream_executor/lib/statusor.h" + + +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" +#include "tensorflow/contrib/mpi_collectives/ring.h" +#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" + +/* + * MPI Allreduce and Allgather Ops for TensorFlow. + * + * TensorFlow natively provides inter-device communication through send and + * receive ops and inter-node communication through Distributed TensorFlow, + * based on the same send and receive abstractions. These end up being + * insufficient for synchronous data-parallel training on HPC clusters where + * Infiniband or other high-speed interconnects are available. This module + * implements MPI ops for allgather and allreduce, which do bandwidth-optimal + * gathers and reductions and can take advantage of hardware-optimized + * communication libraries through the MPI implementation. + * + * The primary logic of the allreduce and allgather are in RingAllgather() and + * RingAllreduce(). The background thread which facilitates MPI operations is + * run in BackgroundThreadLoop(). The provided MPI ops are: + * – MPIInit: + * Initialize MPI on a given device (CPU or GPU). + * Should only be run on a single device in every process. + * – MPISize: + * Get the number of MPI processes in the global communicator. + * – MPIRank: + * Get the rank of the current MPI process in the global communicator. + * – MPILocalRank: + * Get the local rank of the current MPI process within its node. + * – MPIAllreduce: + * Perform an allreduce on a Tensor, returning the sum + * across all MPI processes in the global communicator. + * – MPIAllgather: + * Perform an allgather on a Tensor, returning the concatenation of + * the tensor on the first dimension across all MPI processes in the + * global communicator. + * + */ + +template +using StatusOr = perftools::gputools::port::StatusOr; + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace tensorflow { +namespace contrib { +namespace mpi { + +// Make sure template specializations are generated in the ring.cu.cc and the +// ring.cc file, not in this file. +extern template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +extern template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +extern template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +extern template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +extern template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +extern template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +extern template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); + +namespace { + +// Return true if the templated type is GPUDevice, otherwise false. +template bool IsGPUDevice(); +template<> bool IsGPUDevice() { return true; }; +template<> bool IsGPUDevice() { return false; }; + +// A callback to call after the MPI communication completes. Since the +// allreduce and allgather ops are asynchronous, this callback is what resumes +// computation after the reduction is completed. +typedef std::function)> CommunicationDoneCallback; + +struct CollectiveOpRecord { + // The rank performing this piece of the op + int rank; + + // The name of the op/tensor to be reduced + std::string name; + + // The op's kernel context + OpKernelContext *context; + + // Data type of the op + DataType dtype; + + // The input tensor + const Tensor *in_t; + + // Allgather: Vector of per-rank first-dimension sizes + std::vector sizes_vec; + + // The temp tensor for intermediate results + Tensor temp_t; + + // The output tensor + Tensor *out_t; + + // Whether to run this op on the gpu + bool on_gpu; + + // The callback to call after the op has completed + CommunicationDoneCallback callback; +}; + +// Table storing Tensors to be reduced, keyed by unique name. +// This table contains everything necessary to do the reduction +typedef std::unordered_map TensorTable; + +// Table for storing Tensor metadata on rank zero. This is used for error +// checking and size calculations, as well as determining when a reduction is +// ready to be done (when all nodes are ready to do it). +typedef std::unordered_map > MessageTable; + +// The global state required for the MPI ops. +// +// MPI is a library that stores a lot of global per-program state and often +// requires running on a single thread. As a result, we have to have a single +// background thread responsible for all MPI operations, and communicate with +// that background thread through global state. +struct MPIGlobalState { + // An atomic boolean which is set to true when MPI is initialized. + // This ensures that MPI_Init is never called twice. + std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT; + + // Condition variable to wait for initialization + condition_variable cv; + + // Whether MPI_Init has been completed on the background thread. + bool initialization_done = false; + + // Whether MPI_Init succeeded on the background thread. + Status init_status; + + // A mutex that needs to be used whenever MPI operations touch + // shared structures. + mutex mu; + + // Tensors waiting to be allreduced or allgathered. + TensorTable tensor_table; + + // Queue of MPI requests waiting to be sent to the coordinator node. + std::queue message_queue; + + // Background thread running MPI communication. + std::thread background_thread; + + // Whether the background thread should shutdown. + bool shut_down = false; + + // Only exists on the coordinator node (rank zero). Maintains a count of + // how many nodes are ready to allreduce every tensor (keyed by tensor + // name). + std::unique_ptr message_table; + + // The MPI rank, local rank, and size. + int rank = 0; + int local_rank = 0; + int size = 1; + + // The device that MPI was initialized on. (-1 for no GPU) + int device = -1; + + // The CUDA stream used for data transfers and within-allreduce operations. + // A naive implementation would use the TensorFlow StreamExecutor CUDA + // stream. However, the allreduce and allgather require doing memory copies + // and kernel executions (for accumulation of values on the GPU). However, + // the subsequent operations must wait for those operations to complete, + // otherwise MPI (which uses its own stream internally) will begin the data + // transfers before the CUDA calls are complete. In order to wait for those + // CUDA operations, if we were using the TensorFlow stream, we would have + // to synchronize that stream; however, other TensorFlow threads may be + // submitting more work to that stream, so synchronizing on it can cause + // the allreduce to be delayed, waiting for compute totally unrelated to it + // in other parts of the graph. Overlaying memory transfers and compute + // during backpropagation is crucial for good performance, so we cannot use + // the TensorFlow stream, and must use our own stream. +#if GOOGLE_CUDA + cudaStream_t stream; + std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT; +#endif + + ~MPIGlobalState() { + // Make sure that the destructor of the background thread is safe to + // call. If a thread is still joinable (not detached or complete) its + // destructor cannot be called. + if (background_thread.joinable()) { + shut_down = true; + background_thread.join(); + } + } +}; + +// All the MPI state that must be stored globally per-process. +static MPIGlobalState mpi_global; + +// For clarify in argument lists. +#define RANK_ZERO 0 + +// A tag used for all coordinator messaging. +#define TAG_NOTIFY 1 + +// Store the MPIRequest for a name, and return whether the total count of +// MPIRequests for that tensor is now equal to the MPI size (and thus we are +// ready to reduce the tensor). +bool IncrementTensorCount( + std::unique_ptr& message_table, + MPIRequest msg, int mpi_size) { + auto name = msg.tensor_name(); + auto table_iter = message_table->find(name); + if (table_iter == message_table->end()) { + message_table->emplace(name, std::vector({msg})); + table_iter = message_table->find(name); + } else { + table_iter->second.push_back(msg); + } + + int count = table_iter->second.size(); + return count == mpi_size; +} + +// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse +// instructing all ranks to start the reduction to all ranks. The MPIResponse +// also contains error messages in case the submitted MPIRequests were not +// valid (for example, contained mismatched shapes or types). +// +// Constructing the MPIResponse, thus, requires a whole lot of error checking. +MPIResponse ConstructMPIResponse(std::unique_ptr& message_table, + std::string name) { + bool error = false; + auto it = message_table->find(name); + assert(it != message_table->end()); + + std::vector requests = it->second; + assert(requests.size() > 0); + + std::ostringstream error_message_stream; + + // Check that all data types being reduced or gathered are identical + auto data_type = requests[0].tensor_type(); + for (unsigned int i = 1; i < requests.size(); i++) { + auto request_type = requests[i].tensor_type(); + if (data_type != request_type) { + error = true; + error_message_stream + << "Mismatched data types: One rank had type " + << DataType_Name(data_type) + << ", but another rank had type " + << DataType_Name(request_type) + << "."; + break; + } + } + + // Check that all requested operations are the same + auto message_type = requests[0].request_type(); + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + auto request_type = requests[i].request_type(); + if (message_type != request_type) { + error = true; + error_message_stream + << "Mismatched MPI operations: One rank did an " + << message_type + << ", but another rank did an " + << request_type + << "."; + break; + } + } + + // If we are doing an allreduce, check that all tensor shapes + // are identical + if (message_type == MPIRequest::ALLREDUCE) { + TensorShape tensor_shape = requests[0].tensor_shape(); + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + TensorShape request_shape = requests[i].tensor_shape(); + if (tensor_shape != request_shape) { + error = true; + error_message_stream + << "Mismatched allreduce tensor shapes: " + << "One rank reduced a tensor of shape " + << tensor_shape.DebugString() + << ", but another rank sent a tensor of shape " + << request_shape.DebugString() + << "."; + break; + } + } + } + + // If we are doing an allgather, make sure all but the first dimension are + // the same. The first dimension may be different and the output tensor is + // the sum of the first dimension. Collect the sizes by rank. + if (message_type == MPIRequest::ALLGATHER) { + TensorShape tensor_shape = requests[0].tensor_shape(); + + if (tensor_shape.dims() == 0) { + error = true; + error_message_stream + << "Rank zero tried to gather a rank-zero tensor."; + } + + for (unsigned int i = 1; i < requests.size(); i++) { + if (error) { + break; + } + + TensorShape request_shape = requests[i].tensor_shape(); + if (tensor_shape.dims() != request_shape.dims()) { + error = true; + error_message_stream + << "Mismatched allgather tensor shapes: " + << "One rank gathered a tensor of rank " + << tensor_shape.dims() + << ", but another rank sent a tensor of rank " + << request_shape.dims() + << "."; + break; + } + + for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) { + if (tensor_shape.dim_size(dim) != + request_shape.dim_size(dim)) { + error = true; + error_message_stream + << "Mismatched allgather tensor shapes: " + << "One rank gathered a tensor with dimension " + << dim << " equal to " << tensor_shape.dim_size(dim) + << ", but another rank sent a tensor with dimension " + << dim << " equal to " << request_shape.dim_size(dim) + << "."; + break; + } + } + } + } + + MPIResponse response; + response.set_tensor_name(name); + if (error) { + std::string error_message = error_message_stream.str(); + response.set_response_type(MPIResponse::ERROR); + response.set_error_message(error_message); + } else { + auto response_type = MPIResponse::ERROR; + if (message_type == MPIRequest::ALLREDUCE) { + response_type = MPIResponse::ALLREDUCE; + } else { + response_type = MPIResponse::ALLGATHER; + } + response.set_response_type(response_type); + } + + // Clear all queued up requests for this name. They are now taken care of + // by the constructed MPI response. + message_table->erase(it); + + return response; +} + +// Process an MPIResponse by doing a reduction, a gather, or raising an error. +void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) { + OpKernelContext* context; + const Tensor *input_tensor; + std::vector sizes_vec; + Tensor temp_tensor; + Tensor *output_tensor; + CommunicationDoneCallback callback; + bool on_gpu; + { + // Lock on the tensor table. + mutex_lock guard(mpi_global.mu); + + // We should never fail at finding this key in the tensor table. + auto name = response.tensor_name(); + auto iter = tensor_table.find(name); + assert(iter != tensor_table.end()); + + assert(response.response_type() == MPIResponse::ALLREDUCE || + response.response_type() == MPIResponse::ALLGATHER || + response.response_type() == MPIResponse::ERROR); + + CollectiveOpRecord record = iter->second; + context = record.context; + input_tensor = record.in_t; + sizes_vec = record.sizes_vec; + temp_tensor = record.temp_t; + output_tensor = record.out_t; + on_gpu = record.on_gpu; + callback = record.callback; + + // Clear the tensor table of this tensor and its callbacks; the rest of + // this function takes care of it. + tensor_table.erase(iter); + } + + // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't + // link to non-existent symbols. +#if GOOGLE_CUDA +#define GPU_DEVICE_IF_CUDA GPUDevice +#else +#define GPU_DEVICE_IF_CUDA CPUDevice +#endif + + Status status; + auto dtype = input_tensor->dtype(); + if (response.response_type() == MPIResponse::ALLGATHER) { + if (dtype == DT_FLOAT) { + status = on_gpu ? RingAllgather(context, input_tensor, sizes_vec, output_tensor) + : RingAllgather(context, input_tensor, sizes_vec, output_tensor); + } else if (dtype == DT_INT32) { + status = on_gpu ? RingAllgather(context, input_tensor, sizes_vec, output_tensor) + : RingAllgather(context, input_tensor, sizes_vec, output_tensor); + } else if (dtype == DT_INT64) { + status = on_gpu ? RingAllgather(context, input_tensor, sizes_vec, output_tensor) + : RingAllgather(context, input_tensor, sizes_vec, output_tensor); + } else { + status = errors::Unknown("Invalid tensor type for MPI allgather."); + } + } else if (response.response_type() == MPIResponse::ALLREDUCE) { + if (dtype == DT_FLOAT) { + status = on_gpu ? RingAllreduce(context, input_tensor, &temp_tensor, output_tensor) + : RingAllreduce(context, input_tensor, &temp_tensor, output_tensor); + } else if (dtype == DT_INT32) { + status = on_gpu ? RingAllreduce(context, input_tensor, &temp_tensor, output_tensor) + : RingAllreduce(context, input_tensor, &temp_tensor, output_tensor); + } else if (dtype == DT_INT64) { + status = on_gpu ? RingAllreduce(context, input_tensor, &temp_tensor, output_tensor) + : RingAllreduce(context, input_tensor, &temp_tensor, output_tensor); + } else { + status = errors::Unknown("Invalid tensor type for MPI allreduce."); + } + } else if (response.response_type() == MPIResponse::ERROR) { + status = errors::FailedPrecondition(response.error_message()); + } + + if (status.ok()) { + callback(StatusOr(*output_tensor)); + } else { + callback(StatusOr(status)); + } +} + +// The MPI background thread loop coordinates all the MPI processes and the +// tensor reductions. The design of the communicator mechanism is limited by a +// few considerations: +// +// 1. Some MPI implementations require all MPI calls to happen from a +// single thread. Since TensorFlow may use several threads for graph +// processing, this means we must have our own dedicated thread for +// dealing with MPI. +// 2. We want to gracefully handle errors, when MPI processes do not +// properly agree upon what should happen (such as mismatched types or +// shapes). To do so requires the MPI processes to know about the shapes +// and types of the relevant tensors on the other processes. +// 3. The MPI reductions and gathers should be able to happen in parallel +// with other ongoing operations. Since MPI uses an internal +// (inaccessible) GPU stream separate from the TF GPUDevice streams, we +// cannot explicitly synchronize memcpys or kernels with it. As a result, +// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper +// ordering of memcpys and kernels with respect to TF streams. +// 4. NOTE: We cannot guarantee that all the MPI processes reduce their +// tensors in the same order. Thus, there must be a way to ensure the +// reduction memcpys and kernels occur for correct tensors across all +// ranks at the same time. We choose to use a coordinator (rank ID 0) to +// gather and trigger the reduction operations that are ready to execute. +// +// The coordinator currently follows a master-worker paradigm. Rank zero acts +// as the master (the "coordinator"), whereas all other ranks are simply +// workers. Each rank runs its own background thread which progresses in ticks. +// In each tick, the following actions happen: +// +// a) The workers send any available MPIRequests to the coordinator. These +// MPIRequests indicate what the worker would like to do (i.e. which +// tensor they would like to gather or reduce, as well as their shape and +// type). They repeat this for every tensor that they would like to +// operate on after that tensor's collective op has executed ComputeAsync. +// +// b) The workers send an empty "DONE" message to the coordinator to +// indicate that there are no more tensors they wish to operate on. +// +// c) The coordinator receives the MPIRequests from the workers, as well +// as from its own TensorFlow ops, and stores them in a request table. The +// coordinator continues to receive MPIRequest messages until it has +// received MPI_SIZE number of empty "DONE" messages. +// +// d) The coordinator finds all tensors that are ready to be reduced, +// gathered, or all operations that result in an error. For each of those, +// it sends an MPIResponse to all the workers. When no more MPIResponses +// are available, it sends a "DONE" response to the workers. If the +// process is being shutdown, it instead sends a "SHUTDOWN" response. +// +// e) The workers listen for MPIResponse messages, processing each one by +// doing the required reduce or gather, until they receive a "DONE" +// response from the coordinator. At that point, the tick ends. +// If instead of "DONE" they receive "SHUTDOWN", they exit their +// background loop. +// TODO: Use the global mpi_global state variable instead of a local one +void BackgroundThreadLoop() { +#if GOOGLE_CUDA + // Set the device, so that this thread uses the same GPU context as the + // calling thread. + // TODO: Ensure that this is operating correctly. The background thread + // needs to be able to control all GPUs that the rank has access to, and + // might be more than 1 GPU. Tensors could be resident in any of the + // GPUs, so the background thread's accumulate and copy kernels might need + // to correctly set the device and it might be necessary for the background + // thread to manage multiple streams. + cudaSetDevice(mpi_global.device); + cudaStreamCreate(&mpi_global.stream); +#endif + + // Initialize MPI. This must happen on the background thread, since not all + // MPI implementations support being called from multiple threads. + auto init_result = MPI_Init(NULL, NULL); + if (init_result != MPI_SUCCESS) { + mpi_global.init_status = + errors::Unknown("Could not initialize MPI; MPI_Init() failed."); + mpi_global.initialization_done = true; + mpi_global.cv.notify_all(); + return; + } else { + mpi_global.init_status = Status::OK(); + } + + // Get MPI rank to determine if we are rank zero. + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + bool is_coordinator = rank == 0; + + // Get MPI size to determine how many tensors to wait for before reducing. + int size; + MPI_Comm_size(MPI_COMM_WORLD, &size); + + // Determine local rank by querying the local communicator. + MPI_Comm local_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, + MPI_INFO_NULL, &local_comm); + int local_rank; + MPI_Comm_rank(local_comm, &local_rank); + + mpi_global.rank = rank; + mpi_global.local_rank = local_rank; + mpi_global.size = size; + mpi_global.initialization_done = true; + + // Notify calling thread that initialization is complete + mpi_global.cv.notify_all(); + + // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD! + // Initialize the tensor count table. No tensors are available yet. + if (is_coordinator) { + mpi_global.message_table = + std::unique_ptr(new MessageTable()); + } + + // The coordinator sends a SHUTDOWN message to trigger shutdown. + bool should_shut_down = false; + do { + // TODO: Eliminate the need for thread sleep by making all activity + // depend on other activity (e.g. condition or MPI waits). + std::this_thread::sleep_for (std::chrono::milliseconds(1)); + + // Copy the data structures from global state under this lock. + // However, don't keep the lock for the rest of the loop, so that + // enqueued stream callbacks can continue. + std::queue message_queue; + { + mutex_lock guard(mpi_global.mu); + while (!mpi_global.message_queue.empty()) { + MPIRequest message = mpi_global.message_queue.front(); + mpi_global.message_queue.pop(); + message_queue.push(message); + } + } + + // Collect all tensors that are ready to be reduced. Record them in the + // tensor count table (rank zero) or send them to rank zero to be + // recorded (everyone else). + std::vector ready_to_reduce; + while (!message_queue.empty()) { + // Pop the first available message message + MPIRequest message = message_queue.front(); + message_queue.pop(); + + if (is_coordinator) { + bool reduce = IncrementTensorCount(mpi_global.message_table, + message, size); + if (reduce) { + ready_to_reduce.push_back(message.tensor_name()); + } + } else { + std::string encoded_message; + message.SerializeToString(&encoded_message); + MPI_Send(encoded_message.c_str(), encoded_message.length() + 1, + MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + } + } + + // Rank zero has put all its own tensors in the tensor count table. + // Now, it should count all the tensors that are coming from other + // ranks at this tick. It should keep getting tensors until it gets a + // DONE message from all the other ranks. + if (is_coordinator) { + // Count of DONE messages. Keep receiving messages until the number + // of messages is equal to the number of processes. Initialize to + // one since the coordinator is effectively done. + int completed_ranks = 1; + while (completed_ranks != size) { + MPI_Status status; + MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int source_rank = status.MPI_SOURCE; + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // If the length is zero, this is a DONE message. + if (msg_length == 0) { + completed_ranks++; + MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, + MPI_COMM_WORLD, &status); + continue; + } + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, + TAG_NOTIFY, MPI_COMM_WORLD, &status); + std::string received_data(buffer); + delete[] buffer; + + MPIRequest received_message; + received_message.ParseFromString(received_data); + auto received_name = received_message.tensor_name(); + + bool reduce = IncrementTensorCount( + mpi_global.message_table, received_message, size); + if (reduce) { + ready_to_reduce.push_back(received_name); + } + } + + // At this point, rank zero should have a fully updated tensor + // count table and should know all the tensors that need to be + // reduced or gathered, and everyone else should have sent all + // their information to rank zero. We can now do reductions and + // gathers; rank zero will choose which ones and in what order, + // and will notify the other ranks before doing each reduction. + for (int i = 0; i < ready_to_reduce.size(); i++) { + // Notify all nodes which tensor we'd like to reduce now + auto name = ready_to_reduce[i]; + MPIResponse response = ConstructMPIResponse( + mpi_global.message_table, name); + + std::string encoded_response; + response.SerializeToString(&encoded_response); + for (int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), + encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + + // Perform the reduction. All nodes should end up performing + // the same reduction. + PerformCollectiveOp(mpi_global.tensor_table, response); + } + + // Notify all nodes that we are done with the reductions for this + // tick. + MPIResponse done_response; + should_shut_down = mpi_global.shut_down; + done_response.set_response_type(mpi_global.shut_down ? + MPIResponse::SHUTDOWN : MPIResponse::DONE); + std::string encoded_response; + done_response.SerializeToString(&encoded_response); + for (int r = 1; r < size; r++) { + MPI_Send(encoded_response.c_str(), + encoded_response.length() + 1, + MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); + } + } else { + // Notify the coordinator that this node is done sending messages. + // A DONE message is encoded as a zero-length message. + MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); + + // Receive names for tensors to reduce from rank zero. Once we + // receive a empty DONE message, stop waiting for more names. + while (true) { + MPI_Status status; + MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); + + // Find number of characters in message (including zero byte). + int msg_length; + MPI_Get_count(&status, MPI_BYTE, &msg_length); + + // Get tensor name from MPI into an std::string. + char* buffer = new char[msg_length]; + MPI_Recv(buffer, msg_length, MPI_BYTE, 0, + TAG_NOTIFY, MPI_COMM_WORLD, &status); + std::string received_message(buffer); + delete[] buffer; + + MPIResponse response; + response.ParseFromString(received_message); + if (response.response_type() == MPIResponse::DONE) { + // No more messages this tick + break; + } else if (response.response_type() == MPIResponse::SHUTDOWN) { + // No more messages this tick, and the background thread + // should shut down + should_shut_down = true; + break; + } else { + // Process the current message + PerformCollectiveOp(mpi_global.tensor_table, response); + } + } + } + } while (!should_shut_down); + + MPI_Finalize(); +} + +// Initialize MPI and start the MPI background thread. Ensure that this is +// only done once no matter how many times this function is called. +Status InitializeMPIOnce(bool gpu) { + // Ensure MPI is only initialized once. + if (mpi_global.initialized_flag.test_and_set()) + return mpi_global.init_status; + + mpi_global.device = -1; +#if GOOGLE_CUDA + if (gpu) { + cudaGetDevice(&mpi_global.device); + } +#endif + + // Start the MPI background thread, which assumes MPI is initialized + // TODO: Change this to a Tensorflow thread + mpi_global.background_thread = std::thread(BackgroundThreadLoop); + + // Wait to ensure that the background thread has finished initializing MPI + mutex_lock guard(mpi_global.mu); + mpi_global.cv.wait(guard); + if (!mpi_global.initialization_done) { + mpi_global.init_status = + errors::Unknown("Failed to wait for MPI initialization."); + } + + return mpi_global.init_status; +} + +// Check that MPI is initialized. +Status IsMPIInitialized() { + if (!mpi_global.initialization_done) { + return errors::FailedPrecondition( + "MPI has not been initialized; use tf.contrib.mpi.Session."); + } + return Status::OK(); +} + +// This function (called from the callback set up in MPIAll*Op::ComputeAsync) +// only adds the op's record into the local op queue (to track the op's +// progress), and sends a message to the coordinator indicating that this rank +// is ready to begin. The MPI background thread will handle the MPI message. +void EnqueueTensorCollective(CollectiveOpRecord record, + MPIRequest::RequestType rtype) { + const Tensor *input_tensor = record.in_t; + MPIRequest message; + message.set_request_rank(record.rank); + message.set_tensor_name(record.name); + message.set_tensor_type(record.dtype); + message.set_request_type(rtype); + input_tensor->shape().AsProto(message.mutable_tensor_shape()); + + mutex_lock guard(mpi_global.mu); + mpi_global.tensor_table.emplace(record.name, record); + mpi_global.message_queue.push(message); +} + +} + +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI() { + return mpi_global.stream; +} +#endif + +// Op to initialize MPI in the current process. The settings used in the +// configuration are the same that must be used for all future MPI ops. +template +class MPIInitOp : public OpKernel { + public: + explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) { + } + + + void Compute(OpKernelContext* context) override { + bool on_gpu = IsGPUDevice(); + OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU), + MPIInitOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), + MPIInitOp); +#endif + +REGISTER_OP("MPIInit") + .Doc(R"doc( +Initialize MPI for the current process. + +If this is run on a GPU, then that GPU must be used for all future MPI +operations. If it is run on CPU, then all future MPI operations must also +run on CPU. +)doc"); + +// Op to get the current MPI Size. +template +class MPISizeOp : public OpKernel { + public: + explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) { } + + + void Compute(OpKernelContext* context) override { + OP_REQUIRES_OK(context, IsMPIInitialized()); + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, + TensorShape({}), + &output)); + + auto flat = output->flat(); + flat(0) = mpi_global.size; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU), + MPISizeOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), + MPISizeOp); +#endif + +REGISTER_OP("MPISize") + .Output("size: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the number of running MPI processes. + +More precisely, returns the number of MPI processes in the group associated +with the MPI_COMM_WORLD communicator. + +size: Size of the MPI group. +)doc"); + +// Op to get the current MPI Rank. +template +class MPIRankOp : public OpKernel { + public: + explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + OP_REQUIRES_OK(context, IsMPIInitialized()); + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, + TensorShape({}), + &output)); + + auto flat = output->flat(); + flat(0) = mpi_global.rank; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU), + MPIRankOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), + MPIRankOp); +#endif + +REGISTER_OP("MPIRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the MPI group. + +More precisely, returns the rank of the calling process in the MPI_COMM_WORLD +communicator. + +rank: Rank of the calling process. +)doc"); + + +// Op to get the current local MPI Rank. +template +class MPILocalRankOp : public OpKernel { + public: + explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES_OK(context, IsMPIInitialized()); + + // Write integer to output tensor + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, + TensorShape({}), + &output)); + + auto flat = output->flat(); + flat(0) = mpi_global.local_rank; + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU), + MPILocalRankOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPILocalRank") + .Device(DEVICE_GPU) + .HostMemory("rank"), + MPILocalRankOp); +#endif + +REGISTER_OP("MPILocalRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the node it is on. + +More precisely, returns the rank of the calling process in communicator that +only spans the MPI processes running on that node. + +rank: Rank of the calling process on the node it is on. +)doc"); + +template +class MPIAllreduceOp : public AsyncOpKernel { + public: + explicit MPIAllreduceOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + // Although this op is handled asynchronously, the ComputeAsync call is + // very inexpensive. It only sets up a CollectiveOpRecord and places it + // in the table for the background thread to handle. Thus, we do not need + // a TF pool thread to perform the op. + bool IsExpensive() override { return false; } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); + const Tensor *input_tensor = &context->input(0); + Tensor *output_tensor; + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, + input_tensor->shape(), + &output_tensor), + done); + + // Record allocated on stack so op can fail without memory leak + CollectiveOpRecord record; + record.name = name(); + record.context = context; + record.in_t = input_tensor; + record.out_t = output_tensor; + record.on_gpu = IsGPUDevice(); + record.dtype = input_tensor->dtype(); + + const size_t temp_size = + (input_tensor->NumElements() + mpi_global.size - 1) + / mpi_global.size; + TensorShape temp_shape; + temp_shape.AddDim(temp_size); + OP_REQUIRES_OK_ASYNC(context, context->allocate_temp( + input_tensor->dtype(), + temp_shape, &record.temp_t), + done); + + + auto allreduce_done_callback = [done, context](StatusOr status) { + context->SetStatus(status.status()); + done(); + }; + record.callback = allreduce_done_callback; + + auto allreduce_launch_callback = [record] { + EnqueueTensorCollective(record, MPIRequest::ALLREDUCE); + }; + + // If we are on a CPU, our device context will be null and we can't + // get a stream to enqueue this on. On a CPU this op is called when the + // data is already available, so we can just immediately do the + // allreduce; we don't have to wait for the data to get populated. +#if GOOGLE_CUDA + auto device_context = context->op_device_context(); + if (device_context == nullptr) { + allreduce_launch_callback(); + } else { + auto stream = device_context->stream(); + stream->ThenDoHostCallback(allreduce_launch_callback); + } +#else + allreduce_launch_callback(); +#endif + } +}; + +REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU), + MPIAllreduceOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), + MPIAllreduceOp); +#endif + +REGISTER_OP("MPIAllreduce") + .Attr("T: {int32, int64, float32}") + .Input("tensor: T") + .Output("sum: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allreduce on a tensor. All other processes that do a reduction +on a tensor with the same name must have the same dimension for that tensor. +Tensors are reduced with other tensors that have the same node name for the +allreduce. + +Arguments + tensor: A tensor to reduce. + +Output + sum: A tensor with the same shape as `tensor`, summed across all + MPI processes. +)doc"); + +template +class MPIAllgatherOp : public AsyncOpKernel { + public: + explicit MPIAllgatherOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + // Although this op is handled asynchronously, the ComputeAsync call is + // very inexpensive. It only sets up a CollectiveOpRecord and places it + // in the table for the background thread to handle. Thus, we do not need + // a TF pool thread to perform the op. + bool IsExpensive() override { return false; } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); + const Tensor *input_tensor = &context->input(0); + const Tensor *sizing_tensor = &context->input(1); + + // Record allocated on stack so op can fail without memory leak + CollectiveOpRecord record; + record.name = name(); + record.context = context; + record.in_t = input_tensor; + record.on_gpu = IsGPUDevice(); + + // Construct the output size from the sizing tensor + size_t output_first_dim = 0; + if (sizing_tensor->shape().dims() == 0) { + // 0-dim sizing_tensor implies that the op is just gathering + // a single element from each rank + output_first_dim = mpi_global.size; + for (int i = 0; i < mpi_global.size; i++) { + record.sizes_vec.push_back(1); + } + } else { + // Collect the total output tensor sizing from the sizing tensor + // NOTE: The sizing tensor is forced to be placed on the CPU by + // declaring the input as HostMemory, so it is valid to read it here. + const int64 *sizing_array = + (const int64*)sizing_tensor->tensor_data().data(); + for (int i = 0; i < mpi_global.size; i++) { + record.sizes_vec.push_back(sizing_array[i]); + output_first_dim += sizing_array[i]; + } + } + + TensorShape output_shape; + output_shape.AddDim(output_first_dim); + for (int i = 1; i < input_tensor->shape().dims(); i++) { + output_shape.AddDim(input_tensor->shape().dim_size(i)); + } + + Tensor *output_tensor; + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, + output_shape, + &output_tensor), + done); + + record.out_t = output_tensor; + record.dtype = input_tensor->dtype(); + + auto allgather_done_callback = [done, context](StatusOr status) { + context->SetStatus(status.status()); + done(); + }; + record.callback = allgather_done_callback; + + auto allgather_launch_callback = [record] { + EnqueueTensorCollective(record, MPIRequest::ALLGATHER); + }; + + // If we are on a CPU, our device context will be null and we can't + // get a stream to enqueue this on. On a CPU this op is called when the + // data is already available, so we can just immediately do the + // allgather; we don't have to wait for the data to get populated. +#if GOOGLE_CUDA + auto device_context = context->op_device_context(); + if (device_context == nullptr) { + allgather_launch_callback(); + } else { + auto stream = device_context->stream(); + stream->ThenDoHostCallback(allgather_launch_callback); + } +#else + allgather_launch_callback(); +#endif + } +}; + +REGISTER_OP("MPIAllgather") + .Attr("T: {int32, int64, float32}") + .Attr("S: {int64}") + .Input("tensor: T") + .Input("sizes: S") + .Output("gathered: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR(c->ReplaceDim(c->input(0), 0, c->UnknownDim(), + &output)); + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allgather on a tensor. All other processes that do a gather on a +tensor with the same name must have the same rank for that tensor, and have the +same dimension on all but the first dimension. + +Arguments + tensor: A tensor to gather. + sizes: A tensor containing the first-dimension sizes of tensors to be + gathered from other ranks + +Output + gathered: A tensor with the same shape as `tensor` except for the first + dimension, which is the sum of dimensions in `sizes`. +)doc"); + +REGISTER_KERNEL_BUILDER(Name("MPIAllgather") + .Device(DEVICE_CPU) + .HostMemory("sizes"), + MPIAllgatherOp); +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("MPIAllgather") + .Device(DEVICE_GPU) + .HostMemory("sizes"), + MPIAllgatherOp); +#endif + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow + + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py new file mode 100644 index 00000000000..81567cc688a --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_ops.py @@ -0,0 +1,165 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Inter-process communication using MPI.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + + +def _load_library(name, op_list=None): + """Loads a .so file containing the specified operators. + + Args: + name: The name of the .so file to load. + op_list: A list of names of operators that the library should have. If None + then the .so file's contents will not be verified. + + Raises: + NameError if one of the required ops is missing. + """ + try: + filename = resource_loader.get_path_to_datafile(name) + library = load_library.load_op_library(filename) + for expected_op in (op_list or []): + for lib_op in library.OP_LIST.op: + if lib_op.name == expected_op: + break + else: + raise NameError( + 'Could not find operator %s in dynamic library %s' % + (expected_op, name)) + return library + except errors.NotFoundError: + logging.warning('%s file could not be loaded.', name) + + +MPI_LIB = _load_library('mpi_collectives.so', ['MPISize', 'MPIRank', + 'MPILocalRank', 'MPIAllgather', + 'MPIAllreduce']) + + +def size(name=None): + """An op which returns the number of MPI processes. + + This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the + size of the global communicator. + + Returns: + An integer scalar containing the number of MPI processes. + """ + return MPI_LIB.mpi_size(name=name) + + +ops.NotDifferentiable('MPISize') + + +def rank(name=None): + """An op which returns the MPI rank of the calling process. + + This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the + rank of the current process in the global communicator. + + Returns: + An integer scalar with the MPI rank of the calling process. + """ + return MPI_LIB.mpi_rank(name=name) + + +ops.NotDifferentiable('MPIRank') + + +def init(name=None): + """An op which initializes MPI on the device on which it is run. + + All future MPI ops must be run on the same device that the `init` op was run + on. + """ + return MPI_LIB.mpi_init(name=name) + + +ops.NotDifferentiable('MPIInit') + + +def local_rank(name=None): + """An op which returns the local MPI rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + This is equivalent to running `MPI_Comm_rank(...)` on a new communicator + which only includes processes on the same node. + + Returns: + An integer scalar with the local MPI rank of the calling process. + """ + return MPI_LIB.mpi_local_rank(name=name) + + +ops.NotDifferentiable('MPILocalRank') + + +def _allreduce(tensor, name=None): + """An op which sums an input tensor over all the MPI processes. + + The reduction operation is keyed by the name of the op. The tensor type and + shape must be the same on all MPI processes for a given name. The reduction + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, summed across all + processes. + """ + return MPI_LIB.mpi_allreduce(tensor, name=name) + + +ops.NotDifferentiable('MPIAllreduce') + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other MPI processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different MPI processes. + """ + # Specify that first allgather is to collect the tensor gather sizes, + # indicated by passing in a scalar (0-D tensor) of value 0 + sizes_flag = tf.constant(0, dtype=tf.int64, name="size_flag_const") + my_size = tf.slice(tf.shape(tensor, out_type=tf.int64), [0], [1], name="size_slice") + if name is None: + name = "allgather" + sizing_name = "{}_sizing".format(name) + sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name) + return MPI_LIB.mpi_allgather(tensor, sizes, name=name) + + +ops.NotDifferentiable('MPIAllgather') + + diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops_test.py b/tensorflow/contrib/mpi_collectives/mpi_ops_test.py new file mode 100644 index 00000000000..48e5c0a0c70 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_ops_test.py @@ -0,0 +1,296 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for tensorflow.contrib.mpi_collectives.mpi_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import itertools + +import tensorflow as tf + +import tensorflow.contrib.mpi_collectives as mpi + + +def mpi_env_rank_and_size(): + """Get MPI rank and size from environment variables and return them as a + tuple of integers. + + Most MPI implementations have an `mpirun` or `mpiexec` command that will + run an MPI executable and set up all communication necessary between the + different processors. As part of that set up, they will set environment + variables that contain the rank and size of the MPI_COMM_WORLD + communicator. We can read those environment variables from Python in order + to ensure that `mpi.rank()` and `mpi.size()` return the expected values. + + Since MPI is just a standard, not an implementation, implementations + typically choose their own environment variable names. This function tries + to support several different implementation, but really it only needs to + support whatever implementation we want to use for the TensorFlow test + suite. + + If this is not running under MPI, then defaults of rank zero and size one + are returned. (This is appropriate because when you call MPI_Init in an + application not started with mpirun, it will create a new independent + communicator with only one process in it.) + """ + rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split() + size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split() + + for rank_var, size_var in zip(rank_env, size_env): + rank = os.environ.get(rank_var) + size = os.environ.get(size_var) + if rank is not None and size is not None: + return int(rank), int(size) + + # Default to rank zero and size one if there are no environment variables + return 0, 1 + + +class MPITests(tf.test.TestCase): + """ + Tests for MPI ops in tensorflow.contrib.mpi_collectives. + """ + + def test_mpi_rank(self): + """Test that the rank returned by mpi.rank() is correct.""" + true_rank, _ = mpi_env_rank_and_size() + with self.test_session() as session: + rank = session.run(mpi.rank()) + self.assertEqual(true_rank, rank) + + def test_mpi_size(self): + """Test that the size returned by mpi.size() is correct.""" + _, true_size = mpi_env_rank_and_size() + with self.test_session() as session: + size = session.run(mpi.size()) + self.assertEqual(true_size, size) + + def test_mpi_allreduce_cpu(self): + """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" + with self.test_session() as session: + size = session.run(mpi.size()) + + dtypes = [tf.int32, tf.float32] + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) + summed = mpi.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + break + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "mpi.allreduce produces incorrect results") + + def test_mpi_allreduce_gpu(self): + """Test that the allreduce works on GPUs. + + This test will crash badly if used with an MPI implementation that does + not support GPU memory transfers directly, as it will call MPI_Send on + a GPU data pointer.""" + # Only do this test if there are GPUs available. + if not tf.test.is_gpu_available(cuda_only=True): + return + + no_gpus = tf.GPUOptions(visible_device_list="") + cpu_config = tf.ConfigProto(gpu_options=no_gpus) + with self.test_session(config=cpu_config) as session: + local_rank = session.run(mpi.local_rank()) + + one_gpu = tf.GPUOptions(visible_device_list=str(local_rank)) + gpu_config = tf.ConfigProto(gpu_options=one_gpu) + with self.test_session(config=gpu_config) as session: + size = session.run(mpi.size()) + + dtype = tf.float32 + dim = 3 + with tf.device("/gpu:0"): + tf.set_random_seed(1234) + tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) + summed = mpi.allreduce(tensor, average=False) + multiplied = tensor * size + max_difference = tf.reduce_max(tf.abs(summed - multiplied)) + + # Threshold for floating point equality depends on number of + # ranks, since we're comparing against precise multiplication. + if size <= 3: + threshold = 0 + elif size < 10: + threshold = 1e-4 + elif size < 15: + threshold = 5e-4 + else: + return + + diff = session.run(max_difference) + self.assertTrue(diff <= threshold, + "mpi.allreduce on GPU produces incorrect results") + + def test_mpi_allreduce_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different rank or dimension.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + tf.set_random_seed(1234) + dims = [17 + rank] * 3 + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + # Same number of elements, different rank + tf.set_random_seed(1234) + if rank == 0: + dims = [17, 23 * 57] + else: + dims = [17, 23, 57] + tensor = tf.random_uniform(dims, -1.0, 1.0) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + def test_mpi_allreduce_type_error(self): + """Test that the allreduce raises an error if different ranks try to + send tensors of different type.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + # Same rank, different dimension + dims = [17] * 3 + tensor = tf.ones(dims, dtype=tf.int32 if rank % 2 == 0 else tf.float32) + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allreduce(tensor)) + + def test_mpi_allgather(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors.""" + with self.test_session() as session: + size = session.run(mpi.size()) + rank = session.run(mpi.rank()) + + dtypes = tf.int32, tf.float32 + dims = 1, 2, 3 + for dtype, dim in itertools.product(dtypes, dims): + tensor = tf.ones([17] * dim, dtype=dtype) * rank + gathered = mpi.allgather(tensor) + + gathered_tensor = session.run(gathered) + self.assertEqual(list(gathered_tensor.shape), + [17 * size] + [17] * (dim - 1)) + + for i in range(size): + rank_tensor = tf.slice(gathered_tensor, [i * 17] + [0] * (dim - 1), + [17] + [-1] * (dim - 1)) + self.assertEqual(list(rank_tensor.shape), [17] * dim) + self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))), + "mpi.allgather produces incorrect gathered tensor") + + def test_mpi_allgather_variable_size(self): + """Test that the allgather correctly gathers 1D, 2D, 3D tensors, + even if those tensors have different sizes along the first dim.""" + with self.test_session() as session: + size = session.run(mpi.size()) + rank = session.run(mpi.rank()) + + dtypes = tf.int32, tf.float32 + dims = 1, 2, 3 + for dtype, dim in itertools.product(dtypes, dims): + # Support tests up to MPI Size of 35 + if size > 35: + break + + tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5 + tensor_sizes = tensor_sizes[:size] + + tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1), + dtype=dtype) * rank + gathered = mpi.allgather(tensor) + + gathered_tensor = session.run(gathered) + expected_size = sum(tensor_sizes) + self.assertEqual(list(gathered_tensor.shape), + [expected_size] + [17] * (dim - 1)) + + for i in range(size): + rank_size = [tensor_sizes[i]] + [17] * (dim - 1) + rank_tensor = tf.slice(gathered, + [sum(tensor_sizes[:i])] + [0] * (dim - 1), + rank_size) + self.assertEqual(list(rank_tensor.shape), rank_size) + self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))), + "mpi.allgather produces incorrect gathered tensor") + + def test_mpi_allgather_error(self): + """Test that the allgather returns an error if any dimension besides + the first is different among the tensors being gathered.""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + tensor_size = [17] * 3 + tensor_size[1] = 10 * (rank + 1) + tensor = tf.ones(tensor_size, dtype=tf.float32) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allgather(tensor)) + + def test_mpi_allgather_type_error(self): + """Test that the allgather returns an error if the types being gathered + differ among the processes""" + with self.test_session() as session: + rank = session.run(mpi.rank()) + size = session.run(mpi.size()) + + # This test does not apply if there is only one worker. + if size == 1: + return + + tensor_size = [17] * 3 + dtype = tf.int32 if rank % 2 == 0 else tf.float32 + tensor = tf.ones(tensor_size, dtype=dtype) * rank + with self.assertRaises(tf.errors.FailedPreconditionError): + session.run(mpi.allgather(tensor)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc new file mode 100644 index 00000000000..5463f16e7cb --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.cc @@ -0,0 +1,59 @@ +#ifdef TENSORFLOW_USE_MPI + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/mpi_collectives/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); + + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); + +// Copy data on a CPU using a straight-forward memcpy. +template<> void CopyTensorData(void* dst, void* src, size_t size) { + std::memcpy(dst, src, size); +}; + +// Accumulate values on a CPU. +#define GENERATE_ACCUMULATE(type) \ +template<> void AccumulateTensorData( \ + type* dst, type* src, size_t size) { \ + for (unsigned int i = 0; i < size; i++) { \ + dst[i] += src[i]; \ + } \ +}; +GENERATE_ACCUMULATE(int); +GENERATE_ACCUMULATE(long long); +GENERATE_ACCUMULATE(float); +#undef GENERATE_ACCUMULATE + +} +} +} + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc new file mode 100644 index 00000000000..c5dedd6547b --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc @@ -0,0 +1,78 @@ +#ifdef TENSORFLOW_USE_MPI + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/mpi_collectives/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +template<> MPI_Datatype MPIType() { return MPI_FLOAT; }; +template<> MPI_Datatype MPIType() { return MPI_INT; }; +template<> MPI_Datatype MPIType() { return MPI_LONG_LONG; }; + +template<> DataType TensorFlowDataType() { return DT_FLOAT; }; +template<> DataType TensorFlowDataType() { return DT_INT32; }; +template<> DataType TensorFlowDataType() { return DT_INT64; }; + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); +template Status RingAllreduce( + OpKernelContext*, const Tensor*, Tensor*, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); +template Status RingAllgather( + OpKernelContext*, const Tensor*, const std::vector&, Tensor*); + +// Synchronously copy data on the GPU, using a different stream than the default +// and than TensorFlow to avoid synchronizing on operations unrelated to the +// allreduce. +template<> void CopyTensorData(void* dst, void* src, size_t size) { + auto stream = CudaStreamForMPI(); + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); + cudaStreamSynchronize(stream); +}; + +// Elementwise accumulation kernel for GPU. +template +__global__ void elemwise_accum(T* out, const T* in, const size_t N) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; + i < N; + i += blockDim.x * gridDim.x) { + out[i] += in[i]; + } +} + +// Synchronously accumulate tensors on the GPU, using a different stream than +// the default and than TensorFlow to avoid synchronizing on operations +// unrelated to the allreduce. +#define GENERATE_ACCUMULATE(type) \ +template<> void AccumulateTensorData( \ + type* dst, type* src, size_t size) { \ + auto stream = CudaStreamForMPI(); \ + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ + cudaStreamSynchronize(stream); \ +}; +GENERATE_ACCUMULATE(int); +GENERATE_ACCUMULATE(long long); +GENERATE_ACCUMULATE(float); +#undef GENERATE_ACCUMULATE + +} +} +} +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h new file mode 100644 index 00000000000..2bd4903615c --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.h @@ -0,0 +1,312 @@ +#ifndef TENSORFLOW_CONTRIB_MPI_H_ +#define TENSORFLOW_CONTRIB_MPI_H_ + +#ifdef TENSORFLOW_USE_MPI + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA +#include "cuda_runtime.h" +#endif + +// Needed to avoid header issues with C++-supporting MPI implementations +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" + +#define TAG_TENSOR 12 + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +// Convert from templated types to values we can pass to MPI. +template +MPI_Datatype MPIType(); + +// Convert from templated types to TensorFlow data types. +template +DataType TensorFlowDataType(); + +#define MPI_REQUIRES_OK(MPI_STATUS) \ + if ((MPI_STATUS) != MPI_SUCCESS) { \ + return errors::Unknown("MPI operation failed unexpectedly."); \ + } + +// Copy data from one tensor to another tensor. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void CopyTensorData(void* destination, void* source, size_t size); + +// Add a tensor into another tensor, accumulating in place. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void AccumulateTensorData(T* destination, T* source, size_t size); + +// We need to get the right stream for doing CUDA memory transfers and +// operations, which is possibly different from the standard TensorFlow stream. +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI(); +#endif + +/* Perform a ring allreduce on the data. Allocate the necessary output tensor and + * store it in the output parameter. + * + * Assumes that all MPI processes are doing an allreduce of the same tensor, + * with the same dimensions. + * + * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the allreduce, + * the nodes involved are arranged in a ring: + * + * .--0--. + * / \ + * 3 1 + * \ / + * *--2--* + * + * Each node always sends to the next clockwise node in the ring, and receives + * from the previous one. + * + * The allreduce is done in two parts: a scatter-reduce and an allgather. In + * the scatter reduce, a reduction is done, so that each node ends up with a + * chunk of the final output tensor which has contributions from all other + * nodes. In the allgather, those chunks are distributed among all the nodes, + * so that all nodes have the entire output tensor. + * + * Both of these operations are done by dividing the input tensor into N + * evenly sized chunks (where N is the number of nodes in the ring). + * + * The scatter-reduce is done in N-1 steps. In the ith step, node j will send + * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to + * its existing data for that chunk. For example, in the first iteration with + * the ring depicted above, you will have the following transfers: + * + * Segment 0: Node 0 --> Node 1 + * Segment 1: Node 1 --> Node 2 + * Segment 2: Node 2 --> Node 3 + * Segment 3: Node 3 --> Node 0 + * + * In the second iteration, you'll have the following transfers: + * + * Segment 0: Node 1 --> Node 2 + * Segment 1: Node 2 --> Node 3 + * Segment 2: Node 3 --> Node 0 + * Segment 3: Node 0 --> Node 1 + * + * After this iteration, Node 2 has 3 of the four contributions to Segment 0. + * The last iteration has the following transfers: + * + * Segment 0: Node 2 --> Node 3 + * Segment 1: Node 3 --> Node 0 + * Segment 2: Node 0 --> Node 1 + * Segment 3: Node 1 --> Node 2 + * + * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 + * has the fully accumulated Segment 1; and so on. The scatter-reduce is complete. + * + * Next, the allgather distributes these fully accumululated chunks across all nodes. + * Communication proceeds in the same ring, once again in N-1 steps. At the ith step, + * node j will send chunk (j - i + 1) and receive chunk (j - i). For example, at the + * first iteration, the following transfers will occur: + * + * Segment 0: Node 3 --> Node 0 + * Segment 1: Node 0 --> Node 1 + * Segment 2: Node 1 --> Node 2 + * Segment 3: Node 2 --> Node 3 + * + * After the first iteration, Node 0 will have a fully accumulated Segment 0 + * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its + * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. + * After this has continued for N - 1 iterations, all nodes will have a the fully + * accumulated tensor. + * + * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the allgather. + * Each send will contain K / N bytes, if there are K bytes in the original tensor on every node. + * Thus, each node sends and receives 2K(N - 1)/N bytes of data, and the performance of the allreduce + * (assuming no latency in connections) is constrained by the slowest interconnect between the nodes. + * + */ +template +Status RingAllreduce(OpKernelContext* context, const Tensor* input, + Tensor* temp, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + T* buffer = (T*) output->tensor_data().data(); + + CopyTensorData((void*) buffer, + (void*) input->tensor_data().data(), + output->tensor_data().size()); + + // Calculate segment sizes and segment ends + const size_t elements_to_reduce = input->NumElements(); + const size_t segment_size = elements_to_reduce / n; + std::vector segment_sizes(n, segment_size); + + const size_t residual = elements_to_reduce % n; + for (size_t i = 0; i < residual; ++i) { + segment_sizes[i]++; + } + + std::vector segment_starts(n); + segment_starts[0] = 0; + for (size_t i = 1; i < segment_starts.size(); ++i) { + segment_starts[i] = segment_starts[i-1] + segment_sizes[i-1]; + } + + assert(segment_starts[n-1] + segment_sizes[n-1] == elements_to_reduce); + + T* segment_recv = (T*) temp->tensor_data().data(); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + MPI_Status recv_status; + MPI_Request recv_req; + + // Now start ring. At every step, for every rank, we iterate through + // segments with wraparound and send and recv from our neighbors and reduce + // locally. At the i'th iteration, rank r, sends segment (r-i) and receives + // segment (r-i-1). + for (int i = 0; i < n - 1; i++) { + const size_t send_seg_id = ((r-i) + n) % n; + const size_t recv_seg_id = ((r-i-1) + n) % n; + + T* segment_send = &(buffer[segment_starts[send_seg_id]]); + + MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id], + MPIType(), recv_from, TAG_TENSOR, + MPI_COMM_WORLD, &recv_req)); + + MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id], + MPIType(), send_to, TAG_TENSOR, + MPI_COMM_WORLD)); + + T *segment_update = &(buffer[segment_starts[recv_seg_id]]); + + // Wait for recv to complete before reduction + MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status)); + + const size_t recv_seg_size = segment_sizes[recv_seg_id]; + AccumulateTensorData( + segment_update, segment_recv, recv_seg_size); + } + + // Now start pipelined ring allgather. At every step, for every rank, we + // iterate through segments with wraparound and send and recv from our + // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and + // receives segment (r-i). + for (size_t i = 0; i < n - 1; ++i) { + const size_t send_seg_id = ((r-i+1) + n) % n; + const size_t recv_seg_id = ((r-i) + n) % n; + + // Segment to send - at every iteration we send segment (r-i+1) + T* segment_send = &(buffer[segment_starts[send_seg_id]]); + + // Segment to recv - at every iteration we receive segment (r-i) + T* segment_recv = &(buffer[segment_starts[recv_seg_id]]); + + MPI_REQUIRES_OK(MPI_Sendrecv(segment_send, segment_sizes[send_seg_id], + MPIType(), send_to, TAG_TENSOR, segment_recv, + segment_sizes[recv_seg_id], MPIType(), recv_from, + TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +// Perform a ring allgather on a Tensor. Other ranks may allgather with a +// tensor which differs in the first dimension only; all other dimensions must +// be the same. +// +// For more information on the ring allgather, read the documentation for the +// ring allreduce, which includes a ring allgather. +template +Status RingAllgather(OpKernelContext* context, const Tensor* input, + const std::vector& sizes, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + assert(sizes.size() == n); + assert(input->dim_size(0) == sizes[r]); + + // Compute number of elements in every "row". We can't compute number of + // elements in every chunks, because those chunks are variable length. + size_t elements_per_row = 1; + for (int i = 1; i < input->shape().dims(); i++) { + elements_per_row *= input->dim_size(i); + } + + // Copy data from input tensor to correct place in output tensor. + std::vector segment_starts(n); + segment_starts[0] = 0; + for (int i = 1; i < n; i++) { + segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; + } + size_t offset = segment_starts[r]; + + // Copy data to the right offset for this rank. + T* buffer = (T*) output->tensor_data().data(); + CopyTensorData((void*) (buffer + offset), + (void*) input->tensor_data().data(), + elements_per_row * sizes[r] * sizeof(T)); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + // Perform a ring allgather. At every step, for every rank, we iterate + // through segments with wraparound and send and recv from our neighbors. + // At the i'th iteration, rank r, sends segment (r-i) and receives segment + // (r-1-i). + MPI_Status recv_status; + for (size_t i = 0; i < n - 1; ++i) { + const size_t send_seg_id = ((r-i) + n) % n; + const size_t recv_seg_id = ((r-i-1) + n) % n; + + // Segment to send - at every iteration we send segment (r-i) + size_t offset_send = segment_starts[send_seg_id]; + size_t rows_send = sizes[send_seg_id]; + T* segment_send = &(buffer[offset_send]); + + // Segment to recv - at every iteration we receive segment (r-1-i) + size_t offset_recv = segment_starts[recv_seg_id]; + size_t rows_recv = sizes[recv_seg_id]; + T* segment_recv = &(buffer[offset_recv]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, elements_per_row * rows_send, + MPIType(), send_to, TAG_TENSOR, segment_recv, + elements_per_row * rows_recv, MPIType(), recv_from, + TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +} +} +} + +#endif // TENSORFLOW_USE_MPI + +#undef TENSORFLOW_CONTRIB_MPI_H_ +#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.mpi b/tensorflow/tools/ci_build/Dockerfile.cpu.mpi new file mode 100644 index 00000000000..2bf7fd1d234 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.mpi @@ -0,0 +1,24 @@ +FROM ubuntu:14.04 + +LABEL authors="Andrew Gibiansky , Joel Hestness " + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa && \ + add-apt-repository -y ppa:mc3man/trusty-media && \ + add-apt-repository -y ppa:george-edison55/cmake-3.x +RUN /install/install_deb_packages.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel.sh +RUN /install/install_proto3.sh +RUN /install/install_buildifier.sh +RUN /install/install_mpi.sh + +# Set up bazelrc. +COPY install/.bazelrc /root/.bazelrc +ENV BAZELRC /root/.bazelrc + +# Set up MPI +ENV TF_NEED_MPI 1 +ENV MPI_HOME /usr/lib/openmpi diff --git a/tensorflow/tools/ci_build/install/install_mpi.sh b/tensorflow/tools/ci_build/install/install_mpi.sh new file mode 100755 index 00000000000..6ee9d765949 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_mpi.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set +e +mpiexec=$(which mpiexec) +if [[ -z "$mpiexec_location" ]]; then + # Install dependencies from ubuntu deb repository. + apt-get update + apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev +fi diff --git a/third_party/mpi/.gitignore b/third_party/mpi/.gitignore new file mode 100644 index 00000000000..ab011617a6d --- /dev/null +++ b/third_party/mpi/.gitignore @@ -0,0 +1,3 @@ +*.h +*.dylib +*.so diff --git a/third_party/mpi_collectives/BUILD b/third_party/mpi_collectives/BUILD new file mode 100644 index 00000000000..cbae0243126 --- /dev/null +++ b/third_party/mpi_collectives/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE.txt"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "mpi", + srcs = select({ + "//tensorflow:darwin": ["libmpi.dylib"], + "//conditions:default": ["libmpi.so"], + }), + hdrs = ["mpi.h", "mpi_portable_platform.h"], +)