Refresh device in EagerContext and pflr when device is updated. This is required to allow RuntimeFallback and KernelFallback to access TPU device created by tfrt.
PiperOrigin-RevId: 355932739 Change-Id: I043d217e06612734ba0e2b0bbf1cd672b4c9bf44
This commit is contained in:
parent
6fd7d69823
commit
4413f34d5c
@ -1277,7 +1277,7 @@ Status EagerContext::UpdateRemoteMaster(
|
|||||||
context_view_id_++;
|
context_view_id_++;
|
||||||
|
|
||||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||||
pflr_->InitializeDeviceSet();
|
pflr_->InitializeDeviceAndFlr();
|
||||||
InitPrioritizedDeviceTypeList();
|
InitPrioritizedDeviceTypeList();
|
||||||
|
|
||||||
default_executor_.ClearError();
|
default_executor_.ClearError();
|
||||||
@ -1496,7 +1496,7 @@ Status EagerContext::UpdateRemoteWorker(
|
|||||||
remote_contexts_ = remote_contexts;
|
remote_contexts_ = remote_contexts;
|
||||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||||
InitPrioritizedDeviceTypeList();
|
InitPrioritizedDeviceTypeList();
|
||||||
pflr_->InitializeDeviceSet();
|
pflr_->InitializeDeviceAndFlr();
|
||||||
}
|
}
|
||||||
|
|
||||||
// No need to update remote_device_manager_ since it's not owned for remote
|
// No need to update remote_device_manager_ since it's not owned for remote
|
||||||
|
@ -486,6 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||||
|
|
||||||
const SessionOptions& session_options() const { return opts_; }
|
const SessionOptions& session_options() const { return opts_; }
|
||||||
|
void InitPrioritizedDeviceTypeList();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Rendezvous* CreateRendezvous(int64 step_id) const {
|
Rendezvous* CreateRendezvous(int64 step_id) const {
|
||||||
@ -510,7 +511,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
|
|
||||||
~EagerContext() override;
|
~EagerContext() override;
|
||||||
|
|
||||||
void InitPrioritizedDeviceTypeList();
|
|
||||||
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
||||||
Status RegisterExistingFunctionsOnRemoteWorkers(
|
Status RegisterExistingFunctionsOnRemoteWorkers(
|
||||||
const std::vector<string>& remote_workers);
|
const std::vector<string>& remote_workers);
|
||||||
|
@ -100,7 +100,9 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
|||||||
std::unique_ptr<FunctionLibraryRuntime>>),
|
std::unique_ptr<FunctionLibraryRuntime>>),
|
||||||
next_handle_(0),
|
next_handle_(0),
|
||||||
session_metadata_(session_metadata),
|
session_metadata_(session_metadata),
|
||||||
rendezvous_factory_(std::move(rendezvous_factory)) {
|
rendezvous_factory_(std::move(rendezvous_factory)),
|
||||||
|
optimizer_options_(optimizer_options),
|
||||||
|
graph_def_version_(graph_def_version) {
|
||||||
if (device_mgr == nullptr) {
|
if (device_mgr == nullptr) {
|
||||||
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
|
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
|
||||||
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
|
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
|
||||||
@ -108,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
|||||||
session_metadata_, this);
|
session_metadata_, this);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (Device* d : device_mgr->ListDevices()) {
|
InitializeDeviceAndFlr();
|
||||||
(*flr_map_)[d] = NewFunctionLibraryRuntime(
|
|
||||||
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
|
|
||||||
lib_def_, default_thread_pool, optimizer_options, session_metadata_,
|
|
||||||
this);
|
|
||||||
}
|
|
||||||
|
|
||||||
InitializeDeviceSet();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
@ -214,7 +209,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
|||||||
"function executions");
|
"function executions");
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
|
void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
|
||||||
DeviceMgr const* all_devices = device_mgr_;
|
DeviceMgr const* all_devices = device_mgr_;
|
||||||
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
||||||
all_devices = parent_->remote_device_mgr();
|
all_devices = parent_->remote_device_mgr();
|
||||||
@ -225,6 +220,14 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
|
|||||||
for (auto d : all_devices->ListDevices()) {
|
for (auto d : all_devices->ListDevices()) {
|
||||||
device_set_->AddDevice(d);
|
device_set_->AddDevice(d);
|
||||||
}
|
}
|
||||||
|
for (Device* d : device_mgr_->ListDevices()) {
|
||||||
|
if ((*flr_map_)[d] == nullptr) {
|
||||||
|
(*flr_map_)[d] = NewFunctionLibraryRuntime(
|
||||||
|
device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
|
||||||
|
graph_def_version_, lib_def_, default_thread_pool_,
|
||||||
|
optimizer_options_, session_metadata_, this);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||||
|
@ -207,8 +207,9 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
return device_set_;
|
return device_set_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the set of local and remote devices for op device selection.
|
// Initialize the set of local and remote devices and corresponding flr for op
|
||||||
void InitializeDeviceSet();
|
// device selection.
|
||||||
|
void InitializeDeviceAndFlr();
|
||||||
|
|
||||||
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
|
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
|
||||||
|
|
||||||
@ -478,6 +479,9 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
int next_handle_ TF_GUARDED_BY(mu_);
|
int next_handle_ TF_GUARDED_BY(mu_);
|
||||||
const SessionMetadata* const session_metadata_;
|
const SessionMetadata* const session_metadata_;
|
||||||
const Rendezvous::Factory rendezvous_factory_;
|
const Rendezvous::Factory rendezvous_factory_;
|
||||||
|
|
||||||
|
const OptimizerOptions optimizer_options_;
|
||||||
|
const int graph_def_version_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user