Add PrivateIntraProcessRendezvous.
PrivateIntraProcessRendezvous is a version of the existing IntraProcessRendezvous (now renamed to RefcountedIntraProcessRendezvous with a forwarding alias) that is compatible with stack allocation. It allows users to avoid the overhead of dynamically allocating/destroying an IntraProcessRendezvous and the atomic operations involved in manipulating its reference count. This change modifies some users of IntraProcessRendezvous to use PrivateIntraProcessRendezvous, where appropriate. In particular, it uses a stack-allocated PrivateIntraProcessRendezvous on the DirectSession::RunInternal() path. PiperOrigin-RevId: 282847328 Change-Id: I3c54024ea658afb2e2bd27ef35dc421653abc1a8
This commit is contained in:
parent
03e56f176c
commit
a10dc73356
@ -521,7 +521,8 @@ Status DirectSession::RunInternal(
|
||||
executor_step_count, &debugger_state));
|
||||
}
|
||||
|
||||
run_state.rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
|
||||
PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
|
||||
|
||||
#ifndef __ANDROID__
|
||||
// Set up for collectives if ExecutorsAndKeys declares a key.
|
||||
if (executors_and_keys->collective_graph_key !=
|
||||
@ -616,7 +617,7 @@ Status DirectSession::RunInternal(
|
||||
Executor::Args args;
|
||||
args.step_id = step_id;
|
||||
args.call_frame = call_frame;
|
||||
args.rendezvous = run_state.rendez.get();
|
||||
args.rendezvous = &rendezvous;
|
||||
args.collective_executor =
|
||||
(run_state.collective_executor ? run_state.collective_executor->get()
|
||||
: nullptr);
|
||||
@ -695,7 +696,7 @@ Status DirectSession::RunInternal(
|
||||
// `barrier` will delete itself after the final executor finishes.
|
||||
Notification executors_done;
|
||||
ExecutorBarrier* barrier =
|
||||
new ExecutorBarrier(num_executors, run_state.rendez.get(),
|
||||
new ExecutorBarrier(num_executors, &rendezvous,
|
||||
[&run_state, &executors_done](const Status& ret) {
|
||||
{
|
||||
mutex_lock l(run_state.mu);
|
||||
@ -1139,7 +1140,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
|
||||
|
||||
Status DirectSession::RecvPRunOutputs(
|
||||
const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||
const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
|
||||
std::vector<Tensor>* outputs) {
|
||||
Status s;
|
||||
if (!output_names.empty()) {
|
||||
|
@ -191,7 +191,6 @@ class DirectSession : public Session {
|
||||
struct RunState {
|
||||
mutex mu;
|
||||
Status status GUARDED_BY(mu);
|
||||
core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr;
|
||||
std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
|
||||
std::unique_ptr<StepStatsCollector> collector;
|
||||
TensorStore tensor_store;
|
||||
@ -208,6 +207,7 @@ class DirectSession : public Session {
|
||||
Notification executors_done;
|
||||
std::unordered_map<string, bool> pending_inputs; // true if fed
|
||||
std::unordered_map<string, bool> pending_outputs; // true if fetched
|
||||
core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr;
|
||||
|
||||
PartialRunState(const std::vector<string>& pending_input_names,
|
||||
const std::vector<string>& pending_output_names,
|
||||
@ -282,7 +282,7 @@ class DirectSession : public Session {
|
||||
// tensors are computed.
|
||||
::tensorflow::Status RecvPRunOutputs(
|
||||
const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||
const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
|
||||
std::vector<Tensor>* outputs);
|
||||
|
||||
// Check if the specified fetches can be computed from the feeds
|
||||
|
@ -166,8 +166,8 @@ class ExecutorBarrier {
|
||||
//
|
||||
// 'done' is called after the last executor completes, and
|
||||
// ExecutorBarrier is deleted.
|
||||
ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
|
||||
: rendez_(r), done_cb_(done), pending_(num) {}
|
||||
ExecutorBarrier(size_t num, RendezvousInterface* r, StatusCallback done)
|
||||
: rendez_(r), done_cb_(std::move(done)), pending_(num) {}
|
||||
|
||||
~ExecutorBarrier() {}
|
||||
|
||||
@ -178,7 +178,7 @@ class ExecutorBarrier {
|
||||
}
|
||||
|
||||
private:
|
||||
Rendezvous* rendez_ = nullptr;
|
||||
RendezvousInterface* rendez_ = nullptr; // Not owned.
|
||||
StatusCallback done_cb_ = nullptr;
|
||||
|
||||
mutable mutex mu_;
|
||||
@ -186,7 +186,7 @@ class ExecutorBarrier {
|
||||
StatusGroup status_group_ GUARDED_BY(mu_);
|
||||
|
||||
void WhenDone(const Status& s) {
|
||||
Rendezvous* error_rendez = nullptr;
|
||||
RendezvousInterface* error_rendez = nullptr;
|
||||
StatusCallback done = nullptr;
|
||||
Status status;
|
||||
|
||||
@ -197,7 +197,6 @@ class ExecutorBarrier {
|
||||
// Rendezvous object by this thread only.
|
||||
if (status_group_.ok() && !s.ok()) {
|
||||
error_rendez = rendez_;
|
||||
error_rendez->Ref();
|
||||
}
|
||||
|
||||
if (!s.ok() && !StatusGroup::IsDerived(s) &&
|
||||
@ -219,7 +218,6 @@ class ExecutorBarrier {
|
||||
if (error_rendez != nullptr) {
|
||||
error_rendez->StartAbort(
|
||||
errors::Aborted("Stopping remaining executors."));
|
||||
error_rendez->Unref();
|
||||
}
|
||||
|
||||
if (done != nullptr) {
|
||||
|
@ -1116,11 +1116,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
}
|
||||
Options run_opts = opts;
|
||||
if (opts.create_rendezvous) {
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
|
||||
auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
|
||||
run_opts.rendezvous = rendezvous;
|
||||
run_opts.create_rendezvous = false;
|
||||
done = [done = std::move(done), rendezvous](const Status& status) {
|
||||
rendezvous->Unref();
|
||||
done = [done = std::move(done), rendezvous](const Status& status) mutable {
|
||||
delete rendezvous;
|
||||
done(status);
|
||||
};
|
||||
}
|
||||
@ -1187,11 +1187,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
|
||||
Options run_opts = opts;
|
||||
if (opts.create_rendezvous) {
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
|
||||
auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
|
||||
run_opts.rendezvous = rendezvous;
|
||||
run_opts.create_rendezvous = false;
|
||||
done = [done = std::move(done), rendezvous](const Status& status) {
|
||||
rendezvous->Unref();
|
||||
done = [done = std::move(done), rendezvous](const Status& status) mutable {
|
||||
delete rendezvous;
|
||||
done(status);
|
||||
};
|
||||
}
|
||||
|
@ -1854,8 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
|
||||
Tensor y;
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||
opts.rendezvous = rendezvous;
|
||||
PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
|
||||
opts.rendezvous = &rendezvous;
|
||||
opts.source_device = "/device:CPU:1";
|
||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
|
||||
@ -1870,7 +1870,6 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
y,
|
||||
test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
|
||||
TensorShape({})));
|
||||
rendezvous->Unref();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -110,12 +110,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
~ProcessFunctionLibraryRuntimeTest() override {
|
||||
if (rendezvous_ != nullptr) {
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
void Init(const std::vector<FunctionDef>& flib,
|
||||
const SessionMetadata* session_metadata = nullptr) {
|
||||
FunctionDefLibrary proto;
|
||||
@ -127,7 +121,8 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
|
||||
nullptr, session_metadata));
|
||||
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||
rendezvous_ =
|
||||
absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_.get());
|
||||
}
|
||||
|
||||
Status Instantiate(
|
||||
@ -263,7 +258,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
test::function::FunctionTestSchedClosure(fn);
|
||||
};
|
||||
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.runner = &runner;
|
||||
Status status;
|
||||
Notification done;
|
||||
@ -292,7 +287,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||
std::unique_ptr<TestClusterFLR> cluster_flr_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||
IntraProcessRendezvous* rendezvous_ = nullptr;
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous_ = nullptr;
|
||||
};
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
|
||||
@ -344,7 +339,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
||||
Init({test::function::XTimesTwo()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -359,7 +354,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -375,7 +370,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -392,7 +387,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
|
||||
@ -411,7 +406,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
Tensor y;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
|
||||
@ -432,7 +427,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
|
||||
@ -462,7 +457,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
|
||||
@ -509,7 +504,7 @@ void TestTwoDeviceMult(
|
||||
const string& error = "") {
|
||||
fixture->Init({test::function::TwoDeviceMult()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = fixture->rendezvous_;
|
||||
opts.rendezvous = fixture->rendezvous_.get();
|
||||
auto x = test::AsTensor<float>({1, 2, 3});
|
||||
Tensor y_cpu;
|
||||
Tensor y_gpu;
|
||||
@ -542,7 +537,7 @@ void TestTwoDeviceInputOutput(
|
||||
fixture->Init({test::function::TwoDeviceInputOutput()});
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = fixture->rendezvous_;
|
||||
opts.rendezvous = fixture->rendezvous_.get();
|
||||
Tensor x1 = test::AsTensor<float>({1, 2});
|
||||
if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
|
||||
x1 = fixture->CPUToGPU(x1);
|
||||
@ -743,7 +738,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
||||
|
||||
// Run the function taking a resource and outputing it
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
Tensor x1 = CPUToGPU(test::AsTensor<float>({1, 2}));
|
||||
Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
|
||||
"/job:a/replica:0/task:0/device:GPU:0");
|
||||
@ -985,7 +980,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
|
||||
Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -1001,7 +996,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
|
||||
Init({SessionMetadataReaderOpFn()}, &session_metadata);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -1027,7 +1022,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
|
||||
TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
|
@ -32,23 +32,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
IntraProcessRendezvous::~IntraProcessRendezvous() {}
|
||||
|
||||
Status IntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
|
||||
StatusCallback done) {
|
||||
namespace {
|
||||
void SameWorkerRecvDone(const DeviceMgr* device_mgr,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
Tensor* out, StatusCallback done) {
|
||||
// Do a quick copy (sharing the underlying buffer) if both tensors
|
||||
// are on host memory.
|
||||
const bool src_host =
|
||||
@ -73,13 +62,13 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
}
|
||||
|
||||
Device* src_device;
|
||||
Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
|
||||
Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
Device* dst_device;
|
||||
s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
|
||||
s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
@ -116,16 +105,18 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
|
||||
void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
|
||||
LocalRendezvous* local,
|
||||
const RendezvousInterface::ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
RendezvousInterface::DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
|
||||
|
||||
MEMDEBUG_CACHE_OP("RecvAsync");
|
||||
// Recv the tensor from local_.
|
||||
local_.RecvAsync(
|
||||
key, args,
|
||||
[this, key, done = std::move(done)](
|
||||
local->RecvAsync(
|
||||
parsed, recv_args,
|
||||
[device_mgr, parsed, done = std::move(done)](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
bool is_dead) mutable {
|
||||
@ -141,7 +132,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
};
|
||||
|
||||
if (status.ok() && in.IsInitialized()) {
|
||||
SameWorkerRecvDone(key, send_args, recv_args, in, out,
|
||||
SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
|
||||
std::move(final_callback));
|
||||
} else {
|
||||
final_callback(status);
|
||||
@ -149,8 +140,56 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
});
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
CHECK(!s.ok());
|
||||
} // namespace
|
||||
|
||||
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
||||
|
||||
Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val,
|
||||
const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
|
||||
IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
|
||||
}
|
||||
|
||||
void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
local_.StartAbort(s);
|
||||
}
|
||||
|
||||
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
||||
|
||||
Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val,
|
||||
const bool is_dead) {
|
||||
DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
|
||||
<< key.FullKey();
|
||||
IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
|
||||
}
|
||||
|
||||
void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
local_.StartAbort(s);
|
||||
}
|
||||
|
||||
|
@ -30,48 +30,61 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// IntraProcessRendezvous is a Rendezvous which expects all producers
|
||||
// and consumers to be devices immediately accessible within the
|
||||
// process. That is, it will never be necessary to perform an RPC to
|
||||
// The IntraProcessRendezvous classes are implementations of a Rendezvous that
|
||||
// expects all producers and consumers to be devices immediately accessible
|
||||
// within the process. That is, it will never be necessary to perform an RPC to
|
||||
// communicate with either.
|
||||
//
|
||||
// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class
|
||||
// just adds functionality to coordinate multiple process-local devices.
|
||||
class IntraProcessRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
// Buffering of Tensor values is delegated to a `LocalRendezvous`. An
|
||||
// IntraProcessRendezvous. just adds functionality to coordinate multiple
|
||||
// process-local devices.
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
// Reference-counted implementation that may be shared between multiple threads.
|
||||
class RefCountedIntraProcessRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
|
||||
// Implementation of RendezvousInterface methods.
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
|
||||
private:
|
||||
const DeviceMgr* device_mgr_;
|
||||
LocalRendezvous local_;
|
||||
|
||||
~IntraProcessRendezvous() override;
|
||||
~RefCountedIntraProcessRendezvous() override;
|
||||
|
||||
// Callback handling the case when a rendezvous has been
|
||||
// accomplished in local_ and the consumer is local to this process.
|
||||
// Tensor "in" will be copied into "out". The key "parsed" encodes
|
||||
// the src and dst devices.
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
Tensor* out, StatusCallback done);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RefCountedIntraProcessRendezvous);
|
||||
};
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
|
||||
// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for
|
||||
// backwards compatibility with existing users.
|
||||
using IntraProcessRendezvous = RefCountedIntraProcessRendezvous;
|
||||
|
||||
// Non-reference-counted implementation that may be stack-allocated for
|
||||
// performance.
|
||||
//
|
||||
// Prefer to use PrivateIntraProcessRendezvous in new code.
|
||||
class PrivateIntraProcessRendezvous : public RendezvousInterface {
|
||||
public:
|
||||
explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
~PrivateIntraProcessRendezvous() override;
|
||||
|
||||
// Implementation of RendezvousInterface methods.
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
void StartAbort(const Status& status) override;
|
||||
|
||||
private:
|
||||
const DeviceMgr* device_mgr_;
|
||||
LocalRendezvous local_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PrivateIntraProcessRendezvous);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user