Rolling forward "Add PrivateIntraProcessRendezvous." with a fix.
For now, when using ExecutorBarrier with multiple devices, we continue to use a RefCountedIntraProcessRendezvous. PiperOrigin-RevId: 282882166 Change-Id: I775818ba60db43f34745fb221e63e4d6ca065121
This commit is contained in:
parent
33877af57f
commit
64ba82b7d3
@ -521,7 +521,6 @@ Status DirectSession::RunInternal(
|
||||
executor_step_count, &debugger_state));
|
||||
}
|
||||
|
||||
run_state.rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
|
||||
#ifndef __ANDROID__
|
||||
// Set up for collectives if ExecutorsAndKeys declares a key.
|
||||
if (executors_and_keys->collective_graph_key !=
|
||||
@ -616,7 +615,6 @@ Status DirectSession::RunInternal(
|
||||
Executor::Args args;
|
||||
args.step_id = step_id;
|
||||
args.call_frame = call_frame;
|
||||
args.rendezvous = run_state.rendez.get();
|
||||
args.collective_executor =
|
||||
(run_state.collective_executor ? run_state.collective_executor->get()
|
||||
: nullptr);
|
||||
@ -688,14 +686,21 @@ Status DirectSession::RunInternal(
|
||||
};
|
||||
|
||||
if (can_execute_synchronously) {
|
||||
PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
|
||||
args.rendezvous = &rendezvous;
|
||||
|
||||
const auto& item = executors_and_keys->items[0];
|
||||
set_threadpool_args_for_item(item, &args);
|
||||
run_status = item.executor->Run(args);
|
||||
} else {
|
||||
core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
|
||||
new RefCountedIntraProcessRendezvous(device_mgr_.get()));
|
||||
args.rendezvous = rendezvous.get();
|
||||
|
||||
// `barrier` will delete itself after the final executor finishes.
|
||||
Notification executors_done;
|
||||
ExecutorBarrier* barrier =
|
||||
new ExecutorBarrier(num_executors, run_state.rendez.get(),
|
||||
new ExecutorBarrier(num_executors, rendezvous.get(),
|
||||
[&run_state, &executors_done](const Status& ret) {
|
||||
{
|
||||
mutex_lock l(run_state.mu);
|
||||
@ -1139,7 +1144,7 @@ Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
|
||||
|
||||
Status DirectSession::RecvPRunOutputs(
|
||||
const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||
const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
|
||||
std::vector<Tensor>* outputs) {
|
||||
Status s;
|
||||
if (!output_names.empty()) {
|
||||
|
@ -191,7 +191,6 @@ class DirectSession : public Session {
|
||||
struct RunState {
|
||||
mutex mu;
|
||||
Status status GUARDED_BY(mu);
|
||||
core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr;
|
||||
std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
|
||||
std::unique_ptr<StepStatsCollector> collector;
|
||||
TensorStore tensor_store;
|
||||
@ -208,6 +207,7 @@ class DirectSession : public Session {
|
||||
Notification executors_done;
|
||||
std::unordered_map<string, bool> pending_inputs; // true if fed
|
||||
std::unordered_map<string, bool> pending_outputs; // true if fetched
|
||||
core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr;
|
||||
|
||||
PartialRunState(const std::vector<string>& pending_input_names,
|
||||
const std::vector<string>& pending_output_names,
|
||||
@ -282,7 +282,7 @@ class DirectSession : public Session {
|
||||
// tensors are computed.
|
||||
::tensorflow::Status RecvPRunOutputs(
|
||||
const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||
const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
|
||||
std::vector<Tensor>* outputs);
|
||||
|
||||
// Check if the specified fetches can be computed from the feeds
|
||||
|
@ -1116,11 +1116,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
}
|
||||
Options run_opts = opts;
|
||||
if (opts.create_rendezvous) {
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
|
||||
auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
|
||||
run_opts.rendezvous = rendezvous;
|
||||
run_opts.create_rendezvous = false;
|
||||
done = [done = std::move(done), rendezvous](const Status& status) {
|
||||
rendezvous->Unref();
|
||||
done = [done = std::move(done), rendezvous](const Status& status) mutable {
|
||||
delete rendezvous;
|
||||
done(status);
|
||||
};
|
||||
}
|
||||
@ -1187,11 +1187,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
|
||||
Options run_opts = opts;
|
||||
if (opts.create_rendezvous) {
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
|
||||
auto* rendezvous = new PrivateIntraProcessRendezvous(device_mgr_);
|
||||
run_opts.rendezvous = rendezvous;
|
||||
run_opts.create_rendezvous = false;
|
||||
done = [done = std::move(done), rendezvous](const Status& status) {
|
||||
rendezvous->Unref();
|
||||
done = [done = std::move(done), rendezvous](const Status& status) mutable {
|
||||
delete rendezvous;
|
||||
done(status);
|
||||
};
|
||||
}
|
||||
|
@ -1854,8 +1854,8 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
|
||||
Tensor y;
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||
opts.rendezvous = rendezvous;
|
||||
PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
|
||||
opts.rendezvous = &rendezvous;
|
||||
opts.source_device = "/device:CPU:1";
|
||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
|
||||
@ -1870,7 +1870,6 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
||||
y,
|
||||
test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
|
||||
TensorShape({})));
|
||||
rendezvous->Unref();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -110,12 +110,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
~ProcessFunctionLibraryRuntimeTest() override {
|
||||
if (rendezvous_ != nullptr) {
|
||||
rendezvous_->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
void Init(const std::vector<FunctionDef>& flib,
|
||||
const SessionMetadata* session_metadata = nullptr) {
|
||||
FunctionDefLibrary proto;
|
||||
@ -127,7 +121,8 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
|
||||
nullptr, session_metadata));
|
||||
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||
rendezvous_ =
|
||||
absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_.get());
|
||||
}
|
||||
|
||||
Status Instantiate(
|
||||
@ -263,7 +258,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
test::function::FunctionTestSchedClosure(fn);
|
||||
};
|
||||
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.runner = &runner;
|
||||
Status status;
|
||||
Notification done;
|
||||
@ -292,7 +287,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||
std::unique_ptr<TestClusterFLR> cluster_flr_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||
IntraProcessRendezvous* rendezvous_ = nullptr;
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous_ = nullptr;
|
||||
};
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
|
||||
@ -344,7 +339,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
||||
Init({test::function::XTimesTwo()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -359,7 +354,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -375,7 +370,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -392,7 +387,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
|
||||
@ -411,7 +406,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
Tensor y;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
|
||||
@ -432,7 +427,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
|
||||
@ -462,7 +457,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
|
||||
@ -509,7 +504,7 @@ void TestTwoDeviceMult(
|
||||
const string& error = "") {
|
||||
fixture->Init({test::function::TwoDeviceMult()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = fixture->rendezvous_;
|
||||
opts.rendezvous = fixture->rendezvous_.get();
|
||||
auto x = test::AsTensor<float>({1, 2, 3});
|
||||
Tensor y_cpu;
|
||||
Tensor y_gpu;
|
||||
@ -542,7 +537,7 @@ void TestTwoDeviceInputOutput(
|
||||
fixture->Init({test::function::TwoDeviceInputOutput()});
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = fixture->rendezvous_;
|
||||
opts.rendezvous = fixture->rendezvous_.get();
|
||||
Tensor x1 = test::AsTensor<float>({1, 2});
|
||||
if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
|
||||
x1 = fixture->CPUToGPU(x1);
|
||||
@ -743,7 +738,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
|
||||
|
||||
// Run the function taking a resource and outputing it
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
Tensor x1 = CPUToGPU(test::AsTensor<float>({1, 2}));
|
||||
Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
|
||||
"/job:a/replica:0/task:0/device:GPU:0");
|
||||
@ -985,7 +980,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
|
||||
Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -1001,7 +996,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
|
||||
Init({SessionMetadataReaderOpFn()}, &session_metadata);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
@ -1027,7 +1022,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
|
||||
TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||
opts.rendezvous = rendezvous_;
|
||||
opts.rendezvous = rendezvous_.get();
|
||||
opts.remote_execution = true;
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
|
||||
|
@ -32,23 +32,12 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
IntraProcessRendezvous::~IntraProcessRendezvous() {}
|
||||
|
||||
Status IntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
// Buffers "val" and "device_context" in local_.
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
|
||||
StatusCallback done) {
|
||||
namespace {
|
||||
void SameWorkerRecvDone(const DeviceMgr* device_mgr,
|
||||
const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
Tensor* out, StatusCallback done) {
|
||||
// Do a quick copy (sharing the underlying buffer) if both tensors
|
||||
// are on host memory.
|
||||
const bool src_host =
|
||||
@ -73,13 +62,13 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
}
|
||||
|
||||
Device* src_device;
|
||||
Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
|
||||
Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
Device* dst_device;
|
||||
s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
|
||||
s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
@ -116,16 +105,18 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
|
||||
void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
|
||||
LocalRendezvous* local,
|
||||
const RendezvousInterface::ParsedKey& parsed,
|
||||
const Rendezvous::Args& recv_args,
|
||||
RendezvousInterface::DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
|
||||
|
||||
MEMDEBUG_CACHE_OP("RecvAsync");
|
||||
// Recv the tensor from local_.
|
||||
local_.RecvAsync(
|
||||
key, args,
|
||||
[this, key, done = std::move(done)](
|
||||
local->RecvAsync(
|
||||
parsed, recv_args,
|
||||
[device_mgr, parsed, done = std::move(done)](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
bool is_dead) mutable {
|
||||
@ -141,7 +132,7 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
};
|
||||
|
||||
if (status.ok() && in.IsInitialized()) {
|
||||
SameWorkerRecvDone(key, send_args, recv_args, in, out,
|
||||
SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
|
||||
std::move(final_callback));
|
||||
} else {
|
||||
final_callback(status);
|
||||
@ -149,8 +140,56 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
});
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
CHECK(!s.ok());
|
||||
} // namespace
|
||||
|
||||
RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
|
||||
|
||||
Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val,
|
||||
const bool is_dead) {
|
||||
VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
|
||||
IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
|
||||
}
|
||||
|
||||
void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
local_.StartAbort(s);
|
||||
}
|
||||
|
||||
PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
|
||||
const DeviceMgr* device_mgr)
|
||||
: device_mgr_(device_mgr) {}
|
||||
|
||||
PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
|
||||
|
||||
Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
const Tensor& val,
|
||||
const bool is_dead) {
|
||||
DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
|
||||
return local_.Send(key, args, val, is_dead);
|
||||
}
|
||||
|
||||
void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) {
|
||||
DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
|
||||
<< key.FullKey();
|
||||
IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
|
||||
}
|
||||
|
||||
void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
|
||||
local_.StartAbort(s);
|
||||
}
|
||||
|
||||
|
@ -30,48 +30,61 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// IntraProcessRendezvous is a Rendezvous which expects all producers
|
||||
// and consumers to be devices immediately accessible within the
|
||||
// process. That is, it will never be necessary to perform an RPC to
|
||||
// The IntraProcessRendezvous classes are implementations of a Rendezvous that
|
||||
// expects all producers and consumers to be devices immediately accessible
|
||||
// within the process. That is, it will never be necessary to perform an RPC to
|
||||
// communicate with either.
|
||||
//
|
||||
// Buffering of Tensor values is delegated to a `LocalRendezvous`. This class
|
||||
// just adds functionality to coordinate multiple process-local devices.
|
||||
class IntraProcessRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
// Buffering of Tensor values is delegated to a `LocalRendezvous`. An
|
||||
// IntraProcessRendezvous. just adds functionality to coordinate multiple
|
||||
// process-local devices.
|
||||
|
||||
// Forwards to local_, where the Tensor "val" will be buffered and
|
||||
// any waiting callback stored.
|
||||
// Reference-counted implementation that may be shared between multiple threads.
|
||||
class RefCountedIntraProcessRendezvous : public Rendezvous {
|
||||
public:
|
||||
explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
|
||||
// Implementation of RendezvousInterface methods.
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
|
||||
// This method is called only by the RecvOp. It tests to see
|
||||
// whether the value will be produced by a local or remote device
|
||||
// and handles accordingly. In the local case it forwards to
|
||||
// local_, in the remote case it initiates an RPC request.
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
void StartAbort(const Status& status) override;
|
||||
|
||||
private:
|
||||
const DeviceMgr* device_mgr_;
|
||||
LocalRendezvous local_;
|
||||
|
||||
~IntraProcessRendezvous() override;
|
||||
~RefCountedIntraProcessRendezvous() override;
|
||||
|
||||
// Callback handling the case when a rendezvous has been
|
||||
// accomplished in local_ and the consumer is local to this process.
|
||||
// Tensor "in" will be copied into "out". The key "parsed" encodes
|
||||
// the src and dst devices.
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& in,
|
||||
Tensor* out, StatusCallback done);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RefCountedIntraProcessRendezvous);
|
||||
};
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
|
||||
// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for
|
||||
// backwards compatibility with existing users.
|
||||
using IntraProcessRendezvous = RefCountedIntraProcessRendezvous;
|
||||
|
||||
// Non-reference-counted implementation that may be stack-allocated for
|
||||
// performance.
|
||||
//
|
||||
// Prefer to use PrivateIntraProcessRendezvous in new code.
|
||||
class PrivateIntraProcessRendezvous : public RendezvousInterface {
|
||||
public:
|
||||
explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr);
|
||||
~PrivateIntraProcessRendezvous() override;
|
||||
|
||||
// Implementation of RendezvousInterface methods.
|
||||
Status Send(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
const Tensor& val, const bool is_dead) override;
|
||||
void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
void StartAbort(const Status& status) override;
|
||||
|
||||
private:
|
||||
const DeviceMgr* device_mgr_;
|
||||
LocalRendezvous local_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PrivateIntraProcessRendezvous);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user