Make NcclManager part of CollectiveExecutorMgr

Collective NCCL will then use the NcclManager from CollectiveExecutorMgr.
Non-collective NCCL (nccl_ops) are not affected.

We're adding abortion to NcclManager. After abortion, the NcclManager cannot be
used again. Having a NcclManager per CollectiveExecutorMgr allows us to "reset"
it by resetting EagerContext. This aligns with RING collectives.

Now this is mostly useful in tests, since we don't expose the ability to reset
EagerContext in public API.

This change adds an abstraction (NcclCommunicatorInterface) to bridge collectives and NcclManager. In this we don't need to link NcclManager into core_cpu, which proves to be difficult and weird.

PiperOrigin-RevId: 332315169
Change-Id: I936f6489b43c9de0b3f6a3cf5a14f89e79689d5e
This commit is contained in:
Ran Chen 2020-09-17 14:21:45 -07:00 committed by TensorFlower Gardener
parent 7cfe00cf3a
commit 57f8b71009
34 changed files with 353 additions and 260 deletions

View File

@ -231,10 +231,10 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Aborts all ongoing collectives with the specified status. After abortion,
// subsequent collectives will error with this status immediately.
// subsequent collectives will error with this status immediately. To reset the
// collectives, create a new EagerContext.
//
// This is intended to be used when a peer failure is detected. There's yet no
// way to reset the collectives other than restarting the program.
// This is intended to be used when a peer failure is detected.
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status);

View File

@ -1164,7 +1164,7 @@ cc_library(
)
# Test support library needed for higher-level (TensorFlow-specific) tests
cc_library(
tf_cuda_library(
name = "testlib",
testonly = 1,
srcs = [
@ -1295,6 +1295,7 @@ filegroup(
"//tensorflow/core/graph:mobile_srcs_only_runtime",
"//tensorflow/core/kernels:mobile_srcs",
"//tensorflow/core/lib/io:mobile_srcs_only_runtime",
"//tensorflow/core/nccl:mobile_srcs",
"//tensorflow/core/profiler:mobile_srcs",
"//tensorflow/core/public:mobile_srcs_only_runtime",
"//tensorflow/core/util/sparse:mobile_srcs_only_runtime",

View File

@ -1739,6 +1739,7 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/debug:debug_graph_utils",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/nccl:collective_communicator",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:profiler_backends",
"//tensorflow/core/profiler/lib:profiler_session",
@ -1868,6 +1869,7 @@ tf_cc_tests(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/nccl:collective_communicator",
"//tensorflow/core/platform:regexp",
"//tensorflow/core/util:protos_test_cc",
"//third_party/eigen3",

View File

@ -270,8 +270,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
}
core::ScopedUnref unref(col_impl);
auto col_ctx = std::make_shared<CollectiveContext>(
this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
input, output);
this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
col_params, exec_key, step_id_, input, output);
status = col_impl->InitializeCollectiveContext(col_ctx);
if (!status.ok()) {
done_safe(status);

View File

@ -26,12 +26,14 @@ namespace tensorflow {
CollectiveExecutorMgr::CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
std::unique_ptr<DeviceResolverInterface> dev_resolver,
std::unique_ptr<ParamResolverInterface> param_resolver)
std::unique_ptr<ParamResolverInterface> param_resolver,
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator)
: dev_mgr_(dev_mgr),
dev_resolver_(std::move(dev_resolver)),
param_resolver_(std::move(param_resolver)),
gpu_ring_order_(
config.gpu_options().experimental().collective_ring_order()),
nccl_communicator_(std::move(nccl_communicator)),
work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(),
"collective_ops")) {}

View File

@ -22,12 +22,15 @@ limitations under the License.
namespace tensorflow {
class ConfigProto;
class DeviceMgr;
class NcclManager;
class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
public:
CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
std::unique_ptr<DeviceResolverInterface> dev_resolver,
std::unique_ptr<ParamResolverInterface> param_resolver);
CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
std::unique_ptr<DeviceResolverInterface> dev_resolver,
std::unique_ptr<ParamResolverInterface> param_resolver,
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator);
virtual ~CollectiveExecutorMgr();
@ -43,6 +46,10 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
return dev_resolver_.get();
}
NcclCommunicatorInterface* GetNcclCommunicator() const override {
return nccl_communicator_.get();
}
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response,
const StatusCallback& done) override;
@ -64,6 +71,7 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
std::unique_ptr<DeviceResolverInterface> dev_resolver_;
std::unique_ptr<ParamResolverInterface> param_resolver_;
string gpu_ring_order_;
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator_;
// Unbounded work queue for scheduling potentially-blocking work during
// collective op execution. Ownership is shared between `this` and
// `CollectiveRemoteAccessLocal`.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/nccl/collective_communicator.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
@ -47,7 +48,8 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
new CollectiveParamResolverLocal(cp, device_mgr_.get(), drl.get(),
task_name));
cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
std::move(prl)));
std::move(prl),
MaybeCreateNcclCommunicator()));
}
std::unique_ptr<CollectiveExecutorMgr> cme_;

View File

@ -65,6 +65,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/nccl/collective_communicator.h"
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
@ -554,7 +555,8 @@ Status DirectSession::RunInternal(
drl.get(),
"/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
options_.config, device_mgr_.get(), std::move(drl), std::move(cprl),
MaybeCreateNcclCommunicator()));
}
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));

View File

@ -82,6 +82,7 @@ tf_cuda_library(
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/nccl:collective_communicator",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/nccl/collective_communicator.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform.h"
// clang-format on
@ -120,7 +121,7 @@ EagerContext::EagerContext(
"/job:localhost/replica:0/task:0"));
collective_executor_mgr_.Reset(
new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
std::move(cprl)),
std::move(cprl), MaybeCreateNcclCommunicator()),
/*owned=*/true);
}

View File

@ -674,8 +674,9 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
new HierarchicalTreeBroadcaster;
core::ScopedUnref unref(broadcaster);
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr);
parent_->col_exec_, /*nccl_communicator*/ nullptr,
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
kStepId, input_tensor_ptr, output_tensor_ptr);
TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
// Run the broadcast.

View File

@ -388,8 +388,9 @@ class PermuterTest : public ::testing::Test {
Permuter* permuter = new Permuter;
core::ScopedUnref unref(permuter);
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, &tensor_input_, &tensor_output_);
parent_->col_exec_, /*nccl_communicator*/ nullptr,
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
kStepId, &tensor_input_, &tensor_output_);
TF_CHECK_OK(permuter->InitializeCollectiveContext(col_ctx));
Notification note;
// Run the permute.

View File

@ -480,8 +480,9 @@ class RingGathererTest : public ::testing::Test {
RingGatherer* gatherer = new RingGatherer;
core::ScopedUnref unref(gatherer);
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr);
parent_->col_exec_, /*nccl_communicator*/ nullptr,
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
kStepId, &input_tensor_, output_tensor_ptr);
TF_CHECK_OK(gatherer->InitializeCollectiveContext(col_ctx));
// Run the all-gather.

View File

@ -510,8 +510,9 @@ class RingReducerTest : public ::testing::Test {
RingReducer* reducer = new RingReducer;
core::ScopedUnref unref(reducer);
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
col_params_, exec_key, kStepId, &tensor_, &tensor_);
parent_->col_exec_, /*nccl_communicator*/ nullptr,
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
kStepId, &tensor_, &tensor_);
TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
// Run the all-reduce.

View File

@ -101,6 +101,11 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
return nullptr;
}
NcclCommunicatorInterface* GetNcclCommunicator() const override {
LOG(FATAL) << "Unimplemented"; // Crash OK
return nullptr;
}
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response,
const StatusCallback& done) override {

View File

@ -522,6 +522,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/nccl:collective_communicator",
],
)

View File

@ -316,6 +316,7 @@ cc_library(
":grpc_worker_cache",
":grpc_worker_service",
":rpc_rendezvous_mgr",
"//tensorflow/core/nccl:collective_communicator",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

View File

@ -47,6 +47,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/nccl/collective_communicator.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mem.h"
@ -286,7 +287,8 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
default_worker_name));
worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
config, worker_env_.device_mgr, std::move(dev_resolver),
std::move(param_resolver), worker_cache, default_worker_name));
std::move(param_resolver), MaybeCreateNcclCommunicator(), worker_cache,
default_worker_name));
}
// Set up worker environment.
@ -454,8 +456,8 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
dev_resolver.get(), worker_cache, default_worker_name));
worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
server_def_.default_session_config(), worker_env_.device_mgr,
std::move(dev_resolver), std::move(param_resolver), worker_cache,
default_worker_name));
std::move(dev_resolver), std::move(param_resolver),
MaybeCreateNcclCommunicator(), worker_cache, default_worker_name));
master_env_.worker_cache = worker_cache;
master_env_.collective_executor_mgr =

View File

@ -29,9 +29,11 @@ RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
std::unique_ptr<DeviceResolverDistributed> dev_resolver,
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
WorkerCacheInterface* worker_cache, const string& task_name)
: CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
std::move(param_resolver)),
std::move(param_resolver),
std::move(nccl_communicator)),
worker_cache_(worker_cache),
task_name_(task_name) {
group_leader_ = (task_name == config.experimental().collective_group_leader())

View File

@ -38,6 +38,7 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
const ConfigProto& config, const DeviceMgr* dev_mgr,
std::unique_ptr<DeviceResolverDistributed> dev_resolver,
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
WorkerCacheInterface* worker_cache, const string& task_name);
virtual ~RpcCollectiveExecutorMgr();

View File

@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include <stdlib.h>
#include <string>
#include <vector>
@ -21,10 +24,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/nccl/collective_communicator.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
@ -52,9 +55,9 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
device_mgr_.get(), dr.get(),
worker_cache, task_name));
// This CME is the group leader.
cme_.reset(new RpcCollectiveExecutorMgr(options.config, device_mgr_.get(),
std::move(dr), std::move(cpr),
worker_cache, task_name));
cme_.reset(new RpcCollectiveExecutorMgr(
options.config, device_mgr_.get(), std::move(dr), std::move(cpr),
MaybeCreateNcclCommunicator(), worker_cache, task_name));
}
std::unique_ptr<RpcCollectiveExecutorMgr> cme_;

View File

@ -158,14 +158,13 @@ string CollectiveParams::ToString() const {
return ctx->params_;
}
CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
const DeviceMgr* dev_mgr,
OpKernelContext* ctx,
OpKernelContext::Params* op_params,
const CollectiveParams& col_params,
const string& exec_key, int64 step_id,
const Tensor* input, Tensor* output)
CollectiveContext::CollectiveContext(
CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
OpKernelContext::Params* op_params, const CollectiveParams& col_params,
const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
: col_exec(col_exec),
nccl_communicator(nccl_communicator),
dev_mgr(dev_mgr),
op_ctx(ctx),
op_params(op_params),

View File

@ -36,6 +36,7 @@ class Device;
class DeviceMgr;
class GetStepSequenceRequest;
class GetStepSequenceResponse;
class NcclManager;
class Tensor;
// Types of supported collective operations.
@ -229,6 +230,8 @@ class StepSequenceInterface {
virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
};
class NcclCommunicatorInterface;
// Interface that provides access to per-step CollectiveExecutor
// instances and various distributed resolution capabilities.
class CollectiveExecutorMgrInterface : public StepSequenceInterface {
@ -246,6 +249,8 @@ class CollectiveExecutorMgrInterface : public StepSequenceInterface {
virtual ParamResolverInterface* GetParamResolver() const = 0;
virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0;
};
// Interface that a Collective Op implementation uses to exchange data
@ -354,19 +359,12 @@ class CollectiveExecutor : public core::RefCounted {
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
};
class CollectiveContext {
public:
CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
OpKernelContext* ctx, OpKernelContext::Params* op_params,
const CollectiveParams& col_params, const string& exec_key,
int64 step_id, const Tensor* input, Tensor* output);
virtual ~CollectiveContext() = default;
CollectiveExecutor* col_exec; // Not owned
const DeviceMgr* dev_mgr; // Not owned
OpKernelContext* op_ctx; // Not owned
OpKernelContext::Params* op_params; // Not owned
struct CollectiveContext {
CollectiveExecutor* col_exec; // Not owned
NcclCommunicatorInterface* nccl_communicator; // Not owned
const DeviceMgr* dev_mgr; // Not owned
OpKernelContext* op_ctx; // Not owned
OpKernelContext::Params* op_params; // Not owned
const CollectiveParams& col_params;
const string exec_key;
const int64 step_id;
@ -375,6 +373,23 @@ class CollectiveContext {
Device* device; // The device for which this instance labors
const string device_name;
DeviceLocality device_locality;
CollectiveContext(CollectiveExecutor* col_exec,
NcclCommunicatorInterface* nccl_communicator,
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
OpKernelContext::Params* op_params,
const CollectiveParams& col_params, const string& exec_key,
int64 step_id, const Tensor* input, Tensor* output);
};
class NcclCommunicatorInterface {
public:
virtual ~NcclCommunicatorInterface() = default;
virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
StatusCallback done) = 0;
virtual void StartAbort(const Status& s) = 0;
};
// Interface of a Collective Op implementation. Each specific CollectiveOp will

View File

@ -200,17 +200,6 @@ tf_cc_test(
],
)
# virtual targets since nested select statements not possible
tf_kernel_library(
name = "virtual_nccl",
deps = if_cuda(["@local_config_nccl//:nccl"]),
)
tf_kernel_library(
name = "virtual_rccl",
deps = if_rocm(["@local_config_rocm//rocm:rccl"]),
)
tf_kernel_library(
name = "collective_ops",
srcs = if_nccl([
@ -228,11 +217,10 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core/profiler/lib:traceme",
] + if_nccl([
":virtual_nccl",
":virtual_rccl",
"//tensorflow/core/nccl:nccl_lib",
"//tensorflow/core/nccl:collective_communicator",
]),
)
@ -255,6 +243,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/nccl:collective_communicator",
],
)

View File

@ -74,10 +74,6 @@ Status NcclBase::InitializeCollectiveGroupRuntimeDetails(
return Status::OK();
}
const string NcclBase::NcclCollectiveKey(const string& exec_key, int step_id) {
return strings::StrCat(exec_key, ":", step_id);
}
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -37,8 +37,6 @@ class NcclBase : public CollectiveImplementationInterface {
CollGroupRuntimeDetails* col_group_runtime_details) override;
protected:
const string NcclCollectiveKey(const string& exec_key, int step_id);
const CollectiveType type_;
const string name_;
std::shared_ptr<CollectiveContext> col_ctx_;

View File

@ -24,55 +24,7 @@ limitations under the License.
namespace tensorflow {
void NcclBroadcaster::Run(StatusCallback done) {
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
const int num_global_devices = col_params_->group.group_size;
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
col_params_->instance.task_names[col_params_->default_rank]);
string nccl_collective_key =
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input,
col_ctx_->output, col_params_->default_rank, std::move(done));
VLOG(1)
<< "NcclBroadcast calling NcclManager::AddBroadcastSend/Recv num_tasks "
<< col_params_->group.num_tasks << " current task "
<< col_params_->instance.task_names[col_params_->default_rank]
<< " num local devices " << num_local_devices << " num global devices "
<< num_global_devices << " rank " << col_params_->default_rank
<< " device " << col_ctx_->device_name << " instance "
<< col_params_->instance.instance_key << " source "
<< col_params_->is_source;
if (col_params_->is_source) {
NcclManager::instance()->AddBroadcastSend(
std::move(participant),
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
col_params_->group.runtime_details.communicator_key,
col_params_->source_rank});
} else {
NcclManager::instance()->AddBroadcastRecv(
std::move(participant),
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
col_params_->group.runtime_details.communicator_key,
col_params_->source_rank});
}
{
// `WaitForDependencies` may block if the collective instances on which this
// op depends have not yet launched. When this function returns, this op is
// ready to go.
profiler::TraceMe activity("WaitForDependencies",
profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->WaitForDependencies(*col_params_);
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
}
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->UnblockDependencies(*col_params_);
}
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done));
}
REGISTER_COLLECTIVE(NcclBroadcast, NcclBroadcaster);

View File

@ -24,45 +24,7 @@ limitations under the License.
namespace tensorflow {
void NcclGatherer::Run(StatusCallback done) {
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
const int num_global_devices = col_params_->group.group_size;
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
col_params_->instance.task_names[col_params_->default_rank]);
string nccl_collective_key =
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input,
col_ctx_->output, col_params_->default_rank, std::move(done));
VLOG(1) << "NcclGatherer calling NcclManager::AddToAllGather num_tasks "
<< col_params_->group.num_tasks << " current task "
<< col_params_->instance.task_names[col_params_->default_rank]
<< " num local devices " << num_local_devices
<< " num global devices " << num_global_devices << " rank "
<< col_params_->default_rank << " device " << col_ctx_->device_name
<< " instance " << col_params_->instance.instance_key;
NcclManager::instance()->AddToAllGather(
std::move(participant),
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
col_params_->group.runtime_details.communicator_key,
/*source_rank=*/-1});
{
// `WaitForDependencies` may block if the collective instances on which this
// op depends have not yet launched. When this function returns, this op is
// ready to go.
profiler::TraceMe activity("WaitForDependencies",
profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->WaitForDependencies(*col_params_);
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
}
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->UnblockDependencies(*col_params_);
}
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done));
}
REGISTER_COLLECTIVE(NcclGather, NcclGatherer);

View File

@ -23,36 +23,7 @@ limitations under the License.
namespace tensorflow {
namespace {
Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
if (merge_op == "Add") {
*reduction_op = ncclSum;
return Status::OK();
} else if (merge_op == "Mul") {
*reduction_op = ncclProd;
return Status::OK();
} else if (merge_op == "Maximum") {
*reduction_op = ncclMax;
return Status::OK();
} else if (merge_op == "Minimum") {
*reduction_op = ncclMin;
return Status::OK();
} else {
return errors::Internal(
"Expected merge_op to be in [Add, Mul, Maximum, Minimum], found ",
merge_op);
}
}
} // namespace
void NcclReducer::Run(StatusCallback done) {
ncclRedOp_t reduction_op;
Status s = ReductionOp(col_params_->merge_op->type_string(), &reduction_op);
if (!s.ok()) {
done(s);
return;
}
Tensor group_size;
std::unique_ptr<Notification> group_size_ready;
Status group_size_status;
@ -117,73 +88,7 @@ void NcclReducer::Run(StatusCallback done) {
} else {
done_callback = std::move(done);
}
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
// `AddToAllReduce` performs consistency checks for the NCCL call and enqueues
// the `Participant` struct locally. When all local participants with this
// `nccl_collective_key` have called `AddToAllReduce` and
// `SignalMultiNodeReady`, all devices at this worker are ready to process
// this NCCL op.
//
// The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this
// point, it synchronizes the NCCL stream with the compute stream, and then
// enqueues the NCCL kernel on the NCCL stream.
const int num_global_devices = col_params_->group.group_size;
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
col_params_->instance.task_names[col_params_->default_rank]);
const string nccl_collective_key =
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input,
col_ctx_->output, col_params_->default_rank, std::move(done_callback));
VLOG(1) << "NcclReducer calling NcclManager::AddToAllReduce num_tasks "
<< col_params_->group.num_tasks << " current task "
<< col_params_->instance.task_names[col_params_->default_rank]
<< " num local devices " << num_local_devices
<< " num global devices " << num_global_devices << " device "
<< col_ctx_->device_name << " instance "
<< col_params_->instance.instance_key;
NcclManager::instance()->AddToAllReduce(
std::move(participant),
{nccl_collective_key, num_local_devices, num_global_devices,
col_params_->group.runtime_details.communicator_key, /*source_rank=*/-1},
reduction_op);
// NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
// deadlocks. In the current implementation, we define a deterministic
// sequential launch order between potentially concurrent collective instances
// by introducing control information during static graph analysis in
// graph/collective_order.cc. This can be either in the form of explicit
// control edges or via `wait_for` attribute on the collective op.
//
// The other end of the design spectrum would have a distinguished node
// dynamically signal the next collective to launch to all other participants.
// This has higher degree of runtime coordination, but it may be able to
// achieve better performance if the (arbitrary) static execution order
// assigned in the first approach turns out to not be good from a scheduling
// perspective. e.g. consider a graph in which c1, c2, and c3 are three
// concurrent collective instances, and the static ordering assigns c1 -> c2
// -> c3. In practice, it could turn out that c3 is always ready to execute
// before c1 or c2.
{
// `WaitForDependencies` may block if the collective instances on which this
// op depends have not yet launched. When this function returns, this op is
// ready to go.
profiler::TraceMe activity("WaitForDependencies",
profiler::TraceMeLevel::kInfo);
// TODO(b/80529858): make this entirely non-blocking by converting
// `WaitForDependencies` to async function.
col_ctx_->col_exec->WaitForDependencies(*col_params_);
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
}
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->UnblockDependencies(*col_params_);
}
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done_callback));
// If no final_op, then this OpKernel is non-blocking.
if (!col_params_->final_op) {

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/nccl/collective_communicator.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/unbounded_work_queue.h"
@ -84,6 +85,7 @@ class NcclTestBase : public ::testing::Test {
NcclTestBase(CollectiveType collective_type, const string& collective_name)
: collective_type_(collective_type),
collective_name_(collective_name),
nccl_communicator_(MaybeCreateNcclCommunicator()),
work_queue_(std::make_shared<UnboundedWorkQueue>(
Env::Default(), "collective_executor")),
col_exec_(nullptr) {}
@ -318,7 +320,8 @@ class NcclTestBase : public ::testing::Test {
strings::StrCat(col_params_.instance.instance_key, ":0:0");
auto* reducer = new NcclReducer();
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(),
parent_->col_exec_, parent_->nccl_communicator_.get(),
parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/&input_, /*output=*/&input_);
TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
@ -349,7 +352,8 @@ class NcclTestBase : public ::testing::Test {
strings::StrCat(col_params_.instance.instance_key, ":0:0");
auto* broadcaster = new NcclBroadcaster();
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(),
parent_->col_exec_, parent_->nccl_communicator_.get(),
parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/col_params_.is_source ? &input_ : nullptr,
/*output=*/&input_);
@ -389,7 +393,8 @@ class NcclTestBase : public ::testing::Test {
strings::StrCat(col_params_.instance.instance_key, ":0:0");
auto* gatherer = new NcclGatherer();
auto col_ctx = std::make_shared<CollectiveContext>(
parent_->col_exec_, parent_->dev_mgr_.get(),
parent_->col_exec_, parent_->nccl_communicator_.get(),
parent_->dev_mgr_.get(),
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
/*input=*/&input_,
/*output=*/&output_);
@ -419,6 +424,7 @@ class NcclTestBase : public ::testing::Test {
const string collective_name_;
std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
TestCollectiveExecutorMgr col_exec_mgr_;
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
CollectiveExecutor* col_exec_;
std::unique_ptr<DeviceMgr> dev_mgr_;

View File

@ -3,15 +3,17 @@
# APIs are meant to change over time.
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm", "tf_copts")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm")
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
)
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "if_nccl")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
@ -77,3 +79,28 @@ tf_cuda_cc_test(
"//tensorflow/core/common_runtime/gpu:rocm",
]),
)
cc_library(
name = "collective_communicator",
srcs = ["collective_communicator.cc"],
hdrs = ["collective_communicator.h"],
copts = tf_copts() + if_nccl(["-DTENSORFLOW_USE_NCCL=1"]),
visibility = [
"//learning/brain/runtime:__subpackages__",
"//tensorflow:__subpackages__",
],
deps =
["//tensorflow/core:framework"] + if_nccl([
":nccl_lib",
"@com_google_absl//absl/memory",
"//tensorflow/core/profiler/lib:traceme",
]),
)
filegroup(
name = "mobile_srcs",
srcs = [
"collective_communicator.cc",
"collective_communicator.h",
],
)

View File

@ -0,0 +1,178 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/nccl/collective_communicator.h"
#if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
#include "absl/memory/memory.h"
#include "tensorflow/core/nccl/nccl_manager.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
class NcclCommunicator : public NcclCommunicatorInterface {
public:
void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
StatusCallback done) override;
void StartAbort(const Status& s) override;
private:
NcclManager nccl_manager_;
};
namespace {
Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
if (merge_op == "Add") {
*reduction_op = ncclSum;
return Status::OK();
} else if (merge_op == "Mul") {
*reduction_op = ncclProd;
return Status::OK();
} else if (merge_op == "Maximum") {
*reduction_op = ncclMax;
return Status::OK();
} else if (merge_op == "Minimum") {
*reduction_op = ncclMin;
return Status::OK();
} else {
return errors::Internal(
"Expected merge_op to be in [Add, Mul, Maximum, Minimum], found ",
merge_op);
}
}
string NcclCollectiveKey(const string& exec_key, int step_id) {
return strings::StrCat(exec_key, ":", step_id);
}
} // namespace
std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
return absl::make_unique<NcclCommunicator>();
}
void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
StatusCallback done) {
const CollectiveParams& col_params = col_ctx->col_params;
const int num_global_devices = col_params.group.group_size;
const int num_local_devices = col_params.instance.num_devices_per_task.at(
col_params.instance.task_names[col_params.default_rank]);
const string nccl_collective_key =
NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
col_ctx->output, col_ctx->col_params.default_rank, std::move(done));
NcclManager::Context context(
nccl_collective_key, num_local_devices, num_global_devices,
col_params.group.runtime_details.communicator_key,
col_params.source_rank);
VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type
<< " num_tasks " << col_params.group.num_tasks << " current task "
<< col_params.instance.task_names[col_params.default_rank]
<< " num local devices " << num_local_devices
<< " num global devices " << num_global_devices << " device "
<< col_ctx->device_name << " instance "
<< col_params.instance.instance_key;
// `AddTo*` performs consistency checks for the NCCL call and enqueues the
// `Participant` struct locally. When all local participants with this
// `nccl_collective_key` have called `AddToAllReduce` and
// `SignalMultiNodeReady`, all devices at this worker are ready to process
// this NCCL op.
//
// The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this
// point, it synchronizes the NCCL stream with the compute stream, and then
// enqueues the NCCL kernel on the NCCL stream.
switch (col_params.instance.type) {
case REDUCTION_COLLECTIVE: {
ncclRedOp_t reduction_op;
Status s = ReductionOp(col_params.merge_op->type_string(), &reduction_op);
if (!s.ok()) {
participant->done_callback(s);
return;
}
nccl_manager_.AddToAllReduce(std::move(participant), context,
reduction_op);
break;
}
case GATHER_COLLECTIVE: {
nccl_manager_.AddToAllGather(std::move(participant), context);
break;
}
case BROADCAST_COLLECTIVE: {
if (col_params.is_source) {
nccl_manager_.AddBroadcastSend(std::move(participant), context);
} else {
nccl_manager_.AddBroadcastRecv(std::move(participant), context);
}
break;
}
default: {
participant->done_callback(errors::Internal("Unexpected CollectiveType ",
col_params.instance.type));
return;
}
}
// NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
// deadlocks. In the current implementation, we define a deterministic
// sequential launch order between potentially concurrent collective instances
// by introducing control information during static graph analysis in
// graph/collective_order.cc. This can be either in the form of explicit
// control edges or via `wait_for` attribute on the collective op.
//
// The other end of the design spectrum would have a distinguished node
// dynamically signal the next collective to launch to all other participants.
// This has higher degree of runtime coordination, but it may be able to
// achieve better performance if the (arbitrary) static execution order
// assigned in the first approach turns out to not be good from a scheduling
// perspective. e.g. consider a graph in which c1, c2, and c3 are three
// concurrent collective instances, and the static ordering assigns c1 -> c2
// -> c3. In practice, it could turn out that c3 is always ready to execute
// before c1 or c2.
{
// `WaitForDependencies` may block if the collective instances on which this
// op depends have not yet launched. When this function returns, this op is
// ready to go.
profiler::TraceMe activity("WaitForDependencies",
profiler::TraceMeLevel::kInfo);
col_ctx->col_exec->WaitForDependencies(col_params);
nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
}
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx->col_exec->UnblockDependencies(col_params);
}
}
void NcclCommunicator::StartAbort(const Status& s) {
CHECK(false) << "not implemented yet"; // Crash ok.
}
} // namespace tensorflow
#else
namespace tensorflow {
std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
return nullptr;
}
} // namespace tensorflow
#endif // TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_NCCL_COLECTIVE_COMMUNICATOR_H_
#define TENSORFLOW_CORE_NCCL_COLECTIVE_COMMUNICATOR_H_
#include "tensorflow/core/framework/collective.h"
namespace tensorflow {
// Creates a NcclCommunicator if built with NCCL support, otherwise it returns
// nullptr.
std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator();
} // namespace tensorflow
#endif // TENSORFLOW_CORE_NCCL_COLECTIVE_COMMUNICATOR_H_

View File

@ -747,8 +747,8 @@ class Context(object):
This is intended to be used when a peer failure is detected, which allows
the user to handle the case instead of hanging. This aborts all on-going
collectives. After all subsequent collectives error immediately. The only
way to recovery now is to restart the program.
collectives. After all subsequent collectives error immediately, and you
need to reset_context() to use collectives again.
Args:
code: a `tf.errors` error code.