A series of changes to significantly reduce the number of allocations
done by models distributed across many devices. A small microbenchmark model that runs two banks (A and B) of 30 nodes with a 30x30 full shuffle between them, where each of the nodes in A and in B run with one node on each of the 30 devices (so 30*29+30+30, or ~930 separate RPCs) was showing ~111,000 allocations per iteration of the graph. With the changes here, this is now down to ~64,300 allocations per iteration. Changes include: o DeviceContext::CopyDeviceTensorToCPU and related helper routines: use StringPiece instead of const string& for the tensor name (avoids creating a string in some cases where the caller only has a StringPiece available). o Change some Rendezvous and BaseRemoteRendezvous interfaces to take a 'const Rendezvous::ParsedKey& key', rather than 'const string& key'. In many cases, the callers were already having to parse the key into a ParsedKey, and so we were doing the parsing multiple times at different levels as we processed receiving or sending of a tensor. This reduces the number of times that we parse a key as it flows from a Send node through to a Recv node on another worker. o Changed Rendezvous::ParsedKey so that it makes a copy of the underlying full key, and then uses StringPiece objects to point into this copy for the src_device, dst_device, and edge_name pieces. This turns 3 string allocations into 1 per Rendezvous::ParseKey call. o Added new StringPiece Rendezvous::ParsedKey::FullKey() accessor to return a StringPiece for the underlying full key, and used that in a few places (mostly logging) where that is useful. o In many places, used std::move(function_variable) when assigning to an instance variable. This eliminates a very large number of excess std::function allocations/initializations (~56000 of the baseline allocations were related to std::function setup or cloning, and this is now down to ~11000 after this cl). o In the RPC-based remote workers (StubbyRemoteWorker and GrpcRemoteWorker), changed the code path in RecvTensorAsync to avoid creation of a std::function with 6 arguments unless necessary. There are three cases now handled separately: (a) We're not logging, and we didn't make a copy of the request that we need to free: just use the passed in 'StatusCallback done' object directly, without creating a wrapper std::function object at all (b) We're not logging, but we made a copy of the request that we need to free: we create a simple wrapper std::function that invokes the passed in 'done' callback, and then frees the req_copy request copy object. (c) We're logging: we create the std::function object with all the necessary state to log when the recv has finished. o Changed DeviceMgr::LookupDevice to take a StringPiece, rather than a const string&, and changed the hash table to use StringPiece keys. This allows clients that just have a StringPiece device name in their hand to avoid a string creation to lookup the Device* object. o Changed ExecutorState to use a specialized TaggedNodeReadyQueue that internally uses a gtl::InlinedVector<TaggedNode, 16>, rather than using a std::deque<TaggedNode> for keeping track of nodes ready to execute. This is faster because it avoids allocations entirely if the ready node queue doesn't get bigger than 16, and inlined vectors are generally faster than std::deque, at a minor risk of using more memory if this queue grows to very large numbers of ready nodes (mostly imaginable only in pathological graphs). o In ExecutorState::Process, allocated a single ExecutorState::AsyncState object to keep track of all the state we need to preserve for an asynchronously executed node, rather than keeping this state implicitly via a very large number of arguments to a lamda function. o Added new atomic std::atomic<bool> status_is_ok_ in BaseRemoteRendezvous. This allows us to avoid acquiring the lock when we just want to check if the status is non-OK in BaseRemoteRendezvous::Send and BaseRemoteRendezvous::ValidateDevices. o In GraphMgr::RunAllDone, changed assignment of args.runner to avoid one extra level of std::function indirection (binding the function directly to the ThreadPool::Schedule routine, rather than creating an intermediate lambda function that invokes this inside the body of the lambda. o Added freelist of RpcRecvTensorCall objects in third_party/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc o Changed third_party/tensorflow/core/framework/rendezvous.cc to keep the hashtable of Item* objects keyed by uint64 (hash of the tensor name), rather than the full-string tensor name. Collisions in the 64-bit hash space should basically never happen. o Sped up DeviceNameUtils::ParseFullName by optimizing for the common ordering of parts of /job, /replica, /task, /device. The parsing code was general enough to handle any order, but did so by comparing the prefixes 4, 3, 2, and 1 times, respectively, rather than 1, 1, 1, and 1 times. o Sped up DeviceNameUtils::SplitDeviceName to avoid extra string copies. Change: 125991891
This commit is contained in:
parent
0da0ecdf1f
commit
a120b0bec1
@ -184,27 +184,27 @@ class SimpleRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit SimpleRendezvous() {}
|
||||
|
||||
Status Send(const string& key, const Args& send_args, const Tensor& val,
|
||||
Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
if (is_dead) {
|
||||
return errors::Internal("Send of a dead tensor");
|
||||
}
|
||||
ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(ParseKey(key, &parsed));
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (table_.count(parsed.edge_name) > 0) {
|
||||
string edge_name = parsed.edge_name.ToString();
|
||||
if (table_.count(edge_name) > 0) {
|
||||
return errors::Internal("Send of an already sent tensor");
|
||||
}
|
||||
table_[parsed.edge_name] = val;
|
||||
table_[edge_name] = val;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecvAsync(const string& key, const Args& recv_args,
|
||||
void RecvAsync(const ParsedKey& parsed, const Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
Tensor tensor;
|
||||
Status status = Status::OK();
|
||||
{
|
||||
string key = parsed.edge_name.ToString();
|
||||
mutex_lock l(mu_);
|
||||
if (table_.count(key) <= 0) {
|
||||
status = errors::Internal("Did not find key ", key);
|
||||
@ -417,7 +417,14 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
||||
// this node. Don't bother processing the rest of the nodes.
|
||||
return c > 0;
|
||||
}
|
||||
Status s = rendez->Recv(tensor_name, Rendezvous::Args(), &output, &is_dead);
|
||||
|
||||
string full_key = Rendezvous::CreateKey("/cpu:0", 1, "/cpu:1", tensor_name,
|
||||
FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(full_key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendez->Recv(parsed, Rendezvous::Args(), &output, &is_dead);
|
||||
}
|
||||
if (!s.ok() || is_dead) {
|
||||
return c > 0;
|
||||
}
|
||||
|
@ -43,8 +43,7 @@ std::vector<RegistrationInfo>* MutableRegistry() {
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
void CopyTensor::ViaDMA(const string& edge_name,
|
||||
DeviceContext* send_dev_context,
|
||||
void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
||||
DeviceContext* recv_dev_context, Device* src,
|
||||
Device* dst, const AllocatorAttributes src_alloc_attr,
|
||||
const AllocatorAttributes dst_alloc_attr,
|
||||
|
@ -42,7 +42,7 @@ class CopyTensor {
|
||||
// the type of devices and memory in use, the copy may be performed
|
||||
// synchronously or asynchronously. 'done' will be invoked only
|
||||
// after the copy is actually complete.
|
||||
static void ViaDMA(const string& edge_name, DeviceContext* send_dev_context,
|
||||
static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
||||
DeviceContext* recv_dev_context, Device* src, Device* dst,
|
||||
const AllocatorAttributes src_alloc_attr,
|
||||
const AllocatorAttributes dst_alloc_attr,
|
||||
|
@ -24,13 +24,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) {
|
||||
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
|
||||
: name_backing_store_(128) {
|
||||
for (Device* d : devices) {
|
||||
devices_.push_back(d);
|
||||
|
||||
// Register under both the full name and the local name.
|
||||
device_map_[d->name()] = d;
|
||||
device_map_[DeviceNameUtils::LocalName(d->name())] = d;
|
||||
string full_name = d->name();
|
||||
device_map_[CopyToBackingStore(full_name)] = d;
|
||||
|
||||
string lname = DeviceNameUtils::LocalName(d->name());
|
||||
device_map_[CopyToBackingStore(lname)] = d;
|
||||
device_type_counts_[d->device_type()]++;
|
||||
}
|
||||
}
|
||||
@ -39,6 +43,13 @@ DeviceMgr::~DeviceMgr() {
|
||||
for (auto p : devices_) delete p;
|
||||
}
|
||||
|
||||
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
|
||||
int n = s.size();
|
||||
char* space = name_backing_store_.Alloc(n);
|
||||
memcpy(space, s.data(), n);
|
||||
return StringPiece(space, n);
|
||||
}
|
||||
|
||||
void DeviceMgr::ListDeviceAttributes(
|
||||
std::vector<DeviceAttributes>* devices) const {
|
||||
devices->reserve(devices_.size());
|
||||
@ -70,7 +81,7 @@ string DeviceMgr::DeviceMappingString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
Status DeviceMgr::LookupDevice(const string& name, Device** device) const {
|
||||
Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
|
||||
Status s;
|
||||
auto iter = device_map_.find(name);
|
||||
if (iter == device_map_.end()) {
|
||||
|
@ -22,7 +22,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/core/arena.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -49,7 +51,7 @@ class DeviceMgr {
|
||||
|
||||
// Assigns *device with pointer to Device of the given name.
|
||||
// Accepts either a full device name, or just the replica-local suffix.
|
||||
Status LookupDevice(const string& name, Device** device) const;
|
||||
Status LookupDevice(StringPiece name, Device** device) const;
|
||||
|
||||
// Clears given containers of all devices if 'container' is
|
||||
// non-empty. Otherwise, clears default containers of all devices.
|
||||
@ -60,7 +62,11 @@ class DeviceMgr {
|
||||
private:
|
||||
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
|
||||
DeviceVec devices_;
|
||||
std::unordered_map<string, Device*> device_map_;
|
||||
|
||||
StringPiece CopyToBackingStore(StringPiece s);
|
||||
|
||||
std::unordered_map<StringPiece, Device*, StringPiece::Hasher> device_map_;
|
||||
core::Arena name_backing_store_; // Storage for keys in device_map_
|
||||
std::unordered_map<string, int> device_type_counts_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
|
||||
|
@ -551,6 +551,7 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
IntraProcessRendezvous* rendez) {
|
||||
Status s;
|
||||
Rendezvous::ParsedKey parsed;
|
||||
// Insert the input tensors into the local rendezvous by their
|
||||
// rendezvous key.
|
||||
for (const auto& input : inputs) {
|
||||
@ -560,7 +561,14 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
||||
"' is not a pre-defined feed!");
|
||||
}
|
||||
const string& input_key = it->second;
|
||||
s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
|
||||
|
||||
s = Rendezvous::ParseKey(input_key, &parsed);
|
||||
if (!s.ok()) {
|
||||
rendez->StartAbort(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
|
||||
if (!s.ok()) {
|
||||
rendez->StartAbort(s);
|
||||
return s;
|
||||
@ -578,6 +586,7 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
outputs->resize(output_names.size());
|
||||
}
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
// Get the outputs from the rendezvous
|
||||
for (size_t output_offset = 0; output_offset < output_names.size();
|
||||
++output_offset) {
|
||||
@ -591,14 +600,16 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
const string& output_key = it->second;
|
||||
Tensor output_tensor;
|
||||
bool is_dead;
|
||||
|
||||
// Fetch data from the Rendezvous.
|
||||
IntraProcessRendezvous* rendez = run_state->rendez;
|
||||
s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
|
||||
if (is_dead && s.ok()) {
|
||||
s = errors::InvalidArgument("The tensor returned for ",
|
||||
output_names[output_offset],
|
||||
" was not valid.");
|
||||
|
||||
s = Rendezvous::ParseKey(output_key, &parsed);
|
||||
if (s.ok()) {
|
||||
// Fetch data from the Rendezvous.
|
||||
s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead);
|
||||
if (is_dead && s.ok()) {
|
||||
s = errors::InvalidArgument("The tensor returned for ", output_name,
|
||||
" was not valid.");
|
||||
}
|
||||
}
|
||||
if (!s.ok()) {
|
||||
rendez->StartAbort(s);
|
||||
|
@ -645,6 +645,43 @@ class ExecutorState {
|
||||
}
|
||||
};
|
||||
|
||||
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||
// have that many nodes in the ready queue, so we just use a vector and
|
||||
// don't free up memory from the queue as we consume nodes.
|
||||
class TaggedNodeReadyQueue {
|
||||
public:
|
||||
TaggedNodeReadyQueue() : front_index_(0) {}
|
||||
|
||||
void push_back(TaggedNode node) { ready_.push_back(node); }
|
||||
TaggedNode front() const {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
return ready_[front_index_];
|
||||
}
|
||||
void pop_front() {
|
||||
DCHECK_LT(front_index_, ready_.size());
|
||||
front_index_++;
|
||||
if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
|
||||
if (front_index_ == ready_.size()) {
|
||||
ready_.clear();
|
||||
} else {
|
||||
// Lots of unused entries at beginning of vector: move everything down
|
||||
// to start of vector.
|
||||
ready_.erase(ready_.begin(), ready_.begin() + front_index_);
|
||||
}
|
||||
front_index_ = 0;
|
||||
}
|
||||
}
|
||||
bool empty() const { return ready_.empty(); }
|
||||
const TaggedNode* begin() const { return ready_.begin() + front_index_; }
|
||||
const TaggedNode* end() const { return ready_.end(); }
|
||||
|
||||
private:
|
||||
gtl::InlinedVector<TaggedNode, 16> ready_;
|
||||
int front_index_;
|
||||
};
|
||||
|
||||
struct AsyncState;
|
||||
|
||||
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||
typedef gtl::InlinedVector<Entry, 4> EntryVector;
|
||||
|
||||
@ -767,15 +804,15 @@ class ExecutorState {
|
||||
// "node" just finishes. Takes ownership of "stats". Returns true if
|
||||
// execution has completed.
|
||||
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
|
||||
NodeExecStats* stats, std::deque<TaggedNode>* inline_ready);
|
||||
NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready);
|
||||
|
||||
// Call Process() on all nodes in 'inline_ready'.
|
||||
void ProcessInline(const std::deque<TaggedNode>& inline_ready);
|
||||
void ProcessInline(const TaggedNodeReadyQueue& inline_ready);
|
||||
|
||||
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
|
||||
// nodes in 'ready' into 'inline_ready'.
|
||||
void ScheduleReady(const TaggedNodeSeq& ready,
|
||||
std::deque<TaggedNode>* inline_ready);
|
||||
TaggedNodeReadyQueue* inline_ready);
|
||||
|
||||
// Provide debugging output about an outstanding node in the executor.
|
||||
void DumpCompletedNodeState(const int node_id, const Entry* input_vector);
|
||||
@ -905,43 +942,55 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// State kept alive for executing an asynchronous node in another
|
||||
// thread. NOTE: We need to make a copy of p.input,
|
||||
// p.input_device_contexts, and p.input_alloc_attrs for asynchronous
|
||||
// kernels because OpKernelContext methods like input_type(i) needs
|
||||
// the param points to valid input type vector. It's not an issue for
|
||||
// sync kernels because these vectors are kept on the stack.
|
||||
struct ExecutorState::AsyncState {
|
||||
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
|
||||
const NodeItem& _item, Entry* _first_input, NodeExecStats* _stats)
|
||||
: saved_inputs(*p.inputs),
|
||||
saved_input_device_contexts(*p.input_device_contexts),
|
||||
saved_input_alloc_attrs(*p.input_alloc_attrs),
|
||||
params(p),
|
||||
tagged_node(_tagged_node),
|
||||
item(_item),
|
||||
first_input(_first_input),
|
||||
// ParamsButClearingEigenGPUDevice does equivalent of
|
||||
// params.eigen_gpu_device = nullptr;
|
||||
ctx(ParamsButClearingEigenGPUDevice(¶ms), item.num_outputs),
|
||||
stats(_stats) {
|
||||
params.inputs = &saved_inputs;
|
||||
params.input_device_contexts = &saved_input_device_contexts;
|
||||
params.input_alloc_attrs = &saved_input_alloc_attrs;
|
||||
}
|
||||
|
||||
// Helpers to make a copy of 'p' and makes a copy of the input type
|
||||
// vector and the device context vector.
|
||||
//
|
||||
// NOTE: We need to make a copy of p.input for asynchronous kernel
|
||||
// because OpKernelContext methods like input_type(i) needs the param
|
||||
// points to valid input type vector. It's not an issue for sync
|
||||
// kernels because the type vector is kept on the stack.
|
||||
OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
|
||||
OpKernelContext::Params* ret = new OpKernelContext::Params;
|
||||
*ret = p;
|
||||
// Ensure the copy of Params will make a new eigen GPU device if
|
||||
// necessary.
|
||||
ret->eigen_gpu_device = nullptr;
|
||||
ret->inputs = new TensorValueVec(*p.inputs);
|
||||
ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
|
||||
ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
|
||||
return ret;
|
||||
}
|
||||
TensorValueVec saved_inputs;
|
||||
DeviceContextVec saved_input_device_contexts;
|
||||
AllocatorAttributeVec saved_input_alloc_attrs;
|
||||
OpKernelContext::Params params;
|
||||
TaggedNode tagged_node;
|
||||
NodeItem item;
|
||||
Entry* first_input;
|
||||
OpKernelContext ctx;
|
||||
NodeExecStats* stats;
|
||||
|
||||
// Helpers to delete 'p' and copies made by CopyParams.
|
||||
void DeleteParams(OpKernelContext::Params* p) {
|
||||
// No need to delete p->eigen_gpu_device since that is deleted in
|
||||
// p's destructor
|
||||
delete p->inputs;
|
||||
delete p->input_device_contexts;
|
||||
delete p->input_alloc_attrs;
|
||||
delete p;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
private:
|
||||
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
|
||||
OpKernelContext::Params* p) {
|
||||
// Ensure OpKernelContext constructor will make a new eigen GPU device if
|
||||
// necessary.
|
||||
p->eigen_gpu_device = nullptr; // Force allocation
|
||||
return p;
|
||||
}
|
||||
};
|
||||
|
||||
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
const NodeItem* nodes = impl_->nodes_;
|
||||
TaggedNodeSeq ready;
|
||||
std::deque<TaggedNode> inline_ready;
|
||||
TaggedNodeReadyQueue inline_ready;
|
||||
|
||||
// Parameters passed to OpKernel::Compute.
|
||||
TensorValueVec inputs;
|
||||
@ -1059,20 +1108,25 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
AsyncOpKernel* async = item.kernel->AsAsync();
|
||||
DCHECK(async != nullptr);
|
||||
launched_asynchronously = true;
|
||||
auto pcopy = CopyParams(params);
|
||||
auto ctx = new OpKernelContext(pcopy, item.num_outputs);
|
||||
auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy,
|
||||
device]() {
|
||||
AsyncState* state =
|
||||
new AsyncState(params, tagged_node, item, first_input, stats);
|
||||
|
||||
auto done = [this, state]() {
|
||||
|
||||
Device* device = impl_->params_.device;
|
||||
NodeExecStats* stats = state->stats; // Shorthand
|
||||
Entry* first_input = state->first_input; // Shorthand
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << this << " Async kernel done: "
|
||||
<< SummarizeNodeDef(item.node->def());
|
||||
<< SummarizeNodeDef(state->item.node->def());
|
||||
}
|
||||
if (stats_collector_) nodestats::SetOpEnd(stats);
|
||||
EntryVector outputs;
|
||||
Status s = ProcessOutputs(item, ctx, &outputs, stats);
|
||||
if (stats_collector_) nodestats::SetMemory(stats, ctx);
|
||||
Status s = ProcessOutputs(state->item, &state->ctx, &outputs, stats);
|
||||
if (stats_collector_) nodestats::SetMemory(stats, &state->ctx);
|
||||
// Clears inputs.
|
||||
int num_inputs = item.num_inputs;
|
||||
const int num_inputs = state->item.num_inputs;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
(first_input + i)->val = *kEmptyTensor;
|
||||
}
|
||||
@ -1080,31 +1134,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
// add better optional debugging support.
|
||||
if (vlog_ && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(mu_);
|
||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->mark_completed(tagged_node.node->id());
|
||||
state->tagged_node.input_frame
|
||||
->GetIteration(state->tagged_node.input_iter)
|
||||
->mark_completed(state->tagged_node.node->id());
|
||||
}
|
||||
TaggedNodeSeq ready;
|
||||
if (s.ok()) {
|
||||
PropagateOutputs(tagged_node, outputs, &ready);
|
||||
PropagateOutputs(state->tagged_node, outputs, &ready);
|
||||
}
|
||||
outputs.clear();
|
||||
if (s.ok() && pcopy->device->RequiresRecordingAccessedTensors()) {
|
||||
if (s.ok() &&
|
||||
state->params.device->RequiresRecordingAccessedTensors()) {
|
||||
// Get the list of all tensors accessed during the execution
|
||||
TensorReferenceVector accessed;
|
||||
ctx->retrieve_accessed_tensors(&accessed);
|
||||
state->ctx.retrieve_accessed_tensors(&accessed);
|
||||
if (stats_collector_)
|
||||
nodestats::SetReferencedTensors(stats, accessed);
|
||||
// callee takes ownership of the vector
|
||||
device->ConsumeListOfAccessedTensors(ctx->op_device_context(),
|
||||
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
|
||||
accessed);
|
||||
}
|
||||
bool completed = NodeDone(s, item.node, ready, stats, nullptr);
|
||||
delete ctx;
|
||||
DeleteParams(pcopy);
|
||||
bool completed = NodeDone(s, state->item.node, ready, stats, nullptr);
|
||||
delete state;
|
||||
if (completed) Finish();
|
||||
};
|
||||
if (stats_collector_) nodestats::SetOpStart(stats);
|
||||
device->ComputeAsync(async, ctx, done);
|
||||
device->ComputeAsync(async, &state->ctx, done);
|
||||
} else {
|
||||
// Synchronous computes.
|
||||
OpKernelContext ctx(¶ms, item.num_outputs);
|
||||
@ -1497,7 +1552,7 @@ void ExecutorState::AddLoopInv(FrameState* frame, const Node* node,
|
||||
|
||||
bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
||||
const TaggedNodeSeq& ready, NodeExecStats* stats,
|
||||
std::deque<TaggedNode>* inline_ready) {
|
||||
TaggedNodeReadyQueue* inline_ready) {
|
||||
if (stats_collector_) {
|
||||
nodestats::SetAllEnd(stats);
|
||||
stats_collector_->UpdateCostModelNode(stats, impl_->graph_, node);
|
||||
@ -1542,7 +1597,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
|
||||
return completed;
|
||||
}
|
||||
|
||||
void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
|
||||
void ExecutorState::ProcessInline(const TaggedNodeReadyQueue& inline_ready) {
|
||||
if (inline_ready.empty()) return;
|
||||
int64 scheduled_usec = 0;
|
||||
if (stats_collector_) {
|
||||
@ -1554,7 +1609,7 @@ void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
|
||||
}
|
||||
|
||||
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
|
||||
std::deque<TaggedNode>* inline_ready) {
|
||||
TaggedNodeReadyQueue* inline_ready) {
|
||||
if (ready.empty()) return;
|
||||
|
||||
int64 scheduled_usec = 0;
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/common_runtime/gpu_device_context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -30,7 +30,7 @@ void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
}
|
||||
|
||||
void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
const string& tensor_name,
|
||||
StringPiece tensor_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
|
||||
|
@ -56,9 +56,9 @@ class GPUDeviceContext : public DeviceContext {
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const override;
|
||||
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
const string& edge_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) override;
|
||||
void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name,
|
||||
Device* device, Tensor* cpu_tensor,
|
||||
StatusCallback done) override;
|
||||
|
||||
void MaintainLifetimeOnStream(
|
||||
const Tensor* t, perftools::gputools::Stream* stream) const override {}
|
||||
|
@ -39,7 +39,6 @@ namespace test {
|
||||
|
||||
Benchmark::Benchmark(const string& device, Graph* g,
|
||||
const SessionOptions* options, Graph* init) {
|
||||
|
||||
SessionOptions default_options;
|
||||
if (!options) {
|
||||
options = &default_options;
|
||||
@ -138,11 +137,15 @@ void Benchmark::RunWithArgs(
|
||||
};
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
for (const auto& p : in) {
|
||||
rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||
rendez_->Send(parsed, Rendezvous::Args(), p.second, false);
|
||||
}
|
||||
TF_CHECK_OK(exec_->Run(args));
|
||||
for (const string& key : out) {
|
||||
rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||
rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead);
|
||||
}
|
||||
}
|
||||
TF_CHECK_OK(device_->Sync());
|
||||
@ -150,11 +153,15 @@ void Benchmark::RunWithArgs(
|
||||
testing::StartTiming();
|
||||
while (iters-- > 0) {
|
||||
for (const auto& p : in) {
|
||||
rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||
rendez_->Send(parsed, Rendezvous::Args(), p.second, false);
|
||||
}
|
||||
TF_CHECK_OK(exec_->Run(args));
|
||||
for (const string& key : out) {
|
||||
rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||
rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,19 +36,17 @@ IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
|
||||
|
||||
IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
|
||||
|
||||
Status IntraProcessRendezvous::Send(const string& key,
|
||||
Status IntraProcessRendezvous::Send(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key;
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << parsed.FullKey();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_->Send(key, args, val, is_dead);
|
||||
return local_->Send(parsed, args, val, is_dead);
|
||||
}
|
||||
|
||||
Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
|
||||
@ -111,24 +109,17 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
done);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::RecvAsync(const string& key,
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, false /*!is_src*/, &parsed);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << parsed.FullKey();
|
||||
|
||||
// Recv the tensor from local_.
|
||||
local_->RecvAsync(key, recv_args, [this, parsed, done](
|
||||
const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
local_->RecvAsync(parsed, recv_args, [this, parsed, done](
|
||||
const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
Status s = status;
|
||||
Tensor* out = new Tensor;
|
||||
StatusCallback final_callback = [done, send_args, recv_args, out,
|
||||
|
@ -43,14 +43,14 @@ class IntraProcessRendezvous : public Rendezvous {
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
Status Send(const string& key, const Rendezvous::Args& args,
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const string& key, const Rendezvous::Args& args,
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
|
@ -60,23 +60,25 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
|
||||
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
Rendezvous::DoneCallback done) {
|
||||
BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
|
||||
rendez->RecvLocalAsync(
|
||||
key, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v,
|
||||
bool dead) {
|
||||
parsed, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& v,
|
||||
bool dead) {
|
||||
rendez->Unref();
|
||||
done(s, send_args, recv_args, v, dead);
|
||||
});
|
||||
}
|
||||
|
||||
Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key,
|
||||
Status BaseRendezvousMgr::RecvLocal(int64 step_id,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
Tensor* val, bool* is_dead) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RecvLocalAsync(step_id, key,
|
||||
RecvLocalAsync(step_id, parsed,
|
||||
[val, is_dead, &ret, &n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
@ -140,38 +142,35 @@ static bool IsLocalDevice(const WorkerEnv& worker,
|
||||
return device_name.starts_with(worker.worker_name);
|
||||
}
|
||||
|
||||
Status BaseRemoteRendezvous::Send(const string& key,
|
||||
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key;
|
||||
VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
if (!IsLocalDevice(*env_, parsed.src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
||||
parsed.FullKey(), " @ ", env_->worker_name);
|
||||
}
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_->Send(key, args, val, is_dead);
|
||||
return local_->Send(parsed, args, val, is_dead);
|
||||
}
|
||||
|
||||
Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src,
|
||||
Rendezvous::ParsedKey* parsed) {
|
||||
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
|
||||
bool is_src) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) return status_;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
|
||||
if (is_src && !IsLocalDevice(*env_, parsed->src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
if (is_src && !IsLocalDevice(*env_, parsed.src_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (src): ",
|
||||
parsed.FullKey(), " @ ", env_->worker_name);
|
||||
}
|
||||
if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ",
|
||||
env_->worker_name);
|
||||
if (!is_src && !IsLocalDevice(*env_, parsed.dst_device)) {
|
||||
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
|
||||
parsed.FullKey(), " @ ", env_->worker_name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -233,13 +232,11 @@ bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
|
||||
return DeviceNameUtils::IsSameAddressSpace(src, dst);
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "RemoteRendezvous Recv " << this << " " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, false /*!is_src*/, &parsed);
|
||||
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
|
||||
Status s = ValidateDevices(parsed, false /*!is_src*/);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
@ -247,12 +244,13 @@ void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
|
||||
// Are src and dst in the same worker?
|
||||
if (IsSameWorker(parsed.src, parsed.dst)) {
|
||||
Rendezvous::ParsedKey parsed_copy = parsed;
|
||||
// Recv the tensor from local_.
|
||||
local_->RecvAsync(
|
||||
key, recv_args, [this, parsed, done](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
parsed_copy, recv_args,
|
||||
[this, parsed_copy, done](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
|
||||
Status s = status;
|
||||
Tensor* out = new Tensor;
|
||||
StatusCallback final_callback = [done, send_args, recv_args, out,
|
||||
@ -262,27 +260,26 @@ void BaseRemoteRendezvous::RecvAsync(const string& key,
|
||||
};
|
||||
|
||||
if (s.ok()) {
|
||||
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
|
||||
final_callback);
|
||||
SameWorkerRecvDone(parsed_copy, send_args, recv_args, in, out,
|
||||
std::move(final_callback));
|
||||
} else {
|
||||
final_callback(s);
|
||||
}
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
RecvFromRemoteAsync(key, parsed, recv_args, done);
|
||||
RecvFromRemoteAsync(parsed, recv_args, std::move(done));
|
||||
}
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::RecvLocalAsync(const string& key,
|
||||
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
|
||||
DoneCallback done) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = ParseKey(key, true /* is_src */, &parsed);
|
||||
Status s = ValidateDevices(parsed, true /* is_src */);
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), Args(), Tensor(), false);
|
||||
return;
|
||||
}
|
||||
local_->RecvAsync(key, Args(), done);
|
||||
local_->RecvAsync(parsed, Args(), std::move(done));
|
||||
}
|
||||
|
||||
void BaseRemoteRendezvous::StartAbort(const Status& s) {
|
||||
|
@ -68,12 +68,12 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
|
||||
// "done" when the tensor for "key" is produced or an error occurs.
|
||||
//
|
||||
// This method is used by the rpc handler of RecvTensor.
|
||||
void RecvLocalAsync(int64 step_id, const string& key,
|
||||
void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed,
|
||||
Rendezvous::DoneCallback done) override;
|
||||
|
||||
// Synchronous wrapper for RecvLocalAsync.
|
||||
Status RecvLocal(int64 step_id, const string& key, Tensor* val,
|
||||
bool* is_dead) override;
|
||||
Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
|
||||
Tensor* val, bool* is_dead) override;
|
||||
|
||||
// Removes rendezvous for "step_id".
|
||||
//
|
||||
@ -116,14 +116,14 @@ class BaseRemoteRendezvous : public Rendezvous {
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
Status Send(const string& key, const Rendezvous::Args& args,
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const string& key, const Rendezvous::Args& args,
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
@ -134,15 +134,14 @@ class BaseRemoteRendezvous : public Rendezvous {
|
||||
// network. In either case it needs to retrieve a locally buffered
|
||||
// value from local_, and give it to its caller.
|
||||
//
|
||||
// Runs "done" as soon as the tensor for "key" is available or an error
|
||||
// Runs "done" as soon as the tensor for "parsed" is available or an error
|
||||
// is detected.
|
||||
//
|
||||
// REQUIRES: "key" is one that will be Saved into the local rendezvous.
|
||||
void RecvLocalAsync(const string& key, DoneCallback done);
|
||||
// REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
|
||||
void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
|
||||
|
||||
protected:
|
||||
virtual void RecvFromRemoteAsync(const string& key,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
@ -174,11 +173,10 @@ class BaseRemoteRendezvous : public Rendezvous {
|
||||
// Active outstanding RecvTensor calls.
|
||||
std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
|
||||
|
||||
// Parses "key" into "parsed". If "is_src" is true, checks that the
|
||||
// rendezvous key's source is in this process. If "is_src" is false,
|
||||
// checks that the rendezvous key's destination is in this process.
|
||||
Status ParseKey(const string& key, bool is_src,
|
||||
Rendezvous::ParsedKey* parsed);
|
||||
// If "is_src" is true, checks that the rendezvous key "parsed"'s
|
||||
// source is in this process. If "is_src" is false, checks that the
|
||||
// rendezvous key "parsed"'s destination is in this process.
|
||||
Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
|
||||
|
||||
// Callback handling the case when a rendezvous has been
|
||||
// accomplished in local_ and the consumer is local to this process.
|
||||
|
@ -33,7 +33,7 @@ void CallOptions::StartCancel() {
|
||||
|
||||
void CallOptions::SetCancelCallback(CancelFunction cancel_func) {
|
||||
mutex_lock l(mu_);
|
||||
cancel_func_ = cancel_func;
|
||||
cancel_func_ = std::move(cancel_func);
|
||||
}
|
||||
|
||||
void CallOptions::ClearCancelCallback() {
|
||||
|
@ -127,10 +127,15 @@ float V(const Tensor& tensor) {
|
||||
|
||||
static uint64 kIncarnation = 1; // Uses in following tests.
|
||||
|
||||
string Key(const string& sender, const uint64 incarnation,
|
||||
const string& receiver, const string& name) {
|
||||
return Rendezvous::CreateKey(sender, incarnation, receiver, name,
|
||||
FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
|
||||
const string& receiver, const string& name) {
|
||||
Rendezvous::ParsedKey result;
|
||||
CHECK(
|
||||
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
|
||||
name, FrameAndIter(0, 0)),
|
||||
&result)
|
||||
.ok());
|
||||
return result;
|
||||
}
|
||||
|
||||
#define ALICE "/job:j/replica:0/task:0/cpu:0"
|
||||
|
@ -306,10 +306,15 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||
|
||||
// Sends values specified by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (const auto& p : in) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false);
|
||||
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
item->Unref();
|
||||
@ -337,7 +342,10 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||
LogMemory::RecordStep(args.step_id, handle);
|
||||
}
|
||||
thread::ThreadPool* pool = worker_env_->compute_pool;
|
||||
args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
|
||||
using namespace std::placeholders;
|
||||
// Line below is equivalent to this code, but does one less indirect call:
|
||||
// args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
|
||||
args.runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
|
||||
for (const auto& unit : item->units) {
|
||||
unit.root->RunAsync(args, barrier->Get());
|
||||
}
|
||||
@ -347,11 +355,15 @@ void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
|
||||
StatusCallback done, Status s) {
|
||||
if (s.ok()) {
|
||||
// Receives values requested by the caller.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
for (auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
Tensor* val = &p.second;
|
||||
bool is_dead = false;
|
||||
s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead);
|
||||
s = Rendezvous::ParseKey(key, &parsed);
|
||||
if (s.ok()) {
|
||||
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
|
||||
}
|
||||
if (is_dead) {
|
||||
s = errors::InvalidArgument("The tensor returned for ", key,
|
||||
" was not valid.");
|
||||
|
@ -57,12 +57,13 @@ class RendezvousMgrInterface {
|
||||
// "done" when the tensor for "key" is produced or an error occurs.
|
||||
//
|
||||
// This method is used by the rpc handler of RecvTensor.
|
||||
virtual void RecvLocalAsync(int64 step_id, const string& key,
|
||||
virtual void RecvLocalAsync(int64 step_id,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
Rendezvous::DoneCallback done) = 0;
|
||||
|
||||
// Synchronous wrapper for RecvLocalAsync.
|
||||
virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val,
|
||||
bool* is_dead) = 0;
|
||||
virtual Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
|
||||
Tensor* val, bool* is_dead) = 0;
|
||||
|
||||
// Removes rendezvous for "step_id".
|
||||
//
|
||||
|
@ -97,49 +97,66 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
req_copy->set_dma_ok(false);
|
||||
}
|
||||
// Type-specialized logging for this method.
|
||||
StatusCallback logging_callback = [this, request, req_copy, response, done,
|
||||
start_usec](Status s) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
int64 bytes = response->tensor().ByteSize();
|
||||
int64 send_start_usec = start_usec;
|
||||
// If a send start time was reported by the other side, use
|
||||
// that instead. Maybe we should mark the display if we're using
|
||||
// our local time instead of the remote start time?
|
||||
if (response->send_start_micros()) {
|
||||
// send_start_micros is the timestamp taken when the remote
|
||||
// machine began to send the RecvTensor response.
|
||||
// Due to clock skew between source and dest machines, it is
|
||||
// possible that send_start_micros can be larger than end_usec or
|
||||
// less than start_usec.
|
||||
// To respect causality, we enforce the invariants that the RecvTensor
|
||||
// response can not have been sent before the RecvTensor request, and
|
||||
// must have been sent before it was received.
|
||||
send_start_usec = std::max(start_usec, response->send_start_micros());
|
||||
send_start_usec = std::min(send_start_usec, end_usec - 1);
|
||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||
StatusCallback wrapper_done;
|
||||
const StatusCallback* cb_to_use;
|
||||
if (!logging_active && req_copy == nullptr) {
|
||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
||||
} else if (!logging_active) {
|
||||
wrapper_done = [req_copy, done](Status s) {
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
} else {
|
||||
wrapper_done = [this, request, req_copy, response, done,
|
||||
start_usec](Status s) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
int64 bytes = response->tensor().ByteSize();
|
||||
int64 send_start_usec = start_usec;
|
||||
// If a send start time was reported by the other side, use
|
||||
// that instead. Maybe we should mark the display if we're using
|
||||
// our local time instead of the remote start time?
|
||||
if (response->send_start_micros()) {
|
||||
// send_start_micros is the timestamp taken when the
|
||||
// remote machine began to send the RecvTensor response.
|
||||
// Due to clock skew between source and dest machines, it
|
||||
// is possible that send_start_micros can be larger than
|
||||
// end_usec or less than start_usec.
|
||||
//
|
||||
// To respect causality, we enforce the invariants that
|
||||
// the RecvTensor response can not have been sent before
|
||||
// the RecvTensor request, and must have been sent before
|
||||
// it was received.
|
||||
send_start_usec =
|
||||
std::max(start_usec, response->send_start_micros());
|
||||
send_start_usec = std::min(send_start_usec, end_usec - 1);
|
||||
}
|
||||
const string& key = request->rendezvous_key();
|
||||
std::vector<string> key_parts = str_util::Split(key, ';');
|
||||
if (key_parts.size() != 5) {
|
||||
LOG(WARNING) << "Bad key: " << key;
|
||||
} else {
|
||||
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
|
||||
key_parts[3], // tensor name
|
||||
key_parts[0], // src_device
|
||||
key_parts[2], // dst_device
|
||||
bytes);
|
||||
}
|
||||
}
|
||||
const string& key = request->rendezvous_key();
|
||||
std::vector<string> key_parts = str_util::Split(key, ';');
|
||||
if (key_parts.size() != 5) {
|
||||
LOG(WARNING) << "Bad key: " << key;
|
||||
} else {
|
||||
logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
|
||||
key_parts[3], // tensor name
|
||||
key_parts[0], // src_device
|
||||
key_parts[2], // dst_device
|
||||
bytes);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->DebugString();
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->DebugString();
|
||||
delete req_copy;
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
}
|
||||
|
||||
IssueRequest(req_copy ? req_copy : request, response,
|
||||
&grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback,
|
||||
call_opts);
|
||||
&grpc::WorkerService::Stub::AsyncRecvTensor,
|
||||
std::move(*cb_to_use), call_opts);
|
||||
}
|
||||
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
|
@ -349,11 +349,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
|
||||
// Helper for RecvTensor. Validates "key" and returns the source
|
||||
// device in "*src_dev".
|
||||
Status PrepareRecvTensor(const string& key, Device** src_dev) {
|
||||
// Validate the key.
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
||||
Device** src_dev) {
|
||||
// Figures out which device the tensor is hosted on.
|
||||
TF_RETURN_IF_ERROR(
|
||||
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
|
||||
@ -375,8 +372,12 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
const int64 step_id = call->request.step_id();
|
||||
const string& key = call->request.rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
Device* src_dev = nullptr;
|
||||
Status s = PrepareRecvTensor(key, &src_dev);
|
||||
if (s.ok()) {
|
||||
s = PrepareRecvTensor(parsed, &src_dev);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
return;
|
||||
@ -388,7 +389,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
// cancellation should abort the rendezvous.
|
||||
call->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
|
||||
env_->rendezvous_mgr->RecvLocalAsync(
|
||||
step_id, key,
|
||||
step_id, parsed,
|
||||
[this, call, src_dev](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
|
@ -41,8 +41,7 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
: BaseRemoteRendezvous(env, step_id, false) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const string& key,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
@ -55,23 +54,49 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
// Used only to retrieve tensors from remote processes.
|
||||
class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
public:
|
||||
RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi,
|
||||
int64 step_id, const string& key,
|
||||
const string& remote_dev, Allocator* allocator,
|
||||
Device* dst_device)
|
||||
: wi_(wi),
|
||||
wc_(wc),
|
||||
remote_dev_(remote_dev),
|
||||
allocator_(allocator),
|
||||
dst_(dst_device) {
|
||||
RpcRecvTensorCall()
|
||||
: wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
|
||||
|
||||
void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id,
|
||||
StringPiece key, Allocator* allocator, Device* dst_device,
|
||||
const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
|
||||
wi_ = wi;
|
||||
wc_ = wc;
|
||||
allocator_ = allocator;
|
||||
dst_device_ = dst_device;
|
||||
recv_args_ = recv_args;
|
||||
done_ = std::move(done);
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(key);
|
||||
req_.set_rendezvous_key(key.data(), key.size());
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
delete wi_;
|
||||
wi_ = nullptr;
|
||||
wc_ = nullptr;
|
||||
allocator_ = nullptr;
|
||||
dst_device_ = nullptr;
|
||||
// We don't clear opts_ and assume that Init will set up the state for
|
||||
// opts_ appropriately.
|
||||
req_.Clear();
|
||||
if (resp_.ByteSize() > 128) {
|
||||
// Clear memory from resp_ if it is too large
|
||||
RecvTensorResponse empty;
|
||||
resp_.Swap(&empty);
|
||||
} else {
|
||||
resp_.Clear();
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
status_ = Status::OK();
|
||||
}
|
||||
done_ = nullptr;
|
||||
}
|
||||
|
||||
~RpcRecvTensorCall() override { delete wi_; }
|
||||
|
||||
void Start(std::function<void()> recv_done) override {
|
||||
StartRTCall(recv_done);
|
||||
StartRTCall(std::move(recv_done));
|
||||
}
|
||||
|
||||
void StartAbort(const Status& s) override {
|
||||
@ -93,6 +118,10 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
|
||||
bool is_dead() const { return resp_.is_dead(); }
|
||||
|
||||
Device* dst_device() const { return dst_device_; }
|
||||
const Rendezvous::Args& recv_args() const { return recv_args_; }
|
||||
const Rendezvous::DoneCallback& done() const { return done_; }
|
||||
|
||||
private:
|
||||
// Start the main RecvTensor call, checking for an async abort.
|
||||
void StartRTCall(std::function<void()> recv_done) {
|
||||
@ -100,7 +129,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
nullptr /* TensorBufAllocator */,
|
||||
// done callback
|
||||
[this, recv_done](const Status& s) {
|
||||
{
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
status_.Update(s);
|
||||
}
|
||||
@ -110,12 +139,13 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
|
||||
WorkerInterface* wi_; // Owned.
|
||||
WorkerCacheInterface* wc_; // Not owned.
|
||||
string remote_dev_;
|
||||
Allocator* allocator_;
|
||||
Device* dst_;
|
||||
Device* dst_device_;
|
||||
CallOptions opts_;
|
||||
RecvTensorRequest req_;
|
||||
RecvTensorResponse resp_;
|
||||
Rendezvous::Args recv_args_;
|
||||
Rendezvous::DoneCallback done_;
|
||||
|
||||
mutable mutex mu_;
|
||||
Status status_ GUARDED_BY(mu_);
|
||||
@ -123,10 +153,53 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
|
||||
};
|
||||
|
||||
namespace {
|
||||
class RpcRecvTensorFreeList {
|
||||
public:
|
||||
RpcRecvTensorFreeList() {}
|
||||
~RpcRecvTensorFreeList() {
|
||||
for (int i = 0; i < objects_.size(); i++) {
|
||||
delete objects_[i];
|
||||
}
|
||||
}
|
||||
|
||||
RpcRecvTensorCall* New() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!objects_.empty()) {
|
||||
RpcRecvTensorCall* result = objects_.back();
|
||||
objects_.pop_back();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return new RpcRecvTensorCall;
|
||||
}
|
||||
|
||||
void Release(RpcRecvTensorCall* obj) {
|
||||
obj->Reset();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (objects_.size() < kMaxObjects) {
|
||||
objects_.push_back(obj);
|
||||
return;
|
||||
}
|
||||
}
|
||||
delete obj;
|
||||
}
|
||||
|
||||
private:
|
||||
static const int kMaxObjects = 1000;
|
||||
|
||||
mutex mu_;
|
||||
std::vector<RpcRecvTensorCall*> objects_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
static RpcRecvTensorFreeList call_freelist_;
|
||||
}
|
||||
|
||||
void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const string& key, const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args, DoneCallback done) {
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
Status s;
|
||||
|
||||
// key.src_device identifies a remote device.
|
||||
@ -137,11 +210,15 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
s = errors::Internal(parsed.src_device,
|
||||
" is invalid remote source device.");
|
||||
}
|
||||
// TODO(jeff): Consider checking for a valid worker_cache during the
|
||||
// constructor of RpcRemoteRendezvous, rather than here, to simplify
|
||||
// the twisty logic below.
|
||||
WorkerCacheInterface* worker_cache = env_->worker_cache;
|
||||
if (s.ok() && worker_cache == nullptr) {
|
||||
s = errors::Internal("No remote worker cache available.");
|
||||
}
|
||||
WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker);
|
||||
WorkerInterface* rwi =
|
||||
(worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr);
|
||||
if (s.ok() && rwi == nullptr) {
|
||||
s = errors::Internal("No worker known as ", src_worker);
|
||||
}
|
||||
@ -157,15 +234,16 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
|
||||
|
||||
// Prepare a RecvTensor call that can handle being aborted.
|
||||
RpcRecvTensorCall* call =
|
||||
new RpcRecvTensorCall(worker_cache, rwi, step_id_, key,
|
||||
parsed.src_device, allocator, dst_device);
|
||||
RpcRecvTensorCall* call = call_freelist_.New();
|
||||
|
||||
call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator,
|
||||
dst_device, recv_args, std::move(done));
|
||||
|
||||
// Record "call" in active_ so that it can be aborted cleanly.
|
||||
RegisterCall(call);
|
||||
|
||||
// Start "call".
|
||||
call->Start([this, call, parsed, recv_args, done]() {
|
||||
call->Start([this, call]() {
|
||||
// Removes "call" from active_. Prevent StartAbort().
|
||||
DeregisterCall(call);
|
||||
// If StartAbort was called prior to DeregisterCall, then the
|
||||
@ -173,24 +251,19 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
Status s = call->status();
|
||||
Tensor val;
|
||||
if (s.ok()) {
|
||||
Device* dst_device;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (s.ok()) {
|
||||
s = dst_device->MakeTensorFromProto(call->tensor_proto(),
|
||||
recv_args.alloc_attrs, &val);
|
||||
}
|
||||
s = call->dst_device()->MakeTensorFromProto(
|
||||
call->tensor_proto(), call->recv_args().alloc_attrs, &val);
|
||||
}
|
||||
done(s, Args(), recv_args, val, call->is_dead());
|
||||
delete call;
|
||||
call->done()(s, Args(), call->recv_args(), val, call->is_dead());
|
||||
call_freelist_.Release(call);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) {
|
||||
const WorkerEnv* worker_env) {
|
||||
return new RpcRemoteRendezvous(worker_env, step_id);
|
||||
}
|
||||
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -40,15 +40,21 @@ string V(const Tensor& tensor) {
|
||||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
Rendezvous::ParsedKey MakeKey(const string& s) {
|
||||
Rendezvous::ParsedKey key;
|
||||
CHECK(Rendezvous::ParseKey(s, &key).ok());
|
||||
return key;
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
@ -69,9 +75,9 @@ TEST(RpcRendezvousMgrTest, LocalAbort) {
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{ // Explicit Abort().
|
||||
const int64 step_id = 123;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
@ -105,9 +111,9 @@ TEST(RpcRendezvousMgrTest, CleanupAll) {
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
const int64 step_id = 123;
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
@ -139,9 +145,9 @@ TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const string key = Rendezvous::CreateKey(
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
|
||||
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
|
||||
{
|
||||
Rendezvous* rendez = rmgr.Find(step_id);
|
||||
core::ScopedUnref unref(rendez);
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace Eigen {
|
||||
@ -80,7 +81,7 @@ class DeviceContext : public core::RefCounted {
|
||||
// device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
|
||||
// to be of the same size as "device_tensor".
|
||||
virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
const string& tensor_name, Device* device,
|
||||
StringPiece tensor_name, Device* device,
|
||||
Tensor* cpu_tensor, StatusCallback done) {
|
||||
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -30,6 +31,21 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) {
|
||||
const char* b_base = b.buf_.data();
|
||||
buf_ = b.buf_;
|
||||
src_device.set(buf_.data() + (b.src_device.data() - b_base),
|
||||
b.src_device.size());
|
||||
src = b.src;
|
||||
src_incarnation = b.src_incarnation;
|
||||
dst_device.set(buf_.data() + (b.dst_device.data() - b_base),
|
||||
b.dst_device.size());
|
||||
dst = b.dst;
|
||||
edge_name.set(buf_.data() + (b.edge_name.data() - b_base),
|
||||
b.edge_name.size());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/* static */
|
||||
string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
|
||||
const string& dst_device, const string& name,
|
||||
@ -66,7 +82,8 @@ static StringPiece ConsumeNextPart(StringPiece* s, char delim) {
|
||||
|
||||
/* static */
|
||||
Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
|
||||
StringPiece s(key);
|
||||
out->buf_ = key; // Make a copy that our StringPieces can point at
|
||||
StringPiece s(out->buf_);
|
||||
StringPiece parts[5];
|
||||
for (int i = 0; i < 5; i++) {
|
||||
parts[i] = ConsumeNextPart(&s, ';');
|
||||
@ -77,9 +94,9 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
|
||||
strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
|
||||
DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
|
||||
!parts[3].empty()) {
|
||||
out->src_device.assign(parts[0].data(), parts[0].size());
|
||||
out->dst_device.assign(parts[2].data(), parts[2].size());
|
||||
out->edge_name.assign(parts[3].data(), parts[3].size());
|
||||
out->src_device.set(parts[0].data(), parts[0].size());
|
||||
out->dst_device.set(parts[2].data(), parts[2].size());
|
||||
out->edge_name.set(parts[3].data(), parts[3].size());
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Invalid rendezvous key: ", key);
|
||||
@ -87,8 +104,8 @@ Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
|
||||
|
||||
Rendezvous::~Rendezvous() {}
|
||||
|
||||
Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val,
|
||||
bool* is_dead) {
|
||||
Status Rendezvous::Recv(const ParsedKey& key, const Args& recv_args,
|
||||
Tensor* val, bool* is_dead) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RecvAsync(key, recv_args,
|
||||
@ -109,18 +126,19 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
explicit LocalRendezvousImpl(bool tolerate_dup_recv)
|
||||
: tolerate_dup_recv_(tolerate_dup_recv) {}
|
||||
|
||||
Status Send(const string& key, const Args& send_args, const Tensor& val,
|
||||
Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
VLOG(2) << "Send " << this << " " << key;
|
||||
DoneCallback waiter = nullptr;
|
||||
Args recv_args;
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
VLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!status_.ok()) {
|
||||
return status_;
|
||||
}
|
||||
Item* item = nullptr;
|
||||
Table::iterator iter = table_.find(key);
|
||||
Table::iterator iter = table_.find(key_hash);
|
||||
if (iter == table_.end()) {
|
||||
// There is no waiter for this message. Insert the message
|
||||
// into the waiters table. The waiter will pick it up when
|
||||
@ -138,19 +156,24 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// The allocator attributes of item->value.
|
||||
item->send_alloc_attrs = send_args.alloc_attrs;
|
||||
|
||||
CHECK(table_.insert({key, item}).second);
|
||||
CHECK(table_.insert({key_hash, item}).second);
|
||||
return Status::OK();
|
||||
} else {
|
||||
item = iter->second;
|
||||
|
||||
if (item->waiter == nullptr) {
|
||||
// There is already a message in the table under the key.
|
||||
// Should not happen unless it has a waiter.
|
||||
return errors::Aborted("Duplicated send: ", key);
|
||||
return errors::Aborted("Duplicated send: ", key.FullKey());
|
||||
}
|
||||
// Mark item as complete.
|
||||
item->has_been_recvd = true;
|
||||
waiter = item->waiter;
|
||||
item->waiter = nullptr;
|
||||
|
||||
// Get item->waiter function into waiter and set item->waiter to null
|
||||
std::swap(item->waiter, waiter);
|
||||
DCHECK(item->waiter == nullptr);
|
||||
DCHECK(waiter != nullptr);
|
||||
|
||||
// The ref on recv_dev_context transfers below.
|
||||
recv_args.device_context = item->recv_dev_context;
|
||||
recv_args.alloc_attrs = item->recv_alloc_attrs;
|
||||
@ -173,9 +196,10 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecvAsync(const string& key, const Args& recv_args,
|
||||
void RecvAsync(const ParsedKey& key, const Args& recv_args,
|
||||
DoneCallback done) override {
|
||||
VLOG(2) << "Recv " << this << " " << key;
|
||||
uint64 key_hash = KeyHash(key.FullKey());
|
||||
VLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
|
||||
mu_.lock();
|
||||
if (!status_.ok()) {
|
||||
// Rendezvous has been aborted.
|
||||
@ -184,13 +208,13 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
done(s, Args(), recv_args, Tensor(), false);
|
||||
return;
|
||||
}
|
||||
Table::iterator iter = table_.find(key);
|
||||
Table::iterator iter = table_.find(key_hash);
|
||||
if (iter != table_.end()) {
|
||||
Item* item = iter->second;
|
||||
if (item->has_been_recvd && !tolerate_dup_recv_) {
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
|
||||
Tensor(), false);
|
||||
done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(),
|
||||
recv_args, Tensor(), false);
|
||||
} else if (item->waiter == nullptr || tolerate_dup_recv_) {
|
||||
// A message has already arrived and is stored in the table
|
||||
// under this key. Consumes the message and invokes the done
|
||||
@ -218,8 +242,8 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// Already have a waiter in the waiters table under this key,
|
||||
// which should not happen.
|
||||
mu_.unlock();
|
||||
done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
|
||||
Tensor(), false);
|
||||
done(errors::Aborted("Duplicated recv: ", key.FullKey()), Args(),
|
||||
recv_args, Tensor(), false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -227,13 +251,13 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
// waiting table. The done closure will be invoked when the
|
||||
// message arrives.
|
||||
Item* item = new Item;
|
||||
item->waiter = done;
|
||||
item->waiter = std::move(done);
|
||||
item->recv_alloc_attrs = recv_args.alloc_attrs;
|
||||
if (recv_args.device_context) {
|
||||
item->recv_dev_context = recv_args.device_context;
|
||||
item->recv_dev_context->Ref();
|
||||
}
|
||||
CHECK(table_.insert({key, item}).second);
|
||||
CHECK(table_.insert({key_hash, item}).second);
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
@ -280,7 +304,12 @@ class LocalRendezvousImpl : public Rendezvous {
|
||||
}
|
||||
}
|
||||
};
|
||||
typedef std::unordered_map<string, Item*> Table;
|
||||
// We key the hash table by KeyHash of the Rendezvous::CreateKey string
|
||||
static uint64 KeyHash(const StringPiece& k) {
|
||||
return Hash64(k.data(), k.size());
|
||||
}
|
||||
|
||||
typedef std::unordered_map<uint64, Item*> Table;
|
||||
|
||||
// TODO(zhifengc): shard table_.
|
||||
mutex mu_;
|
||||
|
@ -54,12 +54,22 @@ class Rendezvous : public core::RefCounted {
|
||||
// Parses the key constructed by CreateKey and parse src/dst device
|
||||
// names into structures respectively.
|
||||
struct ParsedKey {
|
||||
string src_device;
|
||||
StringPiece src_device;
|
||||
DeviceNameUtils::ParsedName src;
|
||||
uint64 src_incarnation = 0;
|
||||
string dst_device;
|
||||
StringPiece dst_device;
|
||||
DeviceNameUtils::ParsedName dst;
|
||||
string edge_name;
|
||||
StringPiece edge_name;
|
||||
|
||||
ParsedKey() {}
|
||||
ParsedKey(const ParsedKey& b) { *this = b; }
|
||||
|
||||
ParsedKey& operator=(const ParsedKey& b);
|
||||
StringPiece FullKey() const { return buf_; }
|
||||
|
||||
private:
|
||||
friend class Rendezvous;
|
||||
string buf_;
|
||||
};
|
||||
static Status ParseKey(const string& key, ParsedKey* out);
|
||||
|
||||
@ -74,7 +84,7 @@ class Rendezvous : public core::RefCounted {
|
||||
// Send/Recv on the same worker.
|
||||
//
|
||||
// Send() never blocks.
|
||||
virtual Status Send(const string& key, const Args& args, const Tensor& val,
|
||||
virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
|
||||
const bool is_dead) = 0;
|
||||
|
||||
// Callback provided by a tensor consumer waiting on the rendezvous.
|
||||
@ -84,13 +94,15 @@ class Rendezvous : public core::RefCounted {
|
||||
// receiver, which may be needed when a non-CPU device is in use
|
||||
// by either side.
|
||||
typedef std::function<void(const Status&, const Args&, const Args&,
|
||||
const Tensor&, const bool)> DoneCallback;
|
||||
const Tensor&, const bool)>
|
||||
DoneCallback;
|
||||
|
||||
virtual void RecvAsync(const string& key, const Args& args,
|
||||
virtual void RecvAsync(const ParsedKey& key, const Args& args,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
// Synchronous wrapper for RecvAsync.
|
||||
Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead);
|
||||
Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
|
||||
bool* is_dead);
|
||||
|
||||
// Aborts all pending and future Send/Recv with the given "status".
|
||||
//
|
||||
|
@ -96,13 +96,29 @@ string V(const Tensor& tensor) {
|
||||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
const char* kFoo = "/cpu:0;1;/cpu:1;foo;1;2";
|
||||
const char* kBar = "/gpu:0;2;/gpu:1;bar;1;2";
|
||||
|
||||
Rendezvous::ParsedKey MakeKey(const string& name) {
|
||||
string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/GPU:0", name,
|
||||
FrameAndIter(0, 0));
|
||||
Rendezvous::ParsedKey k;
|
||||
TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
|
||||
return k;
|
||||
}
|
||||
|
||||
Rendezvous::ParsedKey KeyFoo() { return MakeKey("foo"); }
|
||||
Rendezvous::ParsedKey KeyBar() { return MakeKey("bar"); }
|
||||
|
||||
TEST_F(LocalRendezvousTest, SendRecv) {
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false)));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Send(KeyFoo(), args, V("hello"), false)));
|
||||
Tensor val(DT_STRING);
|
||||
bool is_dead = false;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
|
||||
EXPECT_EQ("hello", V(val));
|
||||
}
|
||||
|
||||
@ -110,12 +126,12 @@ TEST_F(LocalRendezvousTest, RecvSend) {
|
||||
SchedClosure([this]() {
|
||||
Env::Default()->SleepForMicroseconds(10000);
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
|
||||
});
|
||||
Tensor val(DT_STRING);
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
|
||||
EXPECT_EQ("hello", V(val));
|
||||
}
|
||||
|
||||
@ -124,16 +140,17 @@ TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) {
|
||||
Tensor t(DT_STRING);
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
|
||||
});
|
||||
Env::Default()->SleepForMicroseconds(1000000);
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
|
||||
EXPECT_EQ("secret msg", V(val));
|
||||
}
|
||||
|
||||
@ -142,17 +159,18 @@ TEST_F(LocalRendezvousTest, DuplicateSerialRecv) {
|
||||
Tensor t(DT_STRING);
|
||||
bool is_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
|
||||
});
|
||||
Env::Default()->SleepForMicroseconds(1000000);
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
|
||||
TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
|
||||
EXPECT_EQ("secret msg", V(val));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
|
||||
}
|
||||
|
||||
// A simple structure that behaves a bit like a blocking counter. The
|
||||
@ -174,7 +192,7 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
|
||||
random::SimplePhilox rnd(&philox);
|
||||
Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Send(strings::StrCat(i), args,
|
||||
TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
|
||||
V(strings::StrCat(i)), false));
|
||||
});
|
||||
SchedClosure([this, &state, i]() {
|
||||
@ -184,7 +202,8 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
TF_ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead));
|
||||
TF_ASSERT_OK(
|
||||
rendez_->Recv(MakeKey(strings::StrCat(i)), args, &val, &val_dead));
|
||||
EXPECT_EQ(strings::StrCat(i), V(val));
|
||||
bool done = false;
|
||||
{
|
||||
@ -212,7 +231,7 @@ TEST_F(LocalRendezvousTest, RecvAbort) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
Status status = rendez_->Recv("foo", args, &val, &val_dead);
|
||||
Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
|
||||
EXPECT_TRUE(errors::IsAborted(status));
|
||||
}
|
||||
|
||||
@ -228,7 +247,7 @@ TEST_F(LocalRendezvousTest, RecvSleepAbort) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
Status status = rendez_->Recv("foo", args, &val, &val_dead);
|
||||
Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
|
||||
EXPECT_TRUE(errors::IsAborted(status));
|
||||
}
|
||||
|
||||
@ -237,8 +256,9 @@ TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
|
||||
Tensor val(DT_STRING);
|
||||
bool val_dead = false;
|
||||
Rendezvous::Args args;
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead)));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
|
||||
EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
|
||||
EXPECT_TRUE(
|
||||
errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
|
||||
}
|
||||
|
||||
class DummyDeviceContext : public DeviceContext {
|
||||
@ -255,15 +275,15 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
|
||||
Rendezvous::Args args;
|
||||
args.device_context = new DummyDeviceContext(123);
|
||||
|
||||
TF_ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
|
||||
TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
|
||||
|
||||
Notification n;
|
||||
Rendezvous::Args args1;
|
||||
args1.device_context = new DummyDeviceContext(1);
|
||||
rendez_->RecvAsync("foo", args1, [&n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, bool is_dead) {
|
||||
rendez_->RecvAsync(KeyFoo(), args1, [&n](const Status& s,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, bool is_dead) {
|
||||
CHECK_EQ(123,
|
||||
dynamic_cast<const DummyDeviceContext*>(send_args.device_context)
|
||||
->stream_id());
|
||||
@ -284,8 +304,8 @@ static void BM_SendRecv(int iters) {
|
||||
Status s;
|
||||
if (iters > 0) {
|
||||
while (iters--) {
|
||||
s = rendez->Send("foo", args, orig, is_dead);
|
||||
s = rendez->Recv("foo", args, &val, &is_dead);
|
||||
s = rendez->Send(KeyFoo(), args, orig, is_dead);
|
||||
s = rendez->Recv(KeyFoo(), args, &val, &is_dead);
|
||||
}
|
||||
CHECK_EQ(V(val), V(orig));
|
||||
}
|
||||
@ -307,8 +327,8 @@ static void BM_RecvSend(int iters) {
|
||||
Rendezvous::Args args;
|
||||
Status s;
|
||||
for (int i = 0; i < iters / 2; ++i) {
|
||||
s = rendez->Recv("foo", args, &foo, &is_dead);
|
||||
s = rendez->Send("bar", args, bar, is_dead);
|
||||
s = rendez->Recv(KeyFoo(), args, &foo, &is_dead);
|
||||
s = rendez->Send(KeyBar(), args, bar, is_dead);
|
||||
}
|
||||
CHECK_EQ("foo", V(foo));
|
||||
});
|
||||
@ -318,8 +338,8 @@ static void BM_RecvSend(int iters) {
|
||||
Rendezvous::Args args;
|
||||
Status s;
|
||||
for (int i = 0; i < iters / 2; ++i) {
|
||||
s = rendez->Send("foo", args, foo, is_dead);
|
||||
s = rendez->Recv("bar", args, &bar, &is_dead);
|
||||
s = rendez->Send(KeyFoo(), args, foo, is_dead);
|
||||
s = rendez->Recv(KeyBar(), args, &bar, &is_dead);
|
||||
}
|
||||
CHECK_EQ("bar", V(bar));
|
||||
delete pool;
|
||||
|
@ -32,10 +32,11 @@ static string GetRendezvousKeyPrefix(const string& send_device,
|
||||
recv_device, ";", tensor_name);
|
||||
}
|
||||
|
||||
static string GetRendezvousKey(const string& key_prefix,
|
||||
const FrameAndIter& frame_iter) {
|
||||
return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":",
|
||||
frame_iter.iter_id);
|
||||
static void GetRendezvousKey(const string& key_prefix,
|
||||
const FrameAndIter& frame_iter, string* key) {
|
||||
key->clear();
|
||||
strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":",
|
||||
frame_iter.iter_id);
|
||||
}
|
||||
|
||||
SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -57,9 +58,13 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
|
||||
string key;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key);
|
||||
VLOG(2) << "Send " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(key, &parsed));
|
||||
|
||||
// The device context may be passed between the Send/Recv
|
||||
// boundary, so that the device context used to produce the Tensor
|
||||
// is used when performing the copy on the recv side (which may be
|
||||
@ -67,9 +72,8 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->input_alloc_attr(0);
|
||||
Status s =
|
||||
ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead());
|
||||
ctx->SetStatus(s);
|
||||
OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0),
|
||||
ctx->is_input_dead()));
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
@ -98,16 +102,21 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||
const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
|
||||
string key;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &key);
|
||||
VLOG(2) << "Recv " << key;
|
||||
|
||||
Rendezvous::ParsedKey parsed;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(key, &parsed), done);
|
||||
|
||||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->output_alloc_attr(0);
|
||||
ctx->rendezvous()->RecvAsync(
|
||||
key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& val, bool is_dead) {
|
||||
parsed, args,
|
||||
[ctx, done](const Status& s, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& val,
|
||||
bool is_dead) {
|
||||
ctx->SetStatus(s);
|
||||
if (s.ok()) {
|
||||
// 'ctx' allocates the output tensor of the expected type. The
|
||||
|
@ -104,22 +104,29 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
}
|
||||
StringPiece tmp;
|
||||
while (!fullname.empty()) {
|
||||
bool progress = false;
|
||||
if (str_util::ConsumePrefix(&fullname, "/job:")) {
|
||||
p->has_job = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/replica:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/replica:")) {
|
||||
p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/task:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/task:")) {
|
||||
p->has_task = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/device:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/device:")) {
|
||||
p->has_type = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
|
||||
return false;
|
||||
@ -132,24 +139,31 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress = true;
|
||||
}
|
||||
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
|
||||
if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/CPU:")) {
|
||||
p->has_type = true;
|
||||
p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
|
||||
p->has_id = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
|
||||
return false;
|
||||
}
|
||||
} else if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/GPU:")) {
|
||||
progress = true;
|
||||
}
|
||||
if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
|
||||
str_util::ConsumePrefix(&fullname, "/GPU:")) {
|
||||
p->has_type = true;
|
||||
p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
|
||||
p->has_id = !str_util::ConsumePrefix(&fullname, "*");
|
||||
if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
progress = true;
|
||||
}
|
||||
|
||||
if (!progress) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -340,11 +354,22 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
|
||||
string* device) {
|
||||
ParsedName pn;
|
||||
if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
|
||||
*task = strings::StrCat(
|
||||
(pn.has_job ? strings::StrCat("/job:", pn.job) : ""),
|
||||
(pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""),
|
||||
(pn.has_task ? strings::StrCat("/task:", pn.task) : ""));
|
||||
*device = strings::StrCat(pn.type, ":", pn.id);
|
||||
task->clear();
|
||||
task->reserve(
|
||||
(pn.has_job ? (5 + pn.job.size()) : 0) +
|
||||
(pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
|
||||
(pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
|
||||
if (pn.has_job) {
|
||||
strings::StrAppend(task, "/job:", pn.job);
|
||||
}
|
||||
if (pn.has_replica) {
|
||||
strings::StrAppend(task, "/replica:", pn.replica);
|
||||
}
|
||||
if (pn.has_task) {
|
||||
strings::StrAppend(task, "/task:", pn.task);
|
||||
}
|
||||
device->clear();
|
||||
strings::StrAppend(device, pn.type, ":", pn.id);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user