diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 0e1fade3e84..03b93cf9a98 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -184,27 +184,27 @@ class SimpleRendezvous : public Rendezvous { public: explicit SimpleRendezvous() {} - Status Send(const string& key, const Args& send_args, const Tensor& val, + Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val, const bool is_dead) override { if (is_dead) { return errors::Internal("Send of a dead tensor"); } - ParsedKey parsed; - TF_RETURN_IF_ERROR(ParseKey(key, &parsed)); mutex_lock l(mu_); - if (table_.count(parsed.edge_name) > 0) { + string edge_name = parsed.edge_name.ToString(); + if (table_.count(edge_name) > 0) { return errors::Internal("Send of an already sent tensor"); } - table_[parsed.edge_name] = val; + table_[edge_name] = val; return Status::OK(); } - void RecvAsync(const string& key, const Args& recv_args, + void RecvAsync(const ParsedKey& parsed, const Args& recv_args, DoneCallback done) override { Tensor tensor; Status status = Status::OK(); { + string key = parsed.edge_name.ToString(); mutex_lock l(mu_); if (table_.count(key) <= 0) { status = errors::Internal("Did not find key ", key); @@ -417,7 +417,14 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, // this node. Don't bother processing the rest of the nodes. return c > 0; } - Status s = rendez->Recv(tensor_name, Rendezvous::Args(), &output, &is_dead); + + string full_key = Rendezvous::CreateKey("/cpu:0", 1, "/cpu:1", tensor_name, + FrameAndIter(0, 0)); + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(full_key, &parsed); + if (s.ok()) { + s = rendez->Recv(parsed, Rendezvous::Args(), &output, &is_dead); + } if (!s.ok() || is_dead) { return c > 0; } diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index c44f040e155..5dc8c33b2a7 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -43,8 +43,7 @@ std::vector<RegistrationInfo>* MutableRegistry() { } // namespace // static -void CopyTensor::ViaDMA(const string& edge_name, - DeviceContext* send_dev_context, +void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, DeviceContext* recv_dev_context, Device* src, Device* dst, const AllocatorAttributes src_alloc_attr, const AllocatorAttributes dst_alloc_attr, diff --git a/tensorflow/core/common_runtime/copy_tensor.h b/tensorflow/core/common_runtime/copy_tensor.h index fd8a26be083..a9d684bf110 100644 --- a/tensorflow/core/common_runtime/copy_tensor.h +++ b/tensorflow/core/common_runtime/copy_tensor.h @@ -42,7 +42,7 @@ class CopyTensor { // the type of devices and memory in use, the copy may be performed // synchronously or asynchronously. 'done' will be invoked only // after the copy is actually complete. - static void ViaDMA(const string& edge_name, DeviceContext* send_dev_context, + static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, DeviceContext* recv_dev_context, Device* src, Device* dst, const AllocatorAttributes src_alloc_attr, const AllocatorAttributes dst_alloc_attr, diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 7ffe3a7d8fe..820c4370e21 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -24,13 +24,17 @@ limitations under the License. namespace tensorflow { -DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) { +DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) + : name_backing_store_(128) { for (Device* d : devices) { devices_.push_back(d); // Register under both the full name and the local name. - device_map_[d->name()] = d; - device_map_[DeviceNameUtils::LocalName(d->name())] = d; + string full_name = d->name(); + device_map_[CopyToBackingStore(full_name)] = d; + + string lname = DeviceNameUtils::LocalName(d->name()); + device_map_[CopyToBackingStore(lname)] = d; device_type_counts_[d->device_type()]++; } } @@ -39,6 +43,13 @@ DeviceMgr::~DeviceMgr() { for (auto p : devices_) delete p; } +StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { + int n = s.size(); + char* space = name_backing_store_.Alloc(n); + memcpy(space, s.data(), n); + return StringPiece(space, n); +} + void DeviceMgr::ListDeviceAttributes( std::vector<DeviceAttributes>* devices) const { devices->reserve(devices_.size()); @@ -70,7 +81,7 @@ string DeviceMgr::DeviceMappingString() const { return out; } -Status DeviceMgr::LookupDevice(const string& name, Device** device) const { +Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const { Status s; auto iter = device_map_.find(name); if (iter == device_map_.end()) { diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index 8d3ab9c7b78..d41931eed8f 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -22,7 +22,9 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/core/arena.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/macros.h" @@ -49,7 +51,7 @@ class DeviceMgr { // Assigns *device with pointer to Device of the given name. // Accepts either a full device name, or just the replica-local suffix. - Status LookupDevice(const string& name, Device** device) const; + Status LookupDevice(StringPiece name, Device** device) const; // Clears given containers of all devices if 'container' is // non-empty. Otherwise, clears default containers of all devices. @@ -60,7 +62,11 @@ class DeviceMgr { private: typedef gtl::InlinedVector<Device*, 8> DeviceVec; DeviceVec devices_; - std::unordered_map<string, Device*> device_map_; + + StringPiece CopyToBackingStore(StringPiece s); + + std::unordered_map<StringPiece, Device*, StringPiece::Hasher> device_map_; + core::Arena name_backing_store_; // Storage for keys in device_map_ std::unordered_map<string, int> device_type_counts_; TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr); diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index a44e5ab960a..90d1b0f895f 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -551,6 +551,7 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs, const ExecutorsAndKeys* executors_and_keys, IntraProcessRendezvous* rendez) { Status s; + Rendezvous::ParsedKey parsed; // Insert the input tensors into the local rendezvous by their // rendezvous key. for (const auto& input : inputs) { @@ -560,7 +561,14 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs, "' is not a pre-defined feed!"); } const string& input_key = it->second; - s = rendez->Send(input_key, Rendezvous::Args(), input.second, false); + + s = Rendezvous::ParseKey(input_key, &parsed); + if (!s.ok()) { + rendez->StartAbort(s); + return s; + } + + s = rendez->Send(parsed, Rendezvous::Args(), input.second, false); if (!s.ok()) { rendez->StartAbort(s); return s; @@ -578,6 +586,7 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names, outputs->resize(output_names.size()); } + Rendezvous::ParsedKey parsed; // Get the outputs from the rendezvous for (size_t output_offset = 0; output_offset < output_names.size(); ++output_offset) { @@ -591,14 +600,16 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names, const string& output_key = it->second; Tensor output_tensor; bool is_dead; - - // Fetch data from the Rendezvous. IntraProcessRendezvous* rendez = run_state->rendez; - s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead); - if (is_dead && s.ok()) { - s = errors::InvalidArgument("The tensor returned for ", - output_names[output_offset], - " was not valid."); + + s = Rendezvous::ParseKey(output_key, &parsed); + if (s.ok()) { + // Fetch data from the Rendezvous. + s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead); + if (is_dead && s.ok()) { + s = errors::InvalidArgument("The tensor returned for ", output_name, + " was not valid."); + } } if (!s.ok()) { rendez->StartAbort(s); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 7f0bad128ec..a05735c086d 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -645,6 +645,43 @@ class ExecutorState { } }; + // A drop-in replacement for std::deque<TaggedNode>. We typically don't + // have that many nodes in the ready queue, so we just use a vector and + // don't free up memory from the queue as we consume nodes. + class TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : front_index_(0) {} + + void push_back(TaggedNode node) { ready_.push_back(node); } + TaggedNode front() const { + DCHECK_LT(front_index_, ready_.size()); + return ready_[front_index_]; + } + void pop_front() { + DCHECK_LT(front_index_, ready_.size()); + front_index_++; + if ((front_index_ == ready_.size()) || (front_index_ > 16384)) { + if (front_index_ == ready_.size()) { + ready_.clear(); + } else { + // Lots of unused entries at beginning of vector: move everything down + // to start of vector. + ready_.erase(ready_.begin(), ready_.begin() + front_index_); + } + front_index_ = 0; + } + } + bool empty() const { return ready_.empty(); } + const TaggedNode* begin() const { return ready_.begin() + front_index_; } + const TaggedNode* end() const { return ready_.end(); } + + private: + gtl::InlinedVector<TaggedNode, 16> ready_; + int front_index_; + }; + + struct AsyncState; + typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; typedef gtl::InlinedVector<Entry, 4> EntryVector; @@ -767,15 +804,15 @@ class ExecutorState { // "node" just finishes. Takes ownership of "stats". Returns true if // execution has completed. bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStats* stats, std::deque<TaggedNode>* inline_ready); + NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready); // Call Process() on all nodes in 'inline_ready'. - void ProcessInline(const std::deque<TaggedNode>& inline_ready); + void ProcessInline(const TaggedNodeReadyQueue& inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive // nodes in 'ready' into 'inline_ready'. void ScheduleReady(const TaggedNodeSeq& ready, - std::deque<TaggedNode>* inline_ready); + TaggedNodeReadyQueue* inline_ready); // Provide debugging output about an outstanding node in the executor. void DumpCompletedNodeState(const int node_id, const Entry* input_vector); @@ -905,43 +942,55 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { } } -namespace { +// State kept alive for executing an asynchronous node in another +// thread. NOTE: We need to make a copy of p.input, +// p.input_device_contexts, and p.input_alloc_attrs for asynchronous +// kernels because OpKernelContext methods like input_type(i) needs +// the param points to valid input type vector. It's not an issue for +// sync kernels because these vectors are kept on the stack. +struct ExecutorState::AsyncState { + AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, + const NodeItem& _item, Entry* _first_input, NodeExecStats* _stats) + : saved_inputs(*p.inputs), + saved_input_device_contexts(*p.input_device_contexts), + saved_input_alloc_attrs(*p.input_alloc_attrs), + params(p), + tagged_node(_tagged_node), + item(_item), + first_input(_first_input), + // ParamsButClearingEigenGPUDevice does equivalent of + // params.eigen_gpu_device = nullptr; + ctx(ParamsButClearingEigenGPUDevice(¶ms), item.num_outputs), + stats(_stats) { + params.inputs = &saved_inputs; + params.input_device_contexts = &saved_input_device_contexts; + params.input_alloc_attrs = &saved_input_alloc_attrs; + } -// Helpers to make a copy of 'p' and makes a copy of the input type -// vector and the device context vector. -// -// NOTE: We need to make a copy of p.input for asynchronous kernel -// because OpKernelContext methods like input_type(i) needs the param -// points to valid input type vector. It's not an issue for sync -// kernels because the type vector is kept on the stack. -OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) { - OpKernelContext::Params* ret = new OpKernelContext::Params; - *ret = p; - // Ensure the copy of Params will make a new eigen GPU device if - // necessary. - ret->eigen_gpu_device = nullptr; - ret->inputs = new TensorValueVec(*p.inputs); - ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts); - ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs); - return ret; -} + TensorValueVec saved_inputs; + DeviceContextVec saved_input_device_contexts; + AllocatorAttributeVec saved_input_alloc_attrs; + OpKernelContext::Params params; + TaggedNode tagged_node; + NodeItem item; + Entry* first_input; + OpKernelContext ctx; + NodeExecStats* stats; -// Helpers to delete 'p' and copies made by CopyParams. -void DeleteParams(OpKernelContext::Params* p) { - // No need to delete p->eigen_gpu_device since that is deleted in - // p's destructor - delete p->inputs; - delete p->input_device_contexts; - delete p->input_alloc_attrs; - delete p; -} - -} // namespace + private: + OpKernelContext::Params* ParamsButClearingEigenGPUDevice( + OpKernelContext::Params* p) { + // Ensure OpKernelContext constructor will make a new eigen GPU device if + // necessary. + p->eigen_gpu_device = nullptr; // Force allocation + return p; + } +}; void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { const NodeItem* nodes = impl_->nodes_; TaggedNodeSeq ready; - std::deque<TaggedNode> inline_ready; + TaggedNodeReadyQueue inline_ready; // Parameters passed to OpKernel::Compute. TensorValueVec inputs; @@ -1059,20 +1108,25 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { AsyncOpKernel* async = item.kernel->AsAsync(); DCHECK(async != nullptr); launched_asynchronously = true; - auto pcopy = CopyParams(params); - auto ctx = new OpKernelContext(pcopy, item.num_outputs); - auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy, - device]() { + AsyncState* state = + new AsyncState(params, tagged_node, item, first_input, stats); + + auto done = [this, state]() { + + Device* device = impl_->params_.device; + NodeExecStats* stats = state->stats; // Shorthand + Entry* first_input = state->first_input; // Shorthand + if (vlog_) { VLOG(2) << this << " Async kernel done: " - << SummarizeNodeDef(item.node->def()); + << SummarizeNodeDef(state->item.node->def()); } if (stats_collector_) nodestats::SetOpEnd(stats); EntryVector outputs; - Status s = ProcessOutputs(item, ctx, &outputs, stats); - if (stats_collector_) nodestats::SetMemory(stats, ctx); + Status s = ProcessOutputs(state->item, &state->ctx, &outputs, stats); + if (stats_collector_) nodestats::SetMemory(stats, &state->ctx); // Clears inputs. - int num_inputs = item.num_inputs; + const int num_inputs = state->item.num_inputs; for (int i = 0; i < num_inputs; ++i) { (first_input + i)->val = *kEmptyTensor; } @@ -1080,31 +1134,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { mutex_lock l(mu_); - tagged_node.input_frame->GetIteration(tagged_node.input_iter) - ->mark_completed(tagged_node.node->id()); + state->tagged_node.input_frame + ->GetIteration(state->tagged_node.input_iter) + ->mark_completed(state->tagged_node.node->id()); } TaggedNodeSeq ready; if (s.ok()) { - PropagateOutputs(tagged_node, outputs, &ready); + PropagateOutputs(state->tagged_node, outputs, &ready); } outputs.clear(); - if (s.ok() && pcopy->device->RequiresRecordingAccessedTensors()) { + if (s.ok() && + state->params.device->RequiresRecordingAccessedTensors()) { // Get the list of all tensors accessed during the execution TensorReferenceVector accessed; - ctx->retrieve_accessed_tensors(&accessed); + state->ctx.retrieve_accessed_tensors(&accessed); if (stats_collector_) nodestats::SetReferencedTensors(stats, accessed); // callee takes ownership of the vector - device->ConsumeListOfAccessedTensors(ctx->op_device_context(), + device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(), accessed); } - bool completed = NodeDone(s, item.node, ready, stats, nullptr); - delete ctx; - DeleteParams(pcopy); + bool completed = NodeDone(s, state->item.node, ready, stats, nullptr); + delete state; if (completed) Finish(); }; if (stats_collector_) nodestats::SetOpStart(stats); - device->ComputeAsync(async, ctx, done); + device->ComputeAsync(async, &state->ctx, done); } else { // Synchronous computes. OpKernelContext ctx(¶ms, item.num_outputs); @@ -1497,7 +1552,7 @@ void ExecutorState::AddLoopInv(FrameState* frame, const Node* node, bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStats* stats, - std::deque<TaggedNode>* inline_ready) { + TaggedNodeReadyQueue* inline_ready) { if (stats_collector_) { nodestats::SetAllEnd(stats); stats_collector_->UpdateCostModelNode(stats, impl_->graph_, node); @@ -1542,7 +1597,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node, return completed; } -void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) { +void ExecutorState::ProcessInline(const TaggedNodeReadyQueue& inline_ready) { if (inline_ready.empty()) return; int64 scheduled_usec = 0; if (stats_collector_) { @@ -1554,7 +1609,7 @@ void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) { } void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready, - std::deque<TaggedNode>* inline_ready) { + TaggedNodeReadyQueue* inline_ready) { if (ready.empty()) return; int64 scheduled_usec = 0; diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index 0feef05eab0..ea1b04feeb4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -30,7 +30,7 @@ void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - const string& tensor_name, + StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done); diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index 1cb238fe2c6..8b1430f0219 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -56,9 +56,9 @@ class GPUDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done) const override; - void CopyDeviceTensorToCPU(const Tensor* device_tensor, - const string& edge_name, Device* device, - Tensor* cpu_tensor, StatusCallback done) override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) override; void MaintainLifetimeOnStream( const Tensor* t, perftools::gputools::Stream* stream) const override {} diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index c205ad7eaa6..be8790f3512 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -39,7 +39,6 @@ namespace test { Benchmark::Benchmark(const string& device, Graph* g, const SessionOptions* options, Graph* init) { - SessionOptions default_options; if (!options) { options = &default_options; @@ -138,11 +137,15 @@ void Benchmark::RunWithArgs( }; for (int i = 0; i < 3; ++i) { for (const auto& p : in) { - rendez_->Send(p.first, Rendezvous::Args(), p.second, false); + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed)); + rendez_->Send(parsed, Rendezvous::Args(), p.second, false); } TF_CHECK_OK(exec_->Run(args)); for (const string& key : out) { - rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead); + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); + rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead); } } TF_CHECK_OK(device_->Sync()); @@ -150,11 +153,15 @@ void Benchmark::RunWithArgs( testing::StartTiming(); while (iters-- > 0) { for (const auto& p : in) { - rendez_->Send(p.first, Rendezvous::Args(), p.second, false); + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed)); + rendez_->Send(parsed, Rendezvous::Args(), p.second, false); } TF_CHECK_OK(exec_->Run(args)); for (const string& key : out) { - rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead); + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); + rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead); } } diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 92a61efaaac..9ef85d645d5 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -36,19 +36,17 @@ IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr) IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); } -Status IntraProcessRendezvous::Send(const string& key, +Status IntraProcessRendezvous::Send(const ParsedKey& parsed, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { - VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key; + VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey(); { mutex_lock l(mu_); if (!status_.ok()) return status_; } - Rendezvous::ParsedKey parsed; - TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); // Buffers "val" and "device_context" in local_. - return local_->Send(key, args, val, is_dead); + return local_->Send(parsed, args, val, is_dead); } Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src, @@ -111,24 +109,17 @@ void IntraProcessRendezvous::SameWorkerRecvDone( done); } -void IntraProcessRendezvous::RecvAsync(const string& key, +void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { - VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key; - - Rendezvous::ParsedKey parsed; - Status s = ParseKey(key, false /*!is_src*/, &parsed); - if (!s.ok()) { - done(s, Args(), recv_args, Tensor(), false); - return; - } + VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey(); // Recv the tensor from local_. - local_->RecvAsync(key, recv_args, [this, parsed, done]( - const Status& status, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& in, bool is_dead) { + local_->RecvAsync(parsed, recv_args, [this, parsed, done]( + const Status& status, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& in, bool is_dead) { Status s = status; Tensor* out = new Tensor; StatusCallback final_callback = [done, send_args, recv_args, out, diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h index 46f54706493..cb5848ede32 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.h +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -43,14 +43,14 @@ class IntraProcessRendezvous : public Rendezvous { // Forwards to local_, where the Tensor "val" will be buffered and // any waiting callback stored. - Status Send(const string& key, const Rendezvous::Args& args, + Status Send(const ParsedKey& key, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) override; // This method is called only by the RecvOp. It tests to see // whether the value will be produced by a local or remote device // and handles accordingly. In the local case it forwards to // local_, in the remote case it initiates an RPC request. - void RecvAsync(const string& key, const Rendezvous::Args& args, + void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; void StartAbort(const Status& status) override; diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index b049d7f1737..00c171ad686 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -60,23 +60,25 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) { return iter->second; } -void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key, +void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, + const Rendezvous::ParsedKey& parsed, Rendezvous::DoneCallback done) { BaseRemoteRendezvous* rendez = FindOrCreate(step_id); rendez->RecvLocalAsync( - key, [rendez, done](const Status& s, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& v, - bool dead) { + parsed, [rendez, done](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, + bool dead) { rendez->Unref(); done(s, send_args, recv_args, v, dead); }); } -Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key, +Status BaseRendezvousMgr::RecvLocal(int64 step_id, + const Rendezvous::ParsedKey& parsed, Tensor* val, bool* is_dead) { Status ret; Notification n; - RecvLocalAsync(step_id, key, + RecvLocalAsync(step_id, parsed, [val, is_dead, &ret, &n](const Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, @@ -140,38 +142,35 @@ static bool IsLocalDevice(const WorkerEnv& worker, return device_name.starts_with(worker.worker_name); } -Status BaseRemoteRendezvous::Send(const string& key, +Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { - VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key; + VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey(); { mutex_lock l(mu_); if (!status_.ok()) return status_; } - Rendezvous::ParsedKey parsed; - TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); if (!IsLocalDevice(*env_, parsed.src_device)) { - return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ", - env_->worker_name); + return errors::InvalidArgument("Invalid rendezvous key (src): ", + parsed.FullKey(), " @ ", env_->worker_name); } // Buffers "val" and "device_context" in local_. - return local_->Send(key, args, val, is_dead); + return local_->Send(parsed, args, val, is_dead); } -Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src, - Rendezvous::ParsedKey* parsed) { +Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, + bool is_src) { { mutex_lock l(mu_); if (!status_.ok()) return status_; } - TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed)); - if (is_src && !IsLocalDevice(*env_, parsed->src_device)) { - return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ", - env_->worker_name); + if (is_src && !IsLocalDevice(*env_, parsed.src_device)) { + return errors::InvalidArgument("Invalid rendezvous key (src): ", + parsed.FullKey(), " @ ", env_->worker_name); } - if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) { - return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ", - env_->worker_name); + if (!is_src && !IsLocalDevice(*env_, parsed.dst_device)) { + return errors::InvalidArgument("Invalid rendezvous key (dst): ", + parsed.FullKey(), " @ ", env_->worker_name); } return Status::OK(); } @@ -233,13 +232,11 @@ bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src, return DeviceNameUtils::IsSameAddressSpace(src, dst); } -void BaseRemoteRendezvous::RecvAsync(const string& key, +void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { - VLOG(1) << "RemoteRendezvous Recv " << this << " " << key; - - Rendezvous::ParsedKey parsed; - Status s = ParseKey(key, false /*!is_src*/, &parsed); + VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); + Status s = ValidateDevices(parsed, false /*!is_src*/); if (!s.ok()) { done(s, Args(), recv_args, Tensor(), false); return; @@ -247,12 +244,13 @@ void BaseRemoteRendezvous::RecvAsync(const string& key, // Are src and dst in the same worker? if (IsSameWorker(parsed.src, parsed.dst)) { + Rendezvous::ParsedKey parsed_copy = parsed; // Recv the tensor from local_. local_->RecvAsync( - key, recv_args, [this, parsed, done](const Status& status, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& in, bool is_dead) { + parsed_copy, recv_args, + [this, parsed_copy, done]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { Status s = status; Tensor* out = new Tensor; StatusCallback final_callback = [done, send_args, recv_args, out, @@ -262,27 +260,26 @@ void BaseRemoteRendezvous::RecvAsync(const string& key, }; if (s.ok()) { - SameWorkerRecvDone(parsed, send_args, recv_args, in, out, - final_callback); + SameWorkerRecvDone(parsed_copy, send_args, recv_args, in, out, + std::move(final_callback)); } else { final_callback(s); } }); return; } else { - RecvFromRemoteAsync(key, parsed, recv_args, done); + RecvFromRemoteAsync(parsed, recv_args, std::move(done)); } } -void BaseRemoteRendezvous::RecvLocalAsync(const string& key, +void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, DoneCallback done) { - Rendezvous::ParsedKey parsed; - Status s = ParseKey(key, true /* is_src */, &parsed); + Status s = ValidateDevices(parsed, true /* is_src */); if (!s.ok()) { done(s, Args(), Args(), Tensor(), false); return; } - local_->RecvAsync(key, Args(), done); + local_->RecvAsync(parsed, Args(), std::move(done)); } void BaseRemoteRendezvous::StartAbort(const Status& s) { diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index d869999a7fd..b208c0f8742 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -68,12 +68,12 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // "done" when the tensor for "key" is produced or an error occurs. // // This method is used by the rpc handler of RecvTensor. - void RecvLocalAsync(int64 step_id, const string& key, + void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed, Rendezvous::DoneCallback done) override; // Synchronous wrapper for RecvLocalAsync. - Status RecvLocal(int64 step_id, const string& key, Tensor* val, - bool* is_dead) override; + Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) override; // Removes rendezvous for "step_id". // @@ -116,14 +116,14 @@ class BaseRemoteRendezvous : public Rendezvous { // Forwards to local_, where the Tensor "val" will be buffered and // any waiting callback stored. - Status Send(const string& key, const Rendezvous::Args& args, + Status Send(const ParsedKey& key, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) override; // This method is called only by the RecvOp. It tests to see // whether the value will be produced by a local or remote device // and handles accordingly. In the local case it forwards to // local_, in the remote case it initiates an RPC request. - void RecvAsync(const string& key, const Rendezvous::Args& args, + void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; void StartAbort(const Status& status) override; @@ -134,15 +134,14 @@ class BaseRemoteRendezvous : public Rendezvous { // network. In either case it needs to retrieve a locally buffered // value from local_, and give it to its caller. // - // Runs "done" as soon as the tensor for "key" is available or an error + // Runs "done" as soon as the tensor for "parsed" is available or an error // is detected. // - // REQUIRES: "key" is one that will be Saved into the local rendezvous. - void RecvLocalAsync(const string& key, DoneCallback done); + // REQUIRES: "parsed" is one that will be Saved into the local rendezvous. + void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done); protected: - virtual void RecvFromRemoteAsync(const string& key, - const Rendezvous::ParsedKey& parsed, + virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, DoneCallback done) = 0; @@ -174,11 +173,10 @@ class BaseRemoteRendezvous : public Rendezvous { // Active outstanding RecvTensor calls. std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_); - // Parses "key" into "parsed". If "is_src" is true, checks that the - // rendezvous key's source is in this process. If "is_src" is false, - // checks that the rendezvous key's destination is in this process. - Status ParseKey(const string& key, bool is_src, - Rendezvous::ParsedKey* parsed); + // If "is_src" is true, checks that the rendezvous key "parsed"'s + // source is in this process. If "is_src" is false, checks that the + // rendezvous key "parsed"'s destination is in this process. + Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src); // Callback handling the case when a rendezvous has been // accomplished in local_ and the consumer is local to this process. diff --git a/tensorflow/core/distributed_runtime/call_options.cc b/tensorflow/core/distributed_runtime/call_options.cc index e01cc601fa1..3506191f0e7 100644 --- a/tensorflow/core/distributed_runtime/call_options.cc +++ b/tensorflow/core/distributed_runtime/call_options.cc @@ -33,7 +33,7 @@ void CallOptions::StartCancel() { void CallOptions::SetCancelCallback(CancelFunction cancel_func) { mutex_lock l(mu_); - cancel_func_ = cancel_func; + cancel_func_ = std::move(cancel_func); } void CallOptions::ClearCancelCallback() { diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc index befde83e99a..17843ff6b06 100644 --- a/tensorflow/core/distributed_runtime/executor_test.cc +++ b/tensorflow/core/distributed_runtime/executor_test.cc @@ -127,10 +127,15 @@ float V(const Tensor& tensor) { static uint64 kIncarnation = 1; // Uses in following tests. -string Key(const string& sender, const uint64 incarnation, - const string& receiver, const string& name) { - return Rendezvous::CreateKey(sender, incarnation, receiver, name, - FrameAndIter(0, 0)); +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + CHECK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result) + .ok()); + return result; } #define ALICE "/job:j/replica:0/task:0/cpu:0" diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 05445e96d41..e6cc1bd33f6 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -306,10 +306,15 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); // Sends values specified by the caller. + Rendezvous::ParsedKey parsed; for (const auto& p : in) { const string& key = p.first; const Tensor& val = p.second; - const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false); + + Status s = Rendezvous::ParseKey(key, &parsed); + if (s.ok()) { + s = rendezvous->Send(parsed, Rendezvous::Args(), val, false); + } if (!s.ok()) { done(s); item->Unref(); @@ -337,7 +342,10 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, LogMemory::RecordStep(args.step_id, handle); } thread::ThreadPool* pool = worker_env_->compute_pool; - args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); }; + using namespace std::placeholders; + // Line below is equivalent to this code, but does one less indirect call: + // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); }; + args.runner = std::bind(&thread::ThreadPool::Schedule, pool, _1); for (const auto& unit : item->units) { unit.root->RunAsync(args, barrier->Get()); } @@ -347,11 +355,15 @@ void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out, StatusCallback done, Status s) { if (s.ok()) { // Receives values requested by the caller. + Rendezvous::ParsedKey parsed; for (auto& p : *out) { const string& key = p.first; Tensor* val = &p.second; bool is_dead = false; - s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead); + s = Rendezvous::ParseKey(key, &parsed); + if (s.ok()) { + s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead); + } if (is_dead) { s = errors::InvalidArgument("The tensor returned for ", key, " was not valid."); diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index 4b5721d3dd1..04c1fc248ef 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -57,12 +57,13 @@ class RendezvousMgrInterface { // "done" when the tensor for "key" is produced or an error occurs. // // This method is used by the rpc handler of RecvTensor. - virtual void RecvLocalAsync(int64 step_id, const string& key, + virtual void RecvLocalAsync(int64 step_id, + const Rendezvous::ParsedKey& parsed, Rendezvous::DoneCallback done) = 0; // Synchronous wrapper for RecvLocalAsync. - virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val, - bool* is_dead) = 0; + virtual Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) = 0; // Removes rendezvous for "step_id". // diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 0bdbd57b468..91c5714e49e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -97,49 +97,66 @@ class GrpcRemoteWorker : public WorkerInterface { req_copy->set_dma_ok(false); } // Type-specialized logging for this method. - StatusCallback logging_callback = [this, request, req_copy, response, done, - start_usec](Status s) { - if (logger_->LoggingActive()) { - int64 end_usec = Env::Default()->NowMicros(); - int64 step_id = request->step_id(); - int64 bytes = response->tensor().ByteSize(); - int64 send_start_usec = start_usec; - // If a send start time was reported by the other side, use - // that instead. Maybe we should mark the display if we're using - // our local time instead of the remote start time? - if (response->send_start_micros()) { - // send_start_micros is the timestamp taken when the remote - // machine began to send the RecvTensor response. - // Due to clock skew between source and dest machines, it is - // possible that send_start_micros can be larger than end_usec or - // less than start_usec. - // To respect causality, we enforce the invariants that the RecvTensor - // response can not have been sent before the RecvTensor request, and - // must have been sent before it was received. - send_start_usec = std::max(start_usec, response->send_start_micros()); - send_start_usec = std::min(send_start_usec, end_usec - 1); + bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2); + StatusCallback wrapper_done; + const StatusCallback* cb_to_use; + if (!logging_active && req_copy == nullptr) { + cb_to_use = &done; // No additional work to do, so just use done directly + } else if (!logging_active) { + wrapper_done = [req_copy, done](Status s) { + delete req_copy; + done(s); + }; + cb_to_use = &wrapper_done; + } else { + wrapper_done = [this, request, req_copy, response, done, + start_usec](Status s) { + if (logger_->LoggingActive()) { + int64 end_usec = Env::Default()->NowMicros(); + int64 step_id = request->step_id(); + int64 bytes = response->tensor().ByteSize(); + int64 send_start_usec = start_usec; + // If a send start time was reported by the other side, use + // that instead. Maybe we should mark the display if we're using + // our local time instead of the remote start time? + if (response->send_start_micros()) { + // send_start_micros is the timestamp taken when the + // remote machine began to send the RecvTensor response. + // Due to clock skew between source and dest machines, it + // is possible that send_start_micros can be larger than + // end_usec or less than start_usec. + // + // To respect causality, we enforce the invariants that + // the RecvTensor response can not have been sent before + // the RecvTensor request, and must have been sent before + // it was received. + send_start_usec = + std::max(start_usec, response->send_start_micros()); + send_start_usec = std::min(send_start_usec, end_usec - 1); + } + const string& key = request->rendezvous_key(); + std::vector<string> key_parts = str_util::Split(key, ';'); + if (key_parts.size() != 5) { + LOG(WARNING) << "Bad key: " << key; + } else { + logger_->RecordRecvTensor(step_id, send_start_usec, end_usec, + key_parts[3], // tensor name + key_parts[0], // src_device + key_parts[2], // dst_device + bytes); + } } - const string& key = request->rendezvous_key(); - std::vector<string> key_parts = str_util::Split(key, ';'); - if (key_parts.size() != 5) { - LOG(WARNING) << "Bad key: " << key; - } else { - logger_->RecordRecvTensor(step_id, send_start_usec, end_usec, - key_parts[3], // tensor name - key_parts[0], // src_device - key_parts[2], // dst_device - bytes); - } - } - VLOG(2) << "done callback, req: " << request->DebugString() - << " response " << response->DebugString(); - delete req_copy; - done(s); - }; + VLOG(2) << "done callback, req: " << request->DebugString() + << " response " << response->DebugString(); + delete req_copy; + done(s); + }; + cb_to_use = &wrapper_done; + } IssueRequest(req_copy ? req_copy : request, response, - &grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback, - call_opts); + &grpc::WorkerService::Stub::AsyncRecvTensor, + std::move(*cb_to_use), call_opts); } void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index bba579a6a8d..bd5dbd2c5ba 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -349,11 +349,8 @@ class GrpcWorkerService : public AsyncServiceInterface { // Helper for RecvTensor. Validates "key" and returns the source // device in "*src_dev". - Status PrepareRecvTensor(const string& key, Device** src_dev) { - // Validate the key. - Rendezvous::ParsedKey parsed; - TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); - + Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, + Device** src_dev) { // Figures out which device the tensor is hosted on. TF_RETURN_IF_ERROR( env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); @@ -375,8 +372,12 @@ class GrpcWorkerService : public AsyncServiceInterface { const int64 step_id = call->request.step_id(); const string& key = call->request.rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(key, &parsed); Device* src_dev = nullptr; - Status s = PrepareRecvTensor(key, &src_dev); + if (s.ok()) { + s = PrepareRecvTensor(parsed, &src_dev); + } if (!s.ok()) { call->SendResponse(ToGrpcStatus(s)); return; @@ -388,7 +389,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // cancellation should abort the rendezvous. call->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); env_->rendezvous_mgr->RecvLocalAsync( - step_id, key, + step_id, parsed, [this, call, src_dev](const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index ea5a42333d4..96f7db2694b 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -41,8 +41,7 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { : BaseRemoteRendezvous(env, step_id, false) {} protected: - void RecvFromRemoteAsync(const string& key, - const Rendezvous::ParsedKey& parsed, + void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, DoneCallback done) override; @@ -55,23 +54,49 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { // Used only to retrieve tensors from remote processes. class RpcRecvTensorCall : public BaseRecvTensorCall { public: - RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi, - int64 step_id, const string& key, - const string& remote_dev, Allocator* allocator, - Device* dst_device) - : wi_(wi), - wc_(wc), - remote_dev_(remote_dev), - allocator_(allocator), - dst_(dst_device) { + RpcRecvTensorCall() + : wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {} + + void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id, + StringPiece key, Allocator* allocator, Device* dst_device, + const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) { + wi_ = wi; + wc_ = wc; + allocator_ = allocator; + dst_device_ = dst_device; + recv_args_ = recv_args; + done_ = std::move(done); req_.set_step_id(step_id); - req_.set_rendezvous_key(key); + req_.set_rendezvous_key(key.data(), key.size()); + } + + void Reset() { + delete wi_; + wi_ = nullptr; + wc_ = nullptr; + allocator_ = nullptr; + dst_device_ = nullptr; + // We don't clear opts_ and assume that Init will set up the state for + // opts_ appropriately. + req_.Clear(); + if (resp_.ByteSize() > 128) { + // Clear memory from resp_ if it is too large + RecvTensorResponse empty; + resp_.Swap(&empty); + } else { + resp_.Clear(); + } + { + mutex_lock l(mu_); + status_ = Status::OK(); + } + done_ = nullptr; } ~RpcRecvTensorCall() override { delete wi_; } void Start(std::function<void()> recv_done) override { - StartRTCall(recv_done); + StartRTCall(std::move(recv_done)); } void StartAbort(const Status& s) override { @@ -93,6 +118,10 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { bool is_dead() const { return resp_.is_dead(); } + Device* dst_device() const { return dst_device_; } + const Rendezvous::Args& recv_args() const { return recv_args_; } + const Rendezvous::DoneCallback& done() const { return done_; } + private: // Start the main RecvTensor call, checking for an async abort. void StartRTCall(std::function<void()> recv_done) { @@ -100,7 +129,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { nullptr /* TensorBufAllocator */, // done callback [this, recv_done](const Status& s) { - { + if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); } @@ -110,12 +139,13 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { WorkerInterface* wi_; // Owned. WorkerCacheInterface* wc_; // Not owned. - string remote_dev_; Allocator* allocator_; - Device* dst_; + Device* dst_device_; CallOptions opts_; RecvTensorRequest req_; RecvTensorResponse resp_; + Rendezvous::Args recv_args_; + Rendezvous::DoneCallback done_; mutable mutex mu_; Status status_ GUARDED_BY(mu_); @@ -123,10 +153,53 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall); }; +namespace { +class RpcRecvTensorFreeList { + public: + RpcRecvTensorFreeList() {} + ~RpcRecvTensorFreeList() { + for (int i = 0; i < objects_.size(); i++) { + delete objects_[i]; + } + } + + RpcRecvTensorCall* New() { + { + mutex_lock l(mu_); + if (!objects_.empty()) { + RpcRecvTensorCall* result = objects_.back(); + objects_.pop_back(); + return result; + } + } + return new RpcRecvTensorCall; + } + + void Release(RpcRecvTensorCall* obj) { + obj->Reset(); + { + mutex_lock l(mu_); + if (objects_.size() < kMaxObjects) { + objects_.push_back(obj); + return; + } + } + delete obj; + } + + private: + static const int kMaxObjects = 1000; + + mutex mu_; + std::vector<RpcRecvTensorCall*> objects_ GUARDED_BY(mu_); +}; + +static RpcRecvTensorFreeList call_freelist_; +} void RpcRemoteRendezvous::RecvFromRemoteAsync( - const string& key, const Rendezvous::ParsedKey& parsed, - const Rendezvous::Args& recv_args, DoneCallback done) { + const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, + DoneCallback done) { Status s; // key.src_device identifies a remote device. @@ -137,11 +210,15 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( s = errors::Internal(parsed.src_device, " is invalid remote source device."); } + // TODO(jeff): Consider checking for a valid worker_cache during the + // constructor of RpcRemoteRendezvous, rather than here, to simplify + // the twisty logic below. WorkerCacheInterface* worker_cache = env_->worker_cache; if (s.ok() && worker_cache == nullptr) { s = errors::Internal("No remote worker cache available."); } - WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker); + WorkerInterface* rwi = + (worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", src_worker); } @@ -157,15 +234,16 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs); // Prepare a RecvTensor call that can handle being aborted. - RpcRecvTensorCall* call = - new RpcRecvTensorCall(worker_cache, rwi, step_id_, key, - parsed.src_device, allocator, dst_device); + RpcRecvTensorCall* call = call_freelist_.New(); + + call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator, + dst_device, recv_args, std::move(done)); // Record "call" in active_ so that it can be aborted cleanly. RegisterCall(call); // Start "call". - call->Start([this, call, parsed, recv_args, done]() { + call->Start([this, call]() { // Removes "call" from active_. Prevent StartAbort(). DeregisterCall(call); // If StartAbort was called prior to DeregisterCall, then the @@ -173,24 +251,19 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( Status s = call->status(); Tensor val; if (s.ok()) { - Device* dst_device; - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); - if (s.ok()) { - s = dst_device->MakeTensorFromProto(call->tensor_proto(), - recv_args.alloc_attrs, &val); - } + s = call->dst_device()->MakeTensorFromProto( + call->tensor_proto(), call->recv_args().alloc_attrs, &val); } - done(s, Args(), recv_args, val, call->is_dead()); - delete call; + call->done()(s, Args(), call->recv_args(), val, call->is_dead()); + call_freelist_.Release(call); }); } } // namespace BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env) { + const WorkerEnv* worker_env) { return new RpcRemoteRendezvous(worker_env, step_id); } - } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 35290b9dea0..7e18278f309 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -40,15 +40,21 @@ string V(const Tensor& tensor) { return tensor.scalar<string>()(); } +Rendezvous::ParsedKey MakeKey(const string& s) { + Rendezvous::ParsedKey key; + CHECK(Rendezvous::ParseKey(s, &key).ok()); + return key; +} + TEST(RpcRendezvousMgrTest, LocalSendRecv) { WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; RpcRendezvousMgr rmgr(&env); const int64 step_id = 123; - const string key = Rendezvous::CreateKey( + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, - "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { Rendezvous* rendez = rmgr.Find(step_id); core::ScopedUnref unref(rendez); @@ -69,9 +75,9 @@ TEST(RpcRendezvousMgrTest, LocalAbort) { env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; RpcRendezvousMgr rmgr(&env); - const string key = Rendezvous::CreateKey( + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, - "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { // Explicit Abort(). const int64 step_id = 123; Rendezvous* rendez = rmgr.Find(step_id); @@ -105,9 +111,9 @@ TEST(RpcRendezvousMgrTest, CleanupAll) { env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; RpcRendezvousMgr rmgr(&env); - const string key = Rendezvous::CreateKey( + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, - "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { const int64 step_id = 123; Rendezvous* rendez = rmgr.Find(step_id); @@ -139,9 +145,9 @@ TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) { env.worker_name = "/job:mnist/replica:1/task:2"; RpcRendezvousMgr rmgr(&env); const int64 step_id = 123; - const string key = Rendezvous::CreateKey( + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, - "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)); + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { Rendezvous* rendez = rmgr.Find(step_id); core::ScopedUnref unref(rendez); diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 3ce661bd7cf..ba05efec2dc 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" namespace Eigen { @@ -80,7 +81,7 @@ class DeviceContext : public core::RefCounted { // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated // to be of the same size as "device_tensor". virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor, - const string& tensor_name, Device* device, + StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); } diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 337fe668c78..715397e6d60 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -30,6 +31,21 @@ limitations under the License. namespace tensorflow { +Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) { + const char* b_base = b.buf_.data(); + buf_ = b.buf_; + src_device.set(buf_.data() + (b.src_device.data() - b_base), + b.src_device.size()); + src = b.src; + src_incarnation = b.src_incarnation; + dst_device.set(buf_.data() + (b.dst_device.data() - b_base), + b.dst_device.size()); + dst = b.dst; + edge_name.set(buf_.data() + (b.edge_name.data() - b_base), + b.edge_name.size()); + return *this; +} + /* static */ string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, const string& dst_device, const string& name, @@ -66,7 +82,8 @@ static StringPiece ConsumeNextPart(StringPiece* s, char delim) { /* static */ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { - StringPiece s(key); + out->buf_ = key; // Make a copy that our StringPieces can point at + StringPiece s(out->buf_); StringPiece parts[5]; for (int i = 0; i < 5; i++) { parts[i] = ConsumeNextPart(&s, ';'); @@ -77,9 +94,9 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { strings::HexStringToUint64(parts[1], &out->src_incarnation) && DeviceNameUtils::ParseFullName(parts[2], &out->dst) && !parts[3].empty()) { - out->src_device.assign(parts[0].data(), parts[0].size()); - out->dst_device.assign(parts[2].data(), parts[2].size()); - out->edge_name.assign(parts[3].data(), parts[3].size()); + out->src_device.set(parts[0].data(), parts[0].size()); + out->dst_device.set(parts[2].data(), parts[2].size()); + out->edge_name.set(parts[3].data(), parts[3].size()); return Status::OK(); } return errors::InvalidArgument("Invalid rendezvous key: ", key); @@ -87,8 +104,8 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { Rendezvous::~Rendezvous() {} -Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val, - bool* is_dead) { +Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args, + Tensor* val, bool* is_dead) { Status ret; Notification n; RecvAsync(key, recv_args, @@ -109,18 +126,19 @@ class LocalRendezvousImpl : public Rendezvous { explicit LocalRendezvousImpl(bool tolerate_dup_recv) : tolerate_dup_recv_(tolerate_dup_recv) {} - Status Send(const string& key, const Args& send_args, const Tensor& val, + Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, const bool is_dead) override { - VLOG(2) << "Send " << this << " " << key; DoneCallback waiter = nullptr; Args recv_args; + uint64 key_hash = KeyHash(key.FullKey()); + VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); { mutex_lock l(mu_); if (!status_.ok()) { return status_; } Item* item = nullptr; - Table::iterator iter = table_.find(key); + Table::iterator iter = table_.find(key_hash); if (iter == table_.end()) { // There is no waiter for this message. Insert the message // into the waiters table. The waiter will pick it up when @@ -138,19 +156,24 @@ class LocalRendezvousImpl : public Rendezvous { // The allocator attributes of item->value. item->send_alloc_attrs = send_args.alloc_attrs; - CHECK(table_.insert({key, item}).second); + CHECK(table_.insert({key_hash, item}).second); return Status::OK(); } else { item = iter->second; + if (item->waiter == nullptr) { // There is already a message in the table under the key. // Should not happen unless it has a waiter. - return errors::Aborted("Duplicated send: ", key); + return errors::Aborted("Duplicated send: ", key.FullKey()); } // Mark item as complete. item->has_been_recvd = true; - waiter = item->waiter; - item->waiter = nullptr; + + // Get item->waiter function into waiter and set item->waiter to null + std::swap(item->waiter, waiter); + DCHECK(item->waiter == nullptr); + DCHECK(waiter != nullptr); + // The ref on recv_dev_context transfers below. recv_args.device_context = item->recv_dev_context; recv_args.alloc_attrs = item->recv_alloc_attrs; @@ -173,9 +196,10 @@ class LocalRendezvousImpl : public Rendezvous { return Status::OK(); } - void RecvAsync(const string& key, const Args& recv_args, + void RecvAsync(const ParsedKey& key, const Args& recv_args, DoneCallback done) override { - VLOG(2) << "Recv " << this << " " << key; + uint64 key_hash = KeyHash(key.FullKey()); + VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); mu_.lock(); if (!status_.ok()) { // Rendezvous has been aborted. @@ -184,13 +208,13 @@ class LocalRendezvousImpl : public Rendezvous { done(s, Args(), recv_args, Tensor(), false); return; } - Table::iterator iter = table_.find(key); + Table::iterator iter = table_.find(key_hash); if (iter != table_.end()) { Item* item = iter->second; if (item->has_been_recvd && !tolerate_dup_recv_) { mu_.unlock(); - done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args, - Tensor(), false); + done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(), + recv_args, Tensor(), false); } else if (item->waiter == nullptr || tolerate_dup_recv_) { // A message has already arrived and is stored in the table // under this key. Consumes the message and invokes the done @@ -218,8 +242,8 @@ class LocalRendezvousImpl : public Rendezvous { // Already have a waiter in the waiters table under this key, // which should not happen. mu_.unlock(); - done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args, - Tensor(), false); + done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(), + recv_args, Tensor(), false); } return; } @@ -227,13 +251,13 @@ class LocalRendezvousImpl : public Rendezvous { // waiting table. The done closure will be invoked when the // message arrives. Item* item = new Item; - item->waiter = done; + item->waiter = std::move(done); item->recv_alloc_attrs = recv_args.alloc_attrs; if (recv_args.device_context) { item->recv_dev_context = recv_args.device_context; item->recv_dev_context->Ref(); } - CHECK(table_.insert({key, item}).second); + CHECK(table_.insert({key_hash, item}).second); mu_.unlock(); return; } @@ -280,7 +304,12 @@ class LocalRendezvousImpl : public Rendezvous { } } }; - typedef std::unordered_map<string, Item*> Table; + // We key the hash table by KeyHash of the Rendezvous::CreateKey string + static uint64 KeyHash(const StringPiece& k) { + return Hash64(k.data(), k.size()); + } + + typedef std::unordered_map<uint64, Item*> Table; // TODO(zhifengc): shard table_. mutex mu_; diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 843374fa337..17cae351550 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -54,12 +54,22 @@ class Rendezvous : public core::RefCounted { // Parses the key constructed by CreateKey and parse src/dst device // names into structures respectively. struct ParsedKey { - string src_device; + StringPiece src_device; DeviceNameUtils::ParsedName src; uint64 src_incarnation = 0; - string dst_device; + StringPiece dst_device; DeviceNameUtils::ParsedName dst; - string edge_name; + StringPiece edge_name; + + ParsedKey() {} + ParsedKey(const ParsedKey& b) { *this = b; } + + ParsedKey& operator=(const ParsedKey& b); + StringPiece FullKey() const { return buf_; } + + private: + friend class Rendezvous; + string buf_; }; static Status ParseKey(const string& key, ParsedKey* out); @@ -74,7 +84,7 @@ class Rendezvous : public core::RefCounted { // Send/Recv on the same worker. // // Send() never blocks. - virtual Status Send(const string& key, const Args& args, const Tensor& val, + virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val, const bool is_dead) = 0; // Callback provided by a tensor consumer waiting on the rendezvous. @@ -84,13 +94,15 @@ class Rendezvous : public core::RefCounted { // receiver, which may be needed when a non-CPU device is in use // by either side. typedef std::function<void(const Status&, const Args&, const Args&, - const Tensor&, const bool)> DoneCallback; + const Tensor&, const bool)> + DoneCallback; - virtual void RecvAsync(const string& key, const Args& args, + virtual void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) = 0; // Synchronous wrapper for RecvAsync. - Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead); + Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead); // Aborts all pending and future Send/Recv with the given "status". // diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index ecafce5f785..663c449dbed 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -96,13 +96,29 @@ string V(const Tensor& tensor) { return tensor.scalar<string>()(); } +const char* kFoo = "/cpu:0;1;/cpu:1;foo;1;2"; +const char* kBar = "/gpu:0;2;/gpu:1;bar;1;2"; + +Rendezvous::ParsedKey MakeKey(const string& name) { + string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890, + "/job:mnist/replica:1/task:2/GPU:0", name, + FrameAndIter(0, 0)); + Rendezvous::ParsedKey k; + TF_EXPECT_OK(Rendezvous::ParseKey(s, &k)); + return k; +} + +Rendezvous::ParsedKey KeyFoo() { return MakeKey("foo"); } +Rendezvous::ParsedKey KeyBar() { return MakeKey("bar"); } + TEST_F(LocalRendezvousTest, SendRecv) { Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); - EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false))); + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false)); + EXPECT_TRUE( + errors::IsAborted(rendez_->Send(KeyFoo(), args, V("hello"), false))); Tensor val(DT_STRING); bool is_dead = false; - TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead)); EXPECT_EQ("hello", V(val)); } @@ -110,12 +126,12 @@ TEST_F(LocalRendezvousTest, RecvSend) { SchedClosure([this]() { Env::Default()->SleepForMicroseconds(10000); Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false)); }); Tensor val(DT_STRING); bool is_dead = false; Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead)); EXPECT_EQ("hello", V(val)); } @@ -124,16 +140,17 @@ TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) { Tensor t(DT_STRING); bool is_dead = false; Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead)); - TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead)); + TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead)); }); Env::Default()->SleepForMicroseconds(1000000); Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); - TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead)); - TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead)); + EXPECT_TRUE( + errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead))); + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead)); EXPECT_EQ("secret msg", V(val)); } @@ -142,17 +159,18 @@ TEST_F(LocalRendezvousTest, DuplicateSerialRecv) { Tensor t(DT_STRING); bool is_dead = false; Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead)); - TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead)); + TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead)); }); Env::Default()->SleepForMicroseconds(1000000); Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead)); - TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead)); + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead)); + TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead)); EXPECT_EQ("secret msg", V(val)); - EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); + EXPECT_TRUE( + errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead))); } // A simple structure that behaves a bit like a blocking counter. The @@ -174,7 +192,7 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) { random::SimplePhilox rnd(&philox); Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000)); Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Send(strings::StrCat(i), args, + TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args, V(strings::StrCat(i)), false)); }); SchedClosure([this, &state, i]() { @@ -184,7 +202,8 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - TF_ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead)); + TF_ASSERT_OK( + rendez_->Recv(MakeKey(strings::StrCat(i)), args, &val, &val_dead)); EXPECT_EQ(strings::StrCat(i), V(val)); bool done = false; { @@ -212,7 +231,7 @@ TEST_F(LocalRendezvousTest, RecvAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - Status status = rendez_->Recv("foo", args, &val, &val_dead); + Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); EXPECT_TRUE(errors::IsAborted(status)); } @@ -228,7 +247,7 @@ TEST_F(LocalRendezvousTest, RecvSleepAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - Status status = rendez_->Recv("foo", args, &val, &val_dead); + Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead); EXPECT_TRUE(errors::IsAborted(status)); } @@ -237,8 +256,9 @@ TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; - EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead))); - EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); + EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead))); + EXPECT_TRUE( + errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead))); } class DummyDeviceContext : public DeviceContext { @@ -255,15 +275,15 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) { Rendezvous::Args args; args.device_context = new DummyDeviceContext(123); - TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false)); Notification n; Rendezvous::Args args1; args1.device_context = new DummyDeviceContext(1); - rendez_->RecvAsync("foo", args1, [&n](const Status& s, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& val, bool is_dead) { + rendez_->RecvAsync(KeyFoo(), args1, [&n](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead) { CHECK_EQ(123, dynamic_cast<const DummyDeviceContext*>(send_args.device_context) ->stream_id()); @@ -284,8 +304,8 @@ static void BM_SendRecv(int iters) { Status s; if (iters > 0) { while (iters--) { - s = rendez->Send("foo", args, orig, is_dead); - s = rendez->Recv("foo", args, &val, &is_dead); + s = rendez->Send(KeyFoo(), args, orig, is_dead); + s = rendez->Recv(KeyFoo(), args, &val, &is_dead); } CHECK_EQ(V(val), V(orig)); } @@ -307,8 +327,8 @@ static void BM_RecvSend(int iters) { Rendezvous::Args args; Status s; for (int i = 0; i < iters / 2; ++i) { - s = rendez->Recv("foo", args, &foo, &is_dead); - s = rendez->Send("bar", args, bar, is_dead); + s = rendez->Recv(KeyFoo(), args, &foo, &is_dead); + s = rendez->Send(KeyBar(), args, bar, is_dead); } CHECK_EQ("foo", V(foo)); }); @@ -318,8 +338,8 @@ static void BM_RecvSend(int iters) { Rendezvous::Args args; Status s; for (int i = 0; i < iters / 2; ++i) { - s = rendez->Send("foo", args, foo, is_dead); - s = rendez->Recv("bar", args, &bar, &is_dead); + s = rendez->Send(KeyFoo(), args, foo, is_dead); + s = rendez->Recv(KeyBar(), args, &bar, &is_dead); } CHECK_EQ("bar", V(bar)); delete pool; diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 64225be2a2d..613aaecabba 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -32,10 +32,11 @@ static string GetRendezvousKeyPrefix(const string& send_device, recv_device, ";", tensor_name); } -static string GetRendezvousKey(const string& key_prefix, - const FrameAndIter& frame_iter) { - return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":", - frame_iter.iter_id); +static void GetRendezvousKey(const string& key_prefix, + const FrameAndIter& frame_iter, string* key) { + key->clear(); + strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":", + frame_iter.iter_id); } SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -57,9 +58,13 @@ void SendOp::Compute(OpKernelContext* ctx) { OP_REQUIRES( ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous.")); - const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter()); + string key; + GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key); VLOG(2) << "Send " << key; + Rendezvous::ParsedKey parsed; + OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(key, &parsed)); + // The device context may be passed between the Send/Recv // boundary, so that the device context used to produce the Tensor // is used when performing the copy on the recv side (which may be @@ -67,9 +72,8 @@ void SendOp::Compute(OpKernelContext* ctx) { Rendezvous::Args args; args.device_context = ctx->op_device_context(); args.alloc_attrs = ctx->input_alloc_attr(0); - Status s = - ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead()); - ctx->SetStatus(s); + OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0), + ctx->is_input_dead())); } REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); @@ -98,16 +102,21 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { OP_REQUIRES( ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous.")); - const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter()); + string key; + GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key); VLOG(2) << "Recv " << key; + Rendezvous::ParsedKey parsed; + OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(key, &parsed), done); + Rendezvous::Args args; args.device_context = ctx->op_device_context(); args.alloc_attrs = ctx->output_alloc_attr(0); ctx->rendezvous()->RecvAsync( - key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& val, bool is_dead) { + parsed, args, + [ctx, done](const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, + bool is_dead) { ctx->SetStatus(s); if (s.ok()) { // 'ctx' allocates the output tensor of the expected type. The diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 25edfd1405a..5816dbd40cc 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -104,22 +104,29 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { } StringPiece tmp; while (!fullname.empty()) { + bool progress = false; if (str_util::ConsumePrefix(&fullname, "/job:")) { p->has_job = !str_util::ConsumePrefix(&fullname, "*"); if (p->has_job && !ConsumeJobName(&fullname, &p->job)) { return false; } - } else if (str_util::ConsumePrefix(&fullname, "/replica:")) { + progress = true; + } + if (str_util::ConsumePrefix(&fullname, "/replica:")) { p->has_replica = !str_util::ConsumePrefix(&fullname, "*"); if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) { return false; } - } else if (str_util::ConsumePrefix(&fullname, "/task:")) { + progress = true; + } + if (str_util::ConsumePrefix(&fullname, "/task:")) { p->has_task = !str_util::ConsumePrefix(&fullname, "*"); if (p->has_task && !ConsumeNumber(&fullname, &p->task)) { return false; } - } else if (str_util::ConsumePrefix(&fullname, "/device:")) { + progress = true; + } + if (str_util::ConsumePrefix(&fullname, "/device:")) { p->has_type = !str_util::ConsumePrefix(&fullname, "*"); if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) { return false; @@ -132,24 +139,31 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { return false; } } + progress = true; + } - } else if (str_util::ConsumePrefix(&fullname, "/cpu:") || - str_util::ConsumePrefix(&fullname, "/CPU:")) { + if (str_util::ConsumePrefix(&fullname, "/cpu:") || + str_util::ConsumePrefix(&fullname, "/CPU:")) { p->has_type = true; p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...' p->has_id = !str_util::ConsumePrefix(&fullname, "*"); if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { return false; } - } else if (str_util::ConsumePrefix(&fullname, "/gpu:") || - str_util::ConsumePrefix(&fullname, "/GPU:")) { + progress = true; + } + if (str_util::ConsumePrefix(&fullname, "/gpu:") || + str_util::ConsumePrefix(&fullname, "/GPU:")) { p->has_type = true; p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...' p->has_id = !str_util::ConsumePrefix(&fullname, "*"); if (p->has_id && !ConsumeNumber(&fullname, &p->id)) { return false; } - } else { + progress = true; + } + + if (!progress) { return false; } } @@ -340,11 +354,22 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task, string* device) { ParsedName pn; if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) { - *task = strings::StrCat( - (pn.has_job ? strings::StrCat("/job:", pn.job) : ""), - (pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""), - (pn.has_task ? strings::StrCat("/task:", pn.task) : "")); - *device = strings::StrCat(pn.type, ":", pn.id); + task->clear(); + task->reserve( + (pn.has_job ? (5 + pn.job.size()) : 0) + + (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) + + (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0)); + if (pn.has_job) { + strings::StrAppend(task, "/job:", pn.job); + } + if (pn.has_replica) { + strings::StrAppend(task, "/replica:", pn.replica); + } + if (pn.has_task) { + strings::StrAppend(task, "/task:", pn.task); + } + device->clear(); + strings::StrAppend(device, pn.type, ":", pn.id); return true; } return false;