Merge pull request #24864 from byronyi:fix-18232
PiperOrigin-RevId: 230809690
This commit is contained in:
commit
6e6706d45d
@ -100,12 +100,33 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gdr_collective_executor_mgr",
|
||||
srcs = ["gdr_collective_executor_mgr.cc"],
|
||||
hdrs = ["gdr_collective_executor_mgr.h"],
|
||||
deps = [
|
||||
":gdr_memory_manager",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:cancellable_call",
|
||||
"//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
|
||||
"//tensorflow/core/distributed_runtime:device_resolver_distributed",
|
||||
"//tensorflow/core/distributed_runtime:request_id",
|
||||
"//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gdr_server_lib",
|
||||
srcs = ["gdr_server_lib.cc"],
|
||||
hdrs = ["gdr_server_lib.h"],
|
||||
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
|
||||
deps = [
|
||||
":gdr_collective_executor_mgr",
|
||||
":gdr_memory_manager",
|
||||
":gdr_rendezvous_mgr",
|
||||
":gdr_worker",
|
||||
|
160
tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
Normal file
160
tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
Normal file
@ -0,0 +1,160 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/base_collective_executor.h"
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class WorkerCacheInterface;
|
||||
|
||||
namespace {
|
||||
|
||||
class RecvBufCall : public CancellableCall {
|
||||
public:
|
||||
RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
|
||||
const string& key, Device* to_device,
|
||||
DeviceContext* to_device_ctx,
|
||||
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
|
||||
const DeviceLocality& client_locality,
|
||||
const DeviceLocality& server_locality,
|
||||
CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
|
||||
: CancellableCall(cancel_mgr, peer_task, wc) {
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_buf_rendezvous_key(key);
|
||||
*req_.mutable_client_locality() = client_locality;
|
||||
*req_.mutable_server_locality() = server_locality;
|
||||
req_.set_num_bytes(to_tensor->TotalBytes());
|
||||
req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
|
||||
req_.set_src_device(peer_device);
|
||||
req_.set_dst_device(to_device->name());
|
||||
req_.set_request_id(GetUniqueRequestId());
|
||||
}
|
||||
|
||||
~RecvBufCall() override {}
|
||||
|
||||
void IssueCall(const StatusCallback& done) override {
|
||||
wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
|
||||
}
|
||||
|
||||
RecvBufRequest req_;
|
||||
RecvBufResponse resp_;
|
||||
};
|
||||
|
||||
class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
public:
|
||||
CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
|
||||
DeviceResolverInterface* dev_resolver,
|
||||
WorkerCacheInterface* worker_cache,
|
||||
int64 step_id,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
|
||||
worker_cache_(worker_cache),
|
||||
remote_memory_manager_(remote_memory_manager) {}
|
||||
|
||||
~CollectiveRemoteAccessDistributed() override {}
|
||||
|
||||
void RecvFromPeer(const string& peer_device, const string& peer_task,
|
||||
bool peer_is_local, const string& key, Device* to_device,
|
||||
DeviceContext* to_device_ctx,
|
||||
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
|
||||
const DeviceLocality& client_locality,
|
||||
int dev_to_dev_stream_index,
|
||||
const StatusCallback& done) override {
|
||||
if (peer_is_local) {
|
||||
CollectiveRemoteAccessLocal::RecvFromPeer(
|
||||
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
|
||||
to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
|
||||
done);
|
||||
return;
|
||||
}
|
||||
|
||||
// State that needs to be threaded through a couple of async calls
|
||||
// in order to make this function completely non-blocking.
|
||||
struct State {
|
||||
DeviceLocality server_locality;
|
||||
std::unique_ptr<RecvBufCall> call;
|
||||
};
|
||||
State* state = new State;
|
||||
|
||||
// Logic to be executed on the RecvBufAsync callback.
|
||||
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
|
||||
to_device_ctx, to_tensor, dev_to_dev_stream_index,
|
||||
done](const Status& s) {
|
||||
if (s.ok()) {
|
||||
remote_memory_manager_->TensorFromTransportOptions(
|
||||
to_tensor, state->call->resp_.transport_options(), to_device,
|
||||
to_device_ctx, to_alloc_attr.on_host(), done);
|
||||
}
|
||||
if (!s.ok() && errors::IsFailedPrecondition(s)) {
|
||||
dev_resolver_->ClearTask(peer_task);
|
||||
}
|
||||
|
||||
delete state;
|
||||
};
|
||||
|
||||
// Logic to execute once we have the device locality for the server-side
|
||||
// device.
|
||||
auto dev_locality_callback = [this, state, peer_device, peer_task, key,
|
||||
to_device, to_device_ctx, to_alloc_attr,
|
||||
to_tensor, client_locality,
|
||||
recv_buf_callback](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
recv_buf_callback(s);
|
||||
} else {
|
||||
state->call.reset(new RecvBufCall(
|
||||
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
|
||||
to_alloc_attr, to_tensor, client_locality, state->server_locality,
|
||||
&cancel_mgr_, worker_cache_));
|
||||
state->call->Start(recv_buf_callback);
|
||||
}
|
||||
};
|
||||
|
||||
dev_resolver_->GetLocalityAsync(
|
||||
peer_device, peer_task, &state->server_locality, dev_locality_callback);
|
||||
}
|
||||
|
||||
void StartAbort(const Status& s) override {
|
||||
CollectiveRemoteAccessLocal::StartAbort(s);
|
||||
cancel_mgr_.StartCancel();
|
||||
}
|
||||
|
||||
protected:
|
||||
WorkerCacheInterface* worker_cache_; // Not owned
|
||||
CancellationManager cancel_mgr_;
|
||||
RemoteMemoryManager* remote_memory_manager_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) {
|
||||
CollectiveRemoteAccessDistributed* rma =
|
||||
new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
|
||||
worker_cache_, step_id,
|
||||
remote_memory_manager_);
|
||||
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
|
||||
&gpu_ring_order_);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
56
tensorflow/contrib/gdr/gdr_collective_executor_mgr.h
Normal file
56
tensorflow/contrib/gdr/gdr_collective_executor_mgr.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
|
||||
#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class CollectiveParamResolverDistributed;
|
||||
class ConfigProto;
|
||||
class DeviceMgr;
|
||||
class DeviceResolverDistributed;
|
||||
class WorkerCacheInterface;
|
||||
class StepSequenceRequest;
|
||||
class StepSequenceResponse;
|
||||
|
||||
// An implementation of CollectiveExecutorMgr for a distributed environment
|
||||
// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA.
|
||||
class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr {
|
||||
public:
|
||||
GdrCollectiveExecutorMgr(
|
||||
const ConfigProto& config, const DeviceMgr* dev_mgr,
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver,
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
|
||||
WorkerCacheInterface* worker_cache, const string& task_name,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
|
||||
std::move(param_resolver), worker_cache,
|
||||
task_name),
|
||||
remote_memory_manager_(remote_memory_manager) {}
|
||||
|
||||
~GdrCollectiveExecutorMgr() override {}
|
||||
|
||||
protected:
|
||||
virtual CollectiveExecutor* Create(int64 step_id) override;
|
||||
|
||||
private:
|
||||
RemoteMemoryManager* remote_memory_manager_; // Not owned.
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
|
@ -73,7 +73,10 @@ int TryToReadNumaNode(ibv_device* device) {
|
||||
|
||||
std::ifstream ifs(filename.c_str());
|
||||
string content;
|
||||
CHECK(std::getline(ifs, content));
|
||||
const auto& ret = std::getline(ifs, content);
|
||||
if (!ret) {
|
||||
return port::kNUMANoAffinity;
|
||||
}
|
||||
|
||||
int32 value;
|
||||
if (strings::safe_strto32(content, &value)) {
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/gdr/gdr_server_lib.h"
|
||||
|
||||
#include "grpc/support/alloc.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
|
||||
#include "tensorflow/contrib/gdr/gdr_worker.h"
|
||||
|
||||
#include "grpc/support/alloc.h"
|
||||
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
||||
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
|
||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -57,10 +59,31 @@ Status GdrServer::Init() {
|
||||
return std::unique_ptr<GdrWorker>(
|
||||
new GdrWorker(env, config, remote_memory_manager_.get()));
|
||||
};
|
||||
CollectiveMgrCreationFunction collective_mgr_func =
|
||||
[this](const ConfigProto& config, const WorkerEnv* env,
|
||||
WorkerCacheInterface* worker_cache) {
|
||||
string unused;
|
||||
string default_worker_name;
|
||||
DeviceNameUtils::SplitDeviceName(
|
||||
env->device_mgr->ListDevices()[0]->name(), &default_worker_name,
|
||||
&unused);
|
||||
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver(
|
||||
new DeviceResolverDistributed(env->device_mgr, worker_cache,
|
||||
default_worker_name));
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
|
||||
new CollectiveParamResolverDistributed(
|
||||
config, env->device_mgr, dev_resolver.get(), worker_cache,
|
||||
default_worker_name));
|
||||
return new GdrCollectiveExecutorMgr(
|
||||
config, env->device_mgr, std::move(dev_resolver),
|
||||
std::move(param_resolver), worker_cache, default_worker_name,
|
||||
remote_memory_manager_.get());
|
||||
};
|
||||
TF_RETURN_IF_ERROR(remote_memory_manager_->Init());
|
||||
|
||||
return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func);
|
||||
return GrpcServer::Init(nullptr, rendezvous_mgr_func, collective_mgr_func,
|
||||
worker_func);
|
||||
}
|
||||
|
||||
Status GdrServer::Start() {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_worker.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
@ -29,10 +30,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_session.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/protobuf/transport_options.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -40,13 +44,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: GrpcWorker(worker_env, config),
|
||||
remote_memory_manager_(remote_memory_manager),
|
||||
recv_tensor_recent_request_ids_(100000) {}
|
||||
recent_request_ids_(100000) {}
|
||||
|
||||
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) {
|
||||
Status s = recv_tensor_recent_request_ids_.TrackUnique(
|
||||
Status s = recent_request_ids_.TrackUnique(
|
||||
request->request_id(), "RecvTensor (GdrWorker)", *request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
@ -145,4 +149,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
});
|
||||
}
|
||||
|
||||
void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
RecvBufResponse* response, StatusCallback done) {
|
||||
// This is an RDMA enabled implementation augmenting grpc.
|
||||
Status s = recent_request_ids_.TrackUnique(request->request_id(),
|
||||
"RecvBuf (GdrWorker)", *request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
CollectiveExecutor::Handle ce_handle(
|
||||
env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
|
||||
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
|
||||
rma->buf_rendezvous()->ConsumeBuf(
|
||||
request->buf_rendezvous_key(),
|
||||
[this, request, response, done](const Status& status,
|
||||
BufRendezvous::Hook* hook) {
|
||||
Status s = status;
|
||||
if (s.ok()) {
|
||||
if (!DMAHelper::CanUseDMA(hook->prod_value)) {
|
||||
s = errors::Internal("Tensor value for key ",
|
||||
request->buf_rendezvous_key(),
|
||||
" is not of a type supported by RecvBuf");
|
||||
}
|
||||
}
|
||||
if (s.ok()) {
|
||||
remote_memory_manager_->TransportOptionsFromTensor(
|
||||
response->mutable_transport_options(), *hook->prod_value,
|
||||
hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(),
|
||||
[this, response, done, hook](const Status& s) {
|
||||
response->set_send_start_micros(env_->env->NowMicros());
|
||||
done(s);
|
||||
BufRendezvous::DoneWithHook(hook);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker {
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) override;
|
||||
|
||||
virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
RecvBufResponse* response,
|
||||
StatusCallback done) override;
|
||||
|
||||
private:
|
||||
RemoteMemoryManager* remote_memory_manager_; // Not owned
|
||||
RecentRequestIds recv_tensor_recent_request_ids_;
|
||||
RecentRequestIds recent_request_ids_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -56,7 +56,7 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
|
||||
void RetireStepId(int64 graph_key, int64 step_id) override;
|
||||
|
||||
protected:
|
||||
CollectiveExecutor* Create(int64 step_id) override;
|
||||
virtual CollectiveExecutor* Create(int64 step_id) override;
|
||||
|
||||
WorkerCacheInterface* const worker_cache_; // Not owned.
|
||||
const string task_name_;
|
||||
|
Loading…
Reference in New Issue
Block a user