Share devices between worker session and eager context in remote eager server.

PiperOrigin-RevId: 239661121
This commit is contained in:
Akshay Modi 2019-03-21 13:32:51 -07:00 committed by TensorFlower Gardener
parent ef936e008e
commit 08cbe99299

View File

@ -90,21 +90,6 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
// TODO(nareshmodi): Correctly set the SessionOptions.
SessionOptions(),
strings::Printf("/job:%s/replica:0/task:%d",
request->server_def().job_name().data(),
request->server_def().task_index()),
&devices));
response->mutable_device_attributes()->Reserve(devices.size());
for (const auto& d : devices) {
*response->add_device_attributes() = d->attributes();
}
std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
absl::make_unique<DeviceMgr>(std::move(devices));
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
auto session_name = strings::StrCat("eager_", request->rendezvous_id());
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
@ -114,13 +99,22 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
// Initialize remote tensor communication based on worker session.
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
request->async(), std::move(device_mgr), r));
request->async(), device_mgr, false, r));
std::vector<DeviceAttributes> device_attributes;
device_mgr->ListDeviceAttributes(&device_attributes);
for (const auto& da : device_attributes) {
*response->add_device_attributes() = da;
}
uint64 context_id;
{