diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 704be917b36..89314d162b3 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -100,12 +100,33 @@ cc_library( ], ) +cc_library( + name = "gdr_collective_executor_mgr", + srcs = ["gdr_collective_executor_mgr.cc"], + hdrs = ["gdr_collective_executor_mgr.h"], + deps = [ + ":gdr_memory_manager", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime:cancellable_call", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", + "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", + "//tensorflow/core/distributed_runtime:worker_cache", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", + ], +) + cc_library( name = "gdr_server_lib", srcs = ["gdr_server_lib.cc"], hdrs = ["gdr_server_lib.h"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ + ":gdr_collective_executor_mgr", ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc new file mode 100644 index 00000000000..b84710d26eb --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +class WorkerCacheInterface; + +namespace { + +class RecvBufCall : public CancellableCall { + public: + RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, + const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const DeviceLocality& server_locality, + CancellationManager* cancel_mgr, WorkerCacheInterface* wc) + : CancellableCall(cancel_mgr, peer_task, wc) { + req_.set_step_id(step_id); + req_.set_buf_rendezvous_key(key); + *req_.mutable_client_locality() = client_locality; + *req_.mutable_server_locality() = server_locality; + req_.set_num_bytes(to_tensor->TotalBytes()); + req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); + req_.set_src_device(peer_device); + req_.set_dst_device(to_device->name()); + req_.set_request_id(GetUniqueRequestId()); + } + + ~RecvBufCall() override {} + + void IssueCall(const StatusCallback& done) override { + wi_->RecvBufAsync(&opts_, &req_, &resp_, done); + } + + RecvBufRequest req_; + RecvBufResponse resp_; +}; + +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + WorkerCacheInterface* worker_cache, + int64 step_id, + RemoteMemoryManager* remote_memory_manager) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache), + remote_memory_manager_(remote_memory_manager) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + const StatusCallback& done) override { + if (peer_is_local) { + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, + done); + return; + } + + // State that needs to be threaded through a couple of async calls + // in order to make this function completely non-blocking. + struct State { + DeviceLocality server_locality; + std::unique_ptr call; + }; + State* state = new State; + + // Logic to be executed on the RecvBufAsync callback. + auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, + to_device_ctx, to_tensor, dev_to_dev_stream_index, + done](const Status& s) { + if (s.ok()) { + remote_memory_manager_->TensorFromTransportOptions( + to_tensor, state->call->resp_.transport_options(), to_device, + to_device_ctx, to_alloc_attr.on_host(), done); + } + if (!s.ok() && errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearTask(peer_task); + } + + delete state; + }; + + // Logic to execute once we have the device locality for the server-side + // device. + auto dev_locality_callback = [this, state, peer_device, peer_task, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, + recv_buf_callback](const Status& s) { + if (!s.ok()) { + recv_buf_callback(s); + } else { + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_locality, + &cancel_mgr_, worker_cache_)); + state->call->Start(recv_buf_callback); + } + }; + + dev_resolver_->GetLocalityAsync( + peer_device, peer_task, &state->server_locality, dev_locality_callback); + } + + void StartAbort(const Status& s) override { + CollectiveRemoteAccessLocal::StartAbort(s); + cancel_mgr_.StartCancel(); + } + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + CancellationManager cancel_mgr_; + RemoteMemoryManager* remote_memory_manager_; +}; + +} // namespace + +CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id, + remote_memory_manager_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h new file mode 100644 index 00000000000..a713c1c6b0b --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class CollectiveParamResolverDistributed; +class ConfigProto; +class DeviceMgr; +class DeviceResolverDistributed; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA. +class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr { + public: + GdrCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name, + RemoteMemoryManager* remote_memory_manager) + : RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, + task_name), + remote_memory_manager_(remote_memory_manager) {} + + ~GdrCollectiveExecutorMgr() override {} + + protected: + virtual CollectiveExecutor* Create(int64 step_id) override; + + private: + RemoteMemoryManager* remote_memory_manager_; // Not owned. +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index ce187515159..7321e973191 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -73,7 +73,10 @@ int TryToReadNumaNode(ibv_device* device) { std::ifstream ifs(filename.c_str()); string content; - CHECK(std::getline(ifs, content)); + const auto& ret = std::getline(ifs, content); + if (!ret) { + return port::kNUMANoAffinity; + } int32 value; if (strings::safe_strto32(content, &value)) { diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index dc0d5d548b8..e4ae1eec231 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_server_lib.h" #include "grpc/support/alloc.h" +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" #include "tensorflow/contrib/gdr/gdr_memory_manager.h" #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h" #include "tensorflow/contrib/gdr/gdr_worker.h" - -#include "grpc/support/alloc.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" namespace tensorflow { @@ -57,10 +59,31 @@ Status GdrServer::Init() { return std::unique_ptr( new GdrWorker(env, config, remote_memory_manager_.get())); }; + CollectiveMgrCreationFunction collective_mgr_func = + [this](const ConfigProto& config, const WorkerEnv* env, + WorkerCacheInterface* worker_cache) { + string unused; + string default_worker_name; + DeviceNameUtils::SplitDeviceName( + env->device_mgr->ListDevices()[0]->name(), &default_worker_name, + &unused); + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(env->device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed( + config, env->device_mgr, dev_resolver.get(), worker_cache, + default_worker_name)); + return new GdrCollectiveExecutorMgr( + config, env->device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name, + remote_memory_manager_.get()); + }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); - return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func); + return GrpcServer::Init(nullptr, rendezvous_mgr_func, collective_mgr_func, + worker_func); } Status GdrServer::Start() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 016e5ea27b3..86897e3c8ef 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_worker.h" +#include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -29,10 +30,13 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/protobuf/transport_options.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" namespace tensorflow { @@ -40,13 +44,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), - recv_tensor_recent_request_ids_(100000) {} + recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - Status s = recv_tensor_recent_request_ids_.TrackUnique( + Status s = recent_request_ids_.TrackUnique( request->request_id(), "RecvTensor (GdrWorker)", *request); if (!s.ok()) { done(s); @@ -145,4 +149,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // This is an RDMA enabled implementation augmenting grpc. + Status s = recent_request_ids_.TrackUnique(request->request_id(), + "RecvBuf (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + CollectiveExecutor::Handle ce_handle( + env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); + CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); + rma->buf_rendezvous()->ConsumeBuf( + request->buf_rendezvous_key(), + [this, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { + Status s = status; + if (s.ok()) { + if (!DMAHelper::CanUseDMA(hook->prod_value)) { + s = errors::Internal("Tensor value for key ", + request->buf_rendezvous_key(), + " is not of a type supported by RecvBuf"); + } + } + if (s.ok()) { + remote_memory_manager_->TransportOptionsFromTensor( + response->mutable_transport_options(), *hook->prod_value, + hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(), + [this, response, done, hook](const Status& s) { + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + }); + } + }); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 39f11e6bde5..9a85cfd4263 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker { ::grpc::ByteBuffer* response, StatusCallback done) override; + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, + StatusCallback done) override; + private: RemoteMemoryManager* remote_memory_manager_; // Not owned - RecentRequestIds recv_tensor_recent_request_ids_; + RecentRequestIds recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h index c9581fa00f3..98eb1467700 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h @@ -56,7 +56,7 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { void RetireStepId(int64 graph_key, int64 step_id) override; protected: - CollectiveExecutor* Create(int64 step_id) override; + virtual CollectiveExecutor* Create(int64 step_id) override; WorkerCacheInterface* const worker_cache_; // Not owned. const string task_name_;