Make NcclManager part of CollectiveExecutorMgr
Collective NCCL will then use the NcclManager from CollectiveExecutorMgr. Non-collective NCCL (nccl_ops) are not affected. We're adding abortion to NcclManager. After abortion, the NcclManager cannot be used again. Having a NcclManager per CollectiveExecutorMgr allows us to "reset" it by resetting EagerContext. This aligns with RING collectives. Now this is mostly useful in tests, since we don't expose the ability to reset EagerContext in public API. This change adds an abstraction (NcclCommunicatorInterface) to bridge collectives and NcclManager. In this we don't need to link NcclManager into core_cpu, which proves to be difficult and weird. PiperOrigin-RevId: 332315169 Change-Id: I936f6489b43c9de0b3f6a3cf5a14f89e79689d5e
This commit is contained in:
parent
7cfe00cf3a
commit
57f8b71009
@ -231,10 +231,10 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Aborts all ongoing collectives with the specified status. After abortion,
|
||||
// subsequent collectives will error with this status immediately.
|
||||
// subsequent collectives will error with this status immediately. To reset the
|
||||
// collectives, create a new EagerContext.
|
||||
//
|
||||
// This is intended to be used when a peer failure is detected. There's yet no
|
||||
// way to reset the collectives other than restarting the program.
|
||||
// This is intended to be used when a peer failure is detected.
|
||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
|
@ -1164,7 +1164,7 @@ cc_library(
|
||||
)
|
||||
|
||||
# Test support library needed for higher-level (TensorFlow-specific) tests
|
||||
cc_library(
|
||||
tf_cuda_library(
|
||||
name = "testlib",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
@ -1295,6 +1295,7 @@ filegroup(
|
||||
"//tensorflow/core/graph:mobile_srcs_only_runtime",
|
||||
"//tensorflow/core/kernels:mobile_srcs",
|
||||
"//tensorflow/core/lib/io:mobile_srcs_only_runtime",
|
||||
"//tensorflow/core/nccl:mobile_srcs",
|
||||
"//tensorflow/core/profiler:mobile_srcs",
|
||||
"//tensorflow/core/public:mobile_srcs_only_runtime",
|
||||
"//tensorflow/core/util/sparse:mobile_srcs_only_runtime",
|
||||
|
@ -1739,6 +1739,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/debug:debug_graph_utils",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||
"//tensorflow/core/profiler/lib:profiler_backends",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
@ -1868,6 +1869,7 @@ tf_cc_tests(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
"//tensorflow/core/platform:regexp",
|
||||
"//tensorflow/core/util:protos_test_cc",
|
||||
"//third_party/eigen3",
|
||||
|
@ -270,8 +270,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
}
|
||||
core::ScopedUnref unref(col_impl);
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
this, dev_mgr_, ctx, CtxParams(ctx), col_params, exec_key, step_id_,
|
||||
input, output);
|
||||
this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
|
||||
col_params, exec_key, step_id_, input, output);
|
||||
status = col_impl->InitializeCollectiveContext(col_ctx);
|
||||
if (!status.ok()) {
|
||||
done_safe(status);
|
||||
|
@ -26,12 +26,14 @@ namespace tensorflow {
|
||||
CollectiveExecutorMgr::CollectiveExecutorMgr(
|
||||
const ConfigProto& config, const DeviceMgr* dev_mgr,
|
||||
std::unique_ptr<DeviceResolverInterface> dev_resolver,
|
||||
std::unique_ptr<ParamResolverInterface> param_resolver)
|
||||
std::unique_ptr<ParamResolverInterface> param_resolver,
|
||||
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator)
|
||||
: dev_mgr_(dev_mgr),
|
||||
dev_resolver_(std::move(dev_resolver)),
|
||||
param_resolver_(std::move(param_resolver)),
|
||||
gpu_ring_order_(
|
||||
config.gpu_options().experimental().collective_ring_order()),
|
||||
nccl_communicator_(std::move(nccl_communicator)),
|
||||
work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(),
|
||||
"collective_ops")) {}
|
||||
|
||||
|
@ -22,12 +22,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
class ConfigProto;
|
||||
class DeviceMgr;
|
||||
class NcclManager;
|
||||
|
||||
class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
|
||||
public:
|
||||
CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
|
||||
std::unique_ptr<DeviceResolverInterface> dev_resolver,
|
||||
std::unique_ptr<ParamResolverInterface> param_resolver);
|
||||
CollectiveExecutorMgr(
|
||||
const ConfigProto& config, const DeviceMgr* dev_mgr,
|
||||
std::unique_ptr<DeviceResolverInterface> dev_resolver,
|
||||
std::unique_ptr<ParamResolverInterface> param_resolver,
|
||||
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator);
|
||||
|
||||
virtual ~CollectiveExecutorMgr();
|
||||
|
||||
@ -43,6 +46,10 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
|
||||
return dev_resolver_.get();
|
||||
}
|
||||
|
||||
NcclCommunicatorInterface* GetNcclCommunicator() const override {
|
||||
return nccl_communicator_.get();
|
||||
}
|
||||
|
||||
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
|
||||
GetStepSequenceResponse* response,
|
||||
const StatusCallback& done) override;
|
||||
@ -64,6 +71,7 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
|
||||
std::unique_ptr<DeviceResolverInterface> dev_resolver_;
|
||||
std::unique_ptr<ParamResolverInterface> param_resolver_;
|
||||
string gpu_ring_order_;
|
||||
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator_;
|
||||
// Unbounded work queue for scheduling potentially-blocking work during
|
||||
// collective op execution. Ownership is shared between `this` and
|
||||
// `CollectiveRemoteAccessLocal`.
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
@ -47,7 +48,8 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
|
||||
new CollectiveParamResolverLocal(cp, device_mgr_.get(), drl.get(),
|
||||
task_name));
|
||||
cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
|
||||
std::move(prl)));
|
||||
std::move(prl),
|
||||
MaybeCreateNcclCommunicator()));
|
||||
}
|
||||
|
||||
std::unique_ptr<CollectiveExecutorMgr> cme_;
|
||||
|
@ -65,6 +65,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
#include "tensorflow/core/platform/byte_order.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -554,7 +555,8 @@ Status DirectSession::RunInternal(
|
||||
drl.get(),
|
||||
"/job:localhost/replica:0/task:0"));
|
||||
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
|
||||
options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
|
||||
options_.config, device_mgr_.get(), std::move(drl), std::move(cprl),
|
||||
MaybeCreateNcclCommunicator()));
|
||||
}
|
||||
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
|
||||
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
|
||||
|
@ -82,6 +82,7 @@ tf_cuda_library(
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
@ -120,7 +121,7 @@ EagerContext::EagerContext(
|
||||
"/job:localhost/replica:0/task:0"));
|
||||
collective_executor_mgr_.Reset(
|
||||
new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
|
||||
std::move(cprl)),
|
||||
std::move(cprl), MaybeCreateNcclCommunicator()),
|
||||
/*owned=*/true);
|
||||
}
|
||||
|
||||
|
@ -674,8 +674,9 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
new HierarchicalTreeBroadcaster;
|
||||
core::ScopedUnref unref(broadcaster);
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, input_tensor_ptr, output_tensor_ptr);
|
||||
parent_->col_exec_, /*nccl_communicator*/ nullptr,
|
||||
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, input_tensor_ptr, output_tensor_ptr);
|
||||
TF_CHECK_OK(broadcaster->InitializeCollectiveContext(col_ctx));
|
||||
|
||||
// Run the broadcast.
|
||||
|
@ -388,8 +388,9 @@ class PermuterTest : public ::testing::Test {
|
||||
Permuter* permuter = new Permuter;
|
||||
core::ScopedUnref unref(permuter);
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, &tensor_input_, &tensor_output_);
|
||||
parent_->col_exec_, /*nccl_communicator*/ nullptr,
|
||||
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, &tensor_input_, &tensor_output_);
|
||||
TF_CHECK_OK(permuter->InitializeCollectiveContext(col_ctx));
|
||||
Notification note;
|
||||
// Run the permute.
|
||||
|
@ -480,8 +480,9 @@ class RingGathererTest : public ::testing::Test {
|
||||
RingGatherer* gatherer = new RingGatherer;
|
||||
core::ScopedUnref unref(gatherer);
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, &input_tensor_, output_tensor_ptr);
|
||||
parent_->col_exec_, /*nccl_communicator*/ nullptr,
|
||||
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, &input_tensor_, output_tensor_ptr);
|
||||
TF_CHECK_OK(gatherer->InitializeCollectiveContext(col_ctx));
|
||||
|
||||
// Run the all-gather.
|
||||
|
@ -510,8 +510,9 @@ class RingReducerTest : public ::testing::Test {
|
||||
RingReducer* reducer = new RingReducer;
|
||||
core::ScopedUnref unref(reducer);
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(), &ctx, &op_params,
|
||||
col_params_, exec_key, kStepId, &tensor_, &tensor_);
|
||||
parent_->col_exec_, /*nccl_communicator*/ nullptr,
|
||||
parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
|
||||
kStepId, &tensor_, &tensor_);
|
||||
TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
|
||||
|
||||
// Run the all-reduce.
|
||||
|
@ -101,6 +101,11 @@ class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
NcclCommunicatorInterface* GetNcclCommunicator() const override {
|
||||
LOG(FATAL) << "Unimplemented"; // Crash OK
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void GetStepSequenceAsync(const GetStepSequenceRequest* request,
|
||||
GetStepSequenceResponse* response,
|
||||
const StatusCallback& done) override {
|
||||
|
@ -522,6 +522,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -316,6 +316,7 @@ cc_library(
|
||||
":grpc_worker_cache",
|
||||
":grpc_worker_service",
|
||||
":rpc_rendezvous_mgr",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -47,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
@ -286,7 +287,8 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
default_worker_name));
|
||||
worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
|
||||
config, worker_env_.device_mgr, std::move(dev_resolver),
|
||||
std::move(param_resolver), worker_cache, default_worker_name));
|
||||
std::move(param_resolver), MaybeCreateNcclCommunicator(), worker_cache,
|
||||
default_worker_name));
|
||||
}
|
||||
|
||||
// Set up worker environment.
|
||||
@ -454,8 +456,8 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
|
||||
dev_resolver.get(), worker_cache, default_worker_name));
|
||||
worker_env_.collective_executor_mgr.reset(new RpcCollectiveExecutorMgr(
|
||||
server_def_.default_session_config(), worker_env_.device_mgr,
|
||||
std::move(dev_resolver), std::move(param_resolver), worker_cache,
|
||||
default_worker_name));
|
||||
std::move(dev_resolver), std::move(param_resolver),
|
||||
MaybeCreateNcclCommunicator(), worker_cache, default_worker_name));
|
||||
|
||||
master_env_.worker_cache = worker_cache;
|
||||
master_env_.collective_executor_mgr =
|
||||
|
@ -29,9 +29,11 @@ RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
|
||||
const ConfigProto& config, const DeviceMgr* dev_mgr,
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver,
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
|
||||
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
|
||||
WorkerCacheInterface* worker_cache, const string& task_name)
|
||||
: CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
|
||||
std::move(param_resolver)),
|
||||
std::move(param_resolver),
|
||||
std::move(nccl_communicator)),
|
||||
worker_cache_(worker_cache),
|
||||
task_name_(task_name) {
|
||||
group_leader_ = (task_name == config.experimental().collective_group_leader())
|
||||
|
@ -38,6 +38,7 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
|
||||
const ConfigProto& config, const DeviceMgr* dev_mgr,
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver,
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
|
||||
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
|
||||
WorkerCacheInterface* worker_cache, const string& task_name);
|
||||
|
||||
virtual ~RpcCollectiveExecutorMgr();
|
||||
|
@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -21,10 +24,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -52,9 +55,9 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
|
||||
device_mgr_.get(), dr.get(),
|
||||
worker_cache, task_name));
|
||||
// This CME is the group leader.
|
||||
cme_.reset(new RpcCollectiveExecutorMgr(options.config, device_mgr_.get(),
|
||||
std::move(dr), std::move(cpr),
|
||||
worker_cache, task_name));
|
||||
cme_.reset(new RpcCollectiveExecutorMgr(
|
||||
options.config, device_mgr_.get(), std::move(dr), std::move(cpr),
|
||||
MaybeCreateNcclCommunicator(), worker_cache, task_name));
|
||||
}
|
||||
|
||||
std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
|
||||
|
@ -158,14 +158,13 @@ string CollectiveParams::ToString() const {
|
||||
return ctx->params_;
|
||||
}
|
||||
|
||||
CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
|
||||
const DeviceMgr* dev_mgr,
|
||||
OpKernelContext* ctx,
|
||||
OpKernelContext::Params* op_params,
|
||||
const CollectiveParams& col_params,
|
||||
const string& exec_key, int64 step_id,
|
||||
const Tensor* input, Tensor* output)
|
||||
CollectiveContext::CollectiveContext(
|
||||
CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
|
||||
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
|
||||
OpKernelContext::Params* op_params, const CollectiveParams& col_params,
|
||||
const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
|
||||
: col_exec(col_exec),
|
||||
nccl_communicator(nccl_communicator),
|
||||
dev_mgr(dev_mgr),
|
||||
op_ctx(ctx),
|
||||
op_params(op_params),
|
||||
|
@ -36,6 +36,7 @@ class Device;
|
||||
class DeviceMgr;
|
||||
class GetStepSequenceRequest;
|
||||
class GetStepSequenceResponse;
|
||||
class NcclManager;
|
||||
class Tensor;
|
||||
|
||||
// Types of supported collective operations.
|
||||
@ -229,6 +230,8 @@ class StepSequenceInterface {
|
||||
virtual void RetireStepId(int64 graph_key, int64 step_id) = 0;
|
||||
};
|
||||
|
||||
class NcclCommunicatorInterface;
|
||||
|
||||
// Interface that provides access to per-step CollectiveExecutor
|
||||
// instances and various distributed resolution capabilities.
|
||||
class CollectiveExecutorMgrInterface : public StepSequenceInterface {
|
||||
@ -246,6 +249,8 @@ class CollectiveExecutorMgrInterface : public StepSequenceInterface {
|
||||
virtual ParamResolverInterface* GetParamResolver() const = 0;
|
||||
|
||||
virtual DeviceResolverInterface* GetDeviceResolver() const = 0;
|
||||
|
||||
virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0;
|
||||
};
|
||||
|
||||
// Interface that a Collective Op implementation uses to exchange data
|
||||
@ -354,19 +359,12 @@ class CollectiveExecutor : public core::RefCounted {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor);
|
||||
};
|
||||
|
||||
class CollectiveContext {
|
||||
public:
|
||||
CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
|
||||
OpKernelContext* ctx, OpKernelContext::Params* op_params,
|
||||
const CollectiveParams& col_params, const string& exec_key,
|
||||
int64 step_id, const Tensor* input, Tensor* output);
|
||||
|
||||
virtual ~CollectiveContext() = default;
|
||||
|
||||
CollectiveExecutor* col_exec; // Not owned
|
||||
const DeviceMgr* dev_mgr; // Not owned
|
||||
OpKernelContext* op_ctx; // Not owned
|
||||
OpKernelContext::Params* op_params; // Not owned
|
||||
struct CollectiveContext {
|
||||
CollectiveExecutor* col_exec; // Not owned
|
||||
NcclCommunicatorInterface* nccl_communicator; // Not owned
|
||||
const DeviceMgr* dev_mgr; // Not owned
|
||||
OpKernelContext* op_ctx; // Not owned
|
||||
OpKernelContext::Params* op_params; // Not owned
|
||||
const CollectiveParams& col_params;
|
||||
const string exec_key;
|
||||
const int64 step_id;
|
||||
@ -375,6 +373,23 @@ class CollectiveContext {
|
||||
Device* device; // The device for which this instance labors
|
||||
const string device_name;
|
||||
DeviceLocality device_locality;
|
||||
|
||||
CollectiveContext(CollectiveExecutor* col_exec,
|
||||
NcclCommunicatorInterface* nccl_communicator,
|
||||
const DeviceMgr* dev_mgr, OpKernelContext* ctx,
|
||||
OpKernelContext::Params* op_params,
|
||||
const CollectiveParams& col_params, const string& exec_key,
|
||||
int64 step_id, const Tensor* input, Tensor* output);
|
||||
};
|
||||
|
||||
class NcclCommunicatorInterface {
|
||||
public:
|
||||
virtual ~NcclCommunicatorInterface() = default;
|
||||
|
||||
virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
virtual void StartAbort(const Status& s) = 0;
|
||||
};
|
||||
|
||||
// Interface of a Collective Op implementation. Each specific CollectiveOp will
|
||||
|
@ -200,17 +200,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
# virtual targets since nested select statements not possible
|
||||
tf_kernel_library(
|
||||
name = "virtual_nccl",
|
||||
deps = if_cuda(["@local_config_nccl//:nccl"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "virtual_rccl",
|
||||
deps = if_rocm(["@local_config_rocm//rocm:rccl"]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "collective_ops",
|
||||
srcs = if_nccl([
|
||||
@ -228,11 +217,10 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
] + if_nccl([
|
||||
":virtual_nccl",
|
||||
":virtual_rccl",
|
||||
"//tensorflow/core/nccl:nccl_lib",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -255,6 +243,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/nccl:collective_communicator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -74,10 +74,6 @@ Status NcclBase::InitializeCollectiveGroupRuntimeDetails(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const string NcclBase::NcclCollectiveKey(const string& exec_key, int step_id) {
|
||||
return strings::StrCat(exec_key, ":", step_id);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -37,8 +37,6 @@ class NcclBase : public CollectiveImplementationInterface {
|
||||
CollGroupRuntimeDetails* col_group_runtime_details) override;
|
||||
|
||||
protected:
|
||||
const string NcclCollectiveKey(const string& exec_key, int step_id);
|
||||
|
||||
const CollectiveType type_;
|
||||
const string name_;
|
||||
std::shared_ptr<CollectiveContext> col_ctx_;
|
||||
|
@ -24,55 +24,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
void NcclBroadcaster::Run(StatusCallback done) {
|
||||
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
|
||||
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
const int num_global_devices = col_params_->group.group_size;
|
||||
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
|
||||
col_params_->instance.task_names[col_params_->default_rank]);
|
||||
string nccl_collective_key =
|
||||
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
|
||||
auto participant = absl::make_unique<NcclManager::Participant>(
|
||||
compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input,
|
||||
col_ctx_->output, col_params_->default_rank, std::move(done));
|
||||
VLOG(1)
|
||||
<< "NcclBroadcast calling NcclManager::AddBroadcastSend/Recv num_tasks "
|
||||
<< col_params_->group.num_tasks << " current task "
|
||||
<< col_params_->instance.task_names[col_params_->default_rank]
|
||||
<< " num local devices " << num_local_devices << " num global devices "
|
||||
<< num_global_devices << " rank " << col_params_->default_rank
|
||||
<< " device " << col_ctx_->device_name << " instance "
|
||||
<< col_params_->instance.instance_key << " source "
|
||||
<< col_params_->is_source;
|
||||
if (col_params_->is_source) {
|
||||
NcclManager::instance()->AddBroadcastSend(
|
||||
std::move(participant),
|
||||
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
|
||||
col_params_->group.runtime_details.communicator_key,
|
||||
col_params_->source_rank});
|
||||
} else {
|
||||
NcclManager::instance()->AddBroadcastRecv(
|
||||
std::move(participant),
|
||||
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
|
||||
col_params_->group.runtime_details.communicator_key,
|
||||
col_params_->source_rank});
|
||||
}
|
||||
{
|
||||
// `WaitForDependencies` may block if the collective instances on which this
|
||||
// op depends have not yet launched. When this function returns, this op is
|
||||
// ready to go.
|
||||
profiler::TraceMe activity("WaitForDependencies",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->WaitForDependencies(*col_params_);
|
||||
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
|
||||
}
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `UnblockDependencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
}
|
||||
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done));
|
||||
}
|
||||
|
||||
REGISTER_COLLECTIVE(NcclBroadcast, NcclBroadcaster);
|
||||
|
@ -24,45 +24,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
void NcclGatherer::Run(StatusCallback done) {
|
||||
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
|
||||
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
const int num_global_devices = col_params_->group.group_size;
|
||||
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
|
||||
col_params_->instance.task_names[col_params_->default_rank]);
|
||||
string nccl_collective_key =
|
||||
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
|
||||
auto participant = absl::make_unique<NcclManager::Participant>(
|
||||
compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input,
|
||||
col_ctx_->output, col_params_->default_rank, std::move(done));
|
||||
VLOG(1) << "NcclGatherer calling NcclManager::AddToAllGather num_tasks "
|
||||
<< col_params_->group.num_tasks << " current task "
|
||||
<< col_params_->instance.task_names[col_params_->default_rank]
|
||||
<< " num local devices " << num_local_devices
|
||||
<< " num global devices " << num_global_devices << " rank "
|
||||
<< col_params_->default_rank << " device " << col_ctx_->device_name
|
||||
<< " instance " << col_params_->instance.instance_key;
|
||||
NcclManager::instance()->AddToAllGather(
|
||||
std::move(participant),
|
||||
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
|
||||
col_params_->group.runtime_details.communicator_key,
|
||||
/*source_rank=*/-1});
|
||||
{
|
||||
// `WaitForDependencies` may block if the collective instances on which this
|
||||
// op depends have not yet launched. When this function returns, this op is
|
||||
// ready to go.
|
||||
profiler::TraceMe activity("WaitForDependencies",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->WaitForDependencies(*col_params_);
|
||||
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
|
||||
}
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `UnblockDependencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
}
|
||||
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done));
|
||||
}
|
||||
|
||||
REGISTER_COLLECTIVE(NcclGather, NcclGatherer);
|
||||
|
@ -23,36 +23,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
|
||||
if (merge_op == "Add") {
|
||||
*reduction_op = ncclSum;
|
||||
return Status::OK();
|
||||
} else if (merge_op == "Mul") {
|
||||
*reduction_op = ncclProd;
|
||||
return Status::OK();
|
||||
} else if (merge_op == "Maximum") {
|
||||
*reduction_op = ncclMax;
|
||||
return Status::OK();
|
||||
} else if (merge_op == "Minimum") {
|
||||
*reduction_op = ncclMin;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal(
|
||||
"Expected merge_op to be in [Add, Mul, Maximum, Minimum], found ",
|
||||
merge_op);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void NcclReducer::Run(StatusCallback done) {
|
||||
ncclRedOp_t reduction_op;
|
||||
Status s = ReductionOp(col_params_->merge_op->type_string(), &reduction_op);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor group_size;
|
||||
std::unique_ptr<Notification> group_size_ready;
|
||||
Status group_size_status;
|
||||
@ -117,73 +88,7 @@ void NcclReducer::Run(StatusCallback done) {
|
||||
} else {
|
||||
done_callback = std::move(done);
|
||||
}
|
||||
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
|
||||
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
// `AddToAllReduce` performs consistency checks for the NCCL call and enqueues
|
||||
// the `Participant` struct locally. When all local participants with this
|
||||
// `nccl_collective_key` have called `AddToAllReduce` and
|
||||
// `SignalMultiNodeReady`, all devices at this worker are ready to process
|
||||
// this NCCL op.
|
||||
//
|
||||
// The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this
|
||||
// point, it synchronizes the NCCL stream with the compute stream, and then
|
||||
// enqueues the NCCL kernel on the NCCL stream.
|
||||
const int num_global_devices = col_params_->group.group_size;
|
||||
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
|
||||
col_params_->instance.task_names[col_params_->default_rank]);
|
||||
const string nccl_collective_key =
|
||||
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
|
||||
auto participant = absl::make_unique<NcclManager::Participant>(
|
||||
compute_stream->parent(), compute_stream, gpu_info, col_ctx_->input,
|
||||
col_ctx_->output, col_params_->default_rank, std::move(done_callback));
|
||||
VLOG(1) << "NcclReducer calling NcclManager::AddToAllReduce num_tasks "
|
||||
<< col_params_->group.num_tasks << " current task "
|
||||
<< col_params_->instance.task_names[col_params_->default_rank]
|
||||
<< " num local devices " << num_local_devices
|
||||
<< " num global devices " << num_global_devices << " device "
|
||||
<< col_ctx_->device_name << " instance "
|
||||
<< col_params_->instance.instance_key;
|
||||
NcclManager::instance()->AddToAllReduce(
|
||||
std::move(participant),
|
||||
{nccl_collective_key, num_local_devices, num_global_devices,
|
||||
col_params_->group.runtime_details.communicator_key, /*source_rank=*/-1},
|
||||
reduction_op);
|
||||
|
||||
// NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
|
||||
// deadlocks. In the current implementation, we define a deterministic
|
||||
// sequential launch order between potentially concurrent collective instances
|
||||
// by introducing control information during static graph analysis in
|
||||
// graph/collective_order.cc. This can be either in the form of explicit
|
||||
// control edges or via `wait_for` attribute on the collective op.
|
||||
//
|
||||
// The other end of the design spectrum would have a distinguished node
|
||||
// dynamically signal the next collective to launch to all other participants.
|
||||
// This has higher degree of runtime coordination, but it may be able to
|
||||
// achieve better performance if the (arbitrary) static execution order
|
||||
// assigned in the first approach turns out to not be good from a scheduling
|
||||
// perspective. e.g. consider a graph in which c1, c2, and c3 are three
|
||||
// concurrent collective instances, and the static ordering assigns c1 -> c2
|
||||
// -> c3. In practice, it could turn out that c3 is always ready to execute
|
||||
// before c1 or c2.
|
||||
{
|
||||
// `WaitForDependencies` may block if the collective instances on which this
|
||||
// op depends have not yet launched. When this function returns, this op is
|
||||
// ready to go.
|
||||
profiler::TraceMe activity("WaitForDependencies",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
// TODO(b/80529858): make this entirely non-blocking by converting
|
||||
// `WaitForDependencies` to async function.
|
||||
col_ctx_->col_exec->WaitForDependencies(*col_params_);
|
||||
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
|
||||
}
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `UnblockDependencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
}
|
||||
col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done_callback));
|
||||
|
||||
// If no final_op, then this OpKernel is non-blocking.
|
||||
if (!col_params_->final_op) {
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
@ -84,6 +85,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
NcclTestBase(CollectiveType collective_type, const string& collective_name)
|
||||
: collective_type_(collective_type),
|
||||
collective_name_(collective_name),
|
||||
nccl_communicator_(MaybeCreateNcclCommunicator()),
|
||||
work_queue_(std::make_shared<UnboundedWorkQueue>(
|
||||
Env::Default(), "collective_executor")),
|
||||
col_exec_(nullptr) {}
|
||||
@ -318,7 +320,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
auto* reducer = new NcclReducer();
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
parent_->col_exec_, parent_->nccl_communicator_.get(),
|
||||
parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/&input_, /*output=*/&input_);
|
||||
TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
|
||||
@ -349,7 +352,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
auto* broadcaster = new NcclBroadcaster();
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
parent_->col_exec_, parent_->nccl_communicator_.get(),
|
||||
parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/col_params_.is_source ? &input_ : nullptr,
|
||||
/*output=*/&input_);
|
||||
@ -389,7 +393,8 @@ class NcclTestBase : public ::testing::Test {
|
||||
strings::StrCat(col_params_.instance.instance_key, ":0:0");
|
||||
auto* gatherer = new NcclGatherer();
|
||||
auto col_ctx = std::make_shared<CollectiveContext>(
|
||||
parent_->col_exec_, parent_->dev_mgr_.get(),
|
||||
parent_->col_exec_, parent_->nccl_communicator_.get(),
|
||||
parent_->dev_mgr_.get(),
|
||||
/*OpKernelContext=*/&ctx, &op_params, col_params_, exec_key, kStepId,
|
||||
/*input=*/&input_,
|
||||
/*output=*/&output_);
|
||||
@ -419,6 +424,7 @@ class NcclTestBase : public ::testing::Test {
|
||||
const string collective_name_;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> gpus_;
|
||||
TestCollectiveExecutorMgr col_exec_mgr_;
|
||||
std::unique_ptr<NcclCommunicatorInterface> nccl_communicator_;
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
CollectiveExecutor* col_exec_;
|
||||
std::unique_ptr<DeviceMgr> dev_mgr_;
|
||||
|
@ -3,15 +3,17 @@
|
||||
# APIs are meant to change over time.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm", "tf_copts")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -77,3 +79,28 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/common_runtime/gpu:rocm",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "collective_communicator",
|
||||
srcs = ["collective_communicator.cc"],
|
||||
hdrs = ["collective_communicator.h"],
|
||||
copts = tf_copts() + if_nccl(["-DTENSORFLOW_USE_NCCL=1"]),
|
||||
visibility = [
|
||||
"//learning/brain/runtime:__subpackages__",
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
deps =
|
||||
["//tensorflow/core:framework"] + if_nccl([
|
||||
":nccl_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobile_srcs",
|
||||
srcs = [
|
||||
"collective_communicator.cc",
|
||||
"collective_communicator.h",
|
||||
],
|
||||
)
|
||||
|
178
tensorflow/core/nccl/collective_communicator.cc
Normal file
178
tensorflow/core/nccl/collective_communicator.cc
Normal file
@ -0,0 +1,178 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
|
||||
#if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/nccl/nccl_manager.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NcclCommunicator : public NcclCommunicatorInterface {
|
||||
public:
|
||||
void Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
StatusCallback done) override;
|
||||
|
||||
void StartAbort(const Status& s) override;
|
||||
|
||||
private:
|
||||
NcclManager nccl_manager_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
Status ReductionOp(const string& merge_op, ncclRedOp_t* reduction_op) {
|
||||
if (merge_op == "Add") {
|
||||
*reduction_op = ncclSum;
|
||||
return Status::OK();
|
||||
} else if (merge_op == "Mul") {
|
||||
*reduction_op = ncclProd;
|
||||
return Status::OK();
|
||||
} else if (merge_op == "Maximum") {
|
||||
*reduction_op = ncclMax;
|
||||
return Status::OK();
|
||||
} else if (merge_op == "Minimum") {
|
||||
*reduction_op = ncclMin;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal(
|
||||
"Expected merge_op to be in [Add, Mul, Maximum, Minimum], found ",
|
||||
merge_op);
|
||||
}
|
||||
}
|
||||
|
||||
string NcclCollectiveKey(const string& exec_key, int step_id) {
|
||||
return strings::StrCat(exec_key, ":", step_id);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
|
||||
return absl::make_unique<NcclCommunicator>();
|
||||
}
|
||||
|
||||
void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
StatusCallback done) {
|
||||
const CollectiveParams& col_params = col_ctx->col_params;
|
||||
const int num_global_devices = col_params.group.group_size;
|
||||
const int num_local_devices = col_params.instance.num_devices_per_task.at(
|
||||
col_params.instance.task_names[col_params.default_rank]);
|
||||
const string nccl_collective_key =
|
||||
NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
|
||||
auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
|
||||
auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
auto participant = absl::make_unique<NcclManager::Participant>(
|
||||
compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
|
||||
col_ctx->output, col_ctx->col_params.default_rank, std::move(done));
|
||||
NcclManager::Context context(
|
||||
nccl_collective_key, num_local_devices, num_global_devices,
|
||||
col_params.group.runtime_details.communicator_key,
|
||||
col_params.source_rank);
|
||||
VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type
|
||||
<< " num_tasks " << col_params.group.num_tasks << " current task "
|
||||
<< col_params.instance.task_names[col_params.default_rank]
|
||||
<< " num local devices " << num_local_devices
|
||||
<< " num global devices " << num_global_devices << " device "
|
||||
<< col_ctx->device_name << " instance "
|
||||
<< col_params.instance.instance_key;
|
||||
// `AddTo*` performs consistency checks for the NCCL call and enqueues the
|
||||
// `Participant` struct locally. When all local participants with this
|
||||
// `nccl_collective_key` have called `AddToAllReduce` and
|
||||
// `SignalMultiNodeReady`, all devices at this worker are ready to process
|
||||
// this NCCL op.
|
||||
//
|
||||
// The `NcclManager` uses a dedicated CUDA stream for NCCL kernels. At this
|
||||
// point, it synchronizes the NCCL stream with the compute stream, and then
|
||||
// enqueues the NCCL kernel on the NCCL stream.
|
||||
switch (col_params.instance.type) {
|
||||
case REDUCTION_COLLECTIVE: {
|
||||
ncclRedOp_t reduction_op;
|
||||
Status s = ReductionOp(col_params.merge_op->type_string(), &reduction_op);
|
||||
if (!s.ok()) {
|
||||
participant->done_callback(s);
|
||||
return;
|
||||
}
|
||||
nccl_manager_.AddToAllReduce(std::move(participant), context,
|
||||
reduction_op);
|
||||
break;
|
||||
}
|
||||
case GATHER_COLLECTIVE: {
|
||||
nccl_manager_.AddToAllGather(std::move(participant), context);
|
||||
break;
|
||||
}
|
||||
case BROADCAST_COLLECTIVE: {
|
||||
if (col_params.is_source) {
|
||||
nccl_manager_.AddBroadcastSend(std::move(participant), context);
|
||||
} else {
|
||||
nccl_manager_.AddBroadcastRecv(std::move(participant), context);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
participant->done_callback(errors::Internal("Unexpected CollectiveType ",
|
||||
col_params.instance.type));
|
||||
return;
|
||||
}
|
||||
}
|
||||
// NOTE(ayushd): We need to synchronize NCCL launches across nodes to prevent
|
||||
// deadlocks. In the current implementation, we define a deterministic
|
||||
// sequential launch order between potentially concurrent collective instances
|
||||
// by introducing control information during static graph analysis in
|
||||
// graph/collective_order.cc. This can be either in the form of explicit
|
||||
// control edges or via `wait_for` attribute on the collective op.
|
||||
//
|
||||
// The other end of the design spectrum would have a distinguished node
|
||||
// dynamically signal the next collective to launch to all other participants.
|
||||
// This has higher degree of runtime coordination, but it may be able to
|
||||
// achieve better performance if the (arbitrary) static execution order
|
||||
// assigned in the first approach turns out to not be good from a scheduling
|
||||
// perspective. e.g. consider a graph in which c1, c2, and c3 are three
|
||||
// concurrent collective instances, and the static ordering assigns c1 -> c2
|
||||
// -> c3. In practice, it could turn out that c3 is always ready to execute
|
||||
// before c1 or c2.
|
||||
{
|
||||
// `WaitForDependencies` may block if the collective instances on which this
|
||||
// op depends have not yet launched. When this function returns, this op is
|
||||
// ready to go.
|
||||
profiler::TraceMe activity("WaitForDependencies",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
col_ctx->col_exec->WaitForDependencies(col_params);
|
||||
nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
|
||||
}
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `UnblockDependencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx->col_exec->UnblockDependencies(col_params);
|
||||
}
|
||||
}
|
||||
|
||||
void NcclCommunicator::StartAbort(const Status& s) {
|
||||
CHECK(false) << "not implemented yet"; // Crash ok.
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#else
|
||||
namespace tensorflow {
|
||||
std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
28
tensorflow/core/nccl/collective_communicator.h
Normal file
28
tensorflow/core/nccl/collective_communicator.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_NCCL_COLECTIVE_COMMUNICATOR_H_
|
||||
#define TENSORFLOW_CORE_NCCL_COLECTIVE_COMMUNICATOR_H_
|
||||
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Creates a NcclCommunicator if built with NCCL support, otherwise it returns
|
||||
// nullptr.
|
||||
std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_NCCL_COLECTIVE_COMMUNICATOR_H_
|
@ -747,8 +747,8 @@ class Context(object):
|
||||
|
||||
This is intended to be used when a peer failure is detected, which allows
|
||||
the user to handle the case instead of hanging. This aborts all on-going
|
||||
collectives. After all subsequent collectives error immediately. The only
|
||||
way to recovery now is to restart the program.
|
||||
collectives. After all subsequent collectives error immediately, and you
|
||||
need to reset_context() to use collectives again.
|
||||
|
||||
Args:
|
||||
code: a `tf.errors` error code.
|
||||
|
Loading…
Reference in New Issue
Block a user