Merge pull request #24864 from byronyi:fix-18232

PiperOrigin-RevId: 230809690
This commit is contained in:
TensorFlower Gardener 2019-01-24 16:22:41 -08:00
commit 6e6706d45d
8 changed files with 316 additions and 8 deletions

View File

@ -100,12 +100,33 @@ cc_library(
],
)
cc_library(
name = "gdr_collective_executor_mgr",
srcs = ["gdr_collective_executor_mgr.cc"],
hdrs = ["gdr_collective_executor_mgr.h"],
deps = [
":gdr_memory_manager",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:cancellable_call",
"//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
"//tensorflow/core/distributed_runtime:device_resolver_distributed",
"//tensorflow/core/distributed_runtime:request_id",
"//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
],
)
cc_library(
name = "gdr_server_lib",
srcs = ["gdr_server_lib.cc"],
hdrs = ["gdr_server_lib.h"],
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
deps = [
":gdr_collective_executor_mgr",
":gdr_memory_manager",
":gdr_rendezvous_mgr",
":gdr_worker",

View File

@ -0,0 +1,160 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
class WorkerCacheInterface;
namespace {
class RecvBufCall : public CancellableCall {
public:
RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
const string& key, Device* to_device,
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
const DeviceLocality& server_locality,
CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
: CancellableCall(cancel_mgr, peer_task, wc) {
req_.set_step_id(step_id);
req_.set_buf_rendezvous_key(key);
*req_.mutable_client_locality() = client_locality;
*req_.mutable_server_locality() = server_locality;
req_.set_num_bytes(to_tensor->TotalBytes());
req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
req_.set_src_device(peer_device);
req_.set_dst_device(to_device->name());
req_.set_request_id(GetUniqueRequestId());
}
~RecvBufCall() override {}
void IssueCall(const StatusCallback& done) override {
wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
}
RecvBufRequest req_;
RecvBufResponse resp_;
};
class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
public:
CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver,
WorkerCacheInterface* worker_cache,
int64 step_id,
RemoteMemoryManager* remote_memory_manager)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
worker_cache_(worker_cache),
remote_memory_manager_(remote_memory_manager) {}
~CollectiveRemoteAccessDistributed() override {}
void RecvFromPeer(const string& peer_device, const string& peer_task,
bool peer_is_local, const string& key, Device* to_device,
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
int dev_to_dev_stream_index,
const StatusCallback& done) override {
if (peer_is_local) {
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
done);
return;
}
// State that needs to be threaded through a couple of async calls
// in order to make this function completely non-blocking.
struct State {
DeviceLocality server_locality;
std::unique_ptr<RecvBufCall> call;
};
State* state = new State;
// Logic to be executed on the RecvBufAsync callback.
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
to_device_ctx, to_tensor, dev_to_dev_stream_index,
done](const Status& s) {
if (s.ok()) {
remote_memory_manager_->TensorFromTransportOptions(
to_tensor, state->call->resp_.transport_options(), to_device,
to_device_ctx, to_alloc_attr.on_host(), done);
}
if (!s.ok() && errors::IsFailedPrecondition(s)) {
dev_resolver_->ClearTask(peer_task);
}
delete state;
};
// Logic to execute once we have the device locality for the server-side
// device.
auto dev_locality_callback = [this, state, peer_device, peer_task, key,
to_device, to_device_ctx, to_alloc_attr,
to_tensor, client_locality,
recv_buf_callback](const Status& s) {
if (!s.ok()) {
recv_buf_callback(s);
} else {
state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, state->server_locality,
&cancel_mgr_, worker_cache_));
state->call->Start(recv_buf_callback);
}
};
dev_resolver_->GetLocalityAsync(
peer_device, peer_task, &state->server_locality, dev_locality_callback);
}
void StartAbort(const Status& s) override {
CollectiveRemoteAccessLocal::StartAbort(s);
cancel_mgr_.StartCancel();
}
protected:
WorkerCacheInterface* worker_cache_; // Not owned
CancellationManager cancel_mgr_;
RemoteMemoryManager* remote_memory_manager_;
};
} // namespace
CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) {
CollectiveRemoteAccessDistributed* rma =
new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
worker_cache_, step_id,
remote_memory_manager_);
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
&gpu_ring_order_);
}
} // namespace tensorflow

View File

@ -0,0 +1,56 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/framework/collective.h"
namespace tensorflow {
class CollectiveParamResolverDistributed;
class ConfigProto;
class DeviceMgr;
class DeviceResolverDistributed;
class WorkerCacheInterface;
class StepSequenceRequest;
class StepSequenceResponse;
// An implementation of CollectiveExecutorMgr for a distributed environment
// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA.
class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr {
public:
GdrCollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
std::unique_ptr<DeviceResolverDistributed> dev_resolver,
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
WorkerCacheInterface* worker_cache, const string& task_name,
RemoteMemoryManager* remote_memory_manager)
: RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
std::move(param_resolver), worker_cache,
task_name),
remote_memory_manager_(remote_memory_manager) {}
~GdrCollectiveExecutorMgr() override {}
protected:
virtual CollectiveExecutor* Create(int64 step_id) override;
private:
RemoteMemoryManager* remote_memory_manager_; // Not owned.
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_

View File

@ -73,7 +73,10 @@ int TryToReadNumaNode(ibv_device* device) {
std::ifstream ifs(filename.c_str());
string content;
CHECK(std::getline(ifs, content));
const auto& ret = std::getline(ifs, content);
if (!ret) {
return port::kNUMANoAffinity;
}
int32 value;
if (strings::safe_strto32(content, &value)) {

View File

@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/contrib/gdr/gdr_server_lib.h"
#include "grpc/support/alloc.h"
#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h"
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
#include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h"
#include "tensorflow/contrib/gdr/gdr_worker.h"
#include "grpc/support/alloc.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
namespace tensorflow {
@ -57,10 +59,31 @@ Status GdrServer::Init() {
return std::unique_ptr<GdrWorker>(
new GdrWorker(env, config, remote_memory_manager_.get()));
};
CollectiveMgrCreationFunction collective_mgr_func =
[this](const ConfigProto& config, const WorkerEnv* env,
WorkerCacheInterface* worker_cache) {
string unused;
string default_worker_name;
DeviceNameUtils::SplitDeviceName(
env->device_mgr->ListDevices()[0]->name(), &default_worker_name,
&unused);
std::unique_ptr<DeviceResolverDistributed> dev_resolver(
new DeviceResolverDistributed(env->device_mgr, worker_cache,
default_worker_name));
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
new CollectiveParamResolverDistributed(
config, env->device_mgr, dev_resolver.get(), worker_cache,
default_worker_name));
return new GdrCollectiveExecutorMgr(
config, env->device_mgr, std::move(dev_resolver),
std::move(param_resolver), worker_cache, default_worker_name,
remote_memory_manager_.get());
};
TF_RETURN_IF_ERROR(remote_memory_manager_->Init());
return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func);
return GrpcServer::Init(nullptr, rendezvous_mgr_func, collective_mgr_func,
worker_func);
}
Status GdrServer::Start() {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/gdr/gdr_worker.h"
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@ -29,10 +30,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
@ -40,13 +44,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config,
RemoteMemoryManager* remote_memory_manager)
: GrpcWorker(worker_env, config),
remote_memory_manager_(remote_memory_manager),
recv_tensor_recent_request_ids_(100000) {}
recent_request_ids_(100000) {}
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
Status s = recv_tensor_recent_request_ids_.TrackUnique(
Status s = recent_request_ids_.TrackUnique(
request->request_id(), "RecvTensor (GdrWorker)", *request);
if (!s.ok()) {
done(s);
@ -145,4 +149,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
});
}
void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) {
// This is an RDMA enabled implementation augmenting grpc.
Status s = recent_request_ids_.TrackUnique(request->request_id(),
"RecvBuf (GdrWorker)", *request);
if (!s.ok()) {
done(s);
return;
}
CollectiveExecutor::Handle ce_handle(
env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
rma->buf_rendezvous()->ConsumeBuf(
request->buf_rendezvous_key(),
[this, request, response, done](const Status& status,
BufRendezvous::Hook* hook) {
Status s = status;
if (s.ok()) {
if (!DMAHelper::CanUseDMA(hook->prod_value)) {
s = errors::Internal("Tensor value for key ",
request->buf_rendezvous_key(),
" is not of a type supported by RecvBuf");
}
}
if (s.ok()) {
remote_memory_manager_->TransportOptionsFromTensor(
response->mutable_transport_options(), *hook->prod_value,
hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(),
[this, response, done, hook](const Status& s) {
response->set_send_start_micros(env_->env->NowMicros());
done(s);
BufRendezvous::DoneWithHook(hook);
});
}
});
}
} // namespace tensorflow

View File

@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker {
::grpc::ByteBuffer* response,
StatusCallback done) override;
virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response,
StatusCallback done) override;
private:
RemoteMemoryManager* remote_memory_manager_; // Not owned
RecentRequestIds recv_tensor_recent_request_ids_;
RecentRequestIds recent_request_ids_;
};
} // namespace tensorflow

View File

@ -56,7 +56,7 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
void RetireStepId(int64 graph_key, int64 step_id) override;
protected:
CollectiveExecutor* Create(int64 step_id) override;
virtual CollectiveExecutor* Create(int64 step_id) override;
WorkerCacheInterface* const worker_cache_; // Not owned.
const string task_name_;