Share devices between worker session and eager context in remote eager server.
PiperOrigin-RevId: 239661121
This commit is contained in:
parent
ef936e008e
commit
08cbe99299
@ -90,21 +90,6 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
|
||||
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
|
||||
// TODO(nareshmodi): Correctly set the SessionOptions.
|
||||
SessionOptions(),
|
||||
strings::Printf("/job:%s/replica:0/task:%d",
|
||||
request->server_def().job_name().data(),
|
||||
request->server_def().task_index()),
|
||||
&devices));
|
||||
response->mutable_device_attributes()->Reserve(devices.size());
|
||||
for (const auto& d : devices) {
|
||||
*response->add_device_attributes() = d->attributes();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
|
||||
absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
|
||||
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
|
||||
auto session_name = strings::StrCat("eager_", request->rendezvous_id());
|
||||
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
|
||||
@ -114,13 +99,22 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
request->async(), std::move(device_mgr), r));
|
||||
request->async(), device_mgr, false, r));
|
||||
|
||||
std::vector<DeviceAttributes> device_attributes;
|
||||
device_mgr->ListDeviceAttributes(&device_attributes);
|
||||
|
||||
for (const auto& da : device_attributes) {
|
||||
*response->add_device_attributes() = da;
|
||||
}
|
||||
|
||||
uint64 context_id;
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user