Refresh device in EagerContext and pflr when device is updated. This is required to allow RuntimeFallback and KernelFallback to access TPU device created by tfrt.
PiperOrigin-RevId: 355932739 Change-Id: I043d217e06612734ba0e2b0bbf1cd672b4c9bf44
This commit is contained in:
parent
6fd7d69823
commit
4413f34d5c
@ -1277,7 +1277,7 @@ Status EagerContext::UpdateRemoteMaster(
|
||||
context_view_id_++;
|
||||
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
pflr_->InitializeDeviceSet();
|
||||
pflr_->InitializeDeviceAndFlr();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
|
||||
default_executor_.ClearError();
|
||||
@ -1496,7 +1496,7 @@ Status EagerContext::UpdateRemoteWorker(
|
||||
remote_contexts_ = remote_contexts;
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
InitPrioritizedDeviceTypeList();
|
||||
pflr_->InitializeDeviceSet();
|
||||
pflr_->InitializeDeviceAndFlr();
|
||||
}
|
||||
|
||||
// No need to update remote_device_manager_ since it's not owned for remote
|
||||
|
@ -486,6 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||
|
||||
const SessionOptions& session_options() const { return opts_; }
|
||||
void InitPrioritizedDeviceTypeList();
|
||||
|
||||
private:
|
||||
Rendezvous* CreateRendezvous(int64 step_id) const {
|
||||
@ -510,7 +511,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
|
||||
~EagerContext() override;
|
||||
|
||||
void InitPrioritizedDeviceTypeList();
|
||||
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
||||
Status RegisterExistingFunctionsOnRemoteWorkers(
|
||||
const std::vector<string>& remote_workers);
|
||||
|
@ -100,7 +100,9 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
std::unique_ptr<FunctionLibraryRuntime>>),
|
||||
next_handle_(0),
|
||||
session_metadata_(session_metadata),
|
||||
rendezvous_factory_(std::move(rendezvous_factory)) {
|
||||
rendezvous_factory_(std::move(rendezvous_factory)),
|
||||
optimizer_options_(optimizer_options),
|
||||
graph_def_version_(graph_def_version) {
|
||||
if (device_mgr == nullptr) {
|
||||
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
|
||||
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
|
||||
@ -108,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
session_metadata_, this);
|
||||
return;
|
||||
}
|
||||
for (Device* d : device_mgr->ListDevices()) {
|
||||
(*flr_map_)[d] = NewFunctionLibraryRuntime(
|
||||
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
|
||||
lib_def_, default_thread_pool, optimizer_options, session_metadata_,
|
||||
this);
|
||||
}
|
||||
|
||||
InitializeDeviceSet();
|
||||
InitializeDeviceAndFlr();
|
||||
}
|
||||
|
||||
/* static */
|
||||
@ -214,7 +209,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||
"function executions");
|
||||
}
|
||||
|
||||
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
|
||||
void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
|
||||
DeviceMgr const* all_devices = device_mgr_;
|
||||
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
||||
all_devices = parent_->remote_device_mgr();
|
||||
@ -225,6 +220,14 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
|
||||
for (auto d : all_devices->ListDevices()) {
|
||||
device_set_->AddDevice(d);
|
||||
}
|
||||
for (Device* d : device_mgr_->ListDevices()) {
|
||||
if ((*flr_map_)[d] == nullptr) {
|
||||
(*flr_map_)[d] = NewFunctionLibraryRuntime(
|
||||
device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
|
||||
graph_def_version_, lib_def_, default_thread_pool_,
|
||||
optimizer_options_, session_metadata_, this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||
|
@ -207,8 +207,9 @@ class ProcessFunctionLibraryRuntime {
|
||||
return device_set_;
|
||||
}
|
||||
|
||||
// Initialize the set of local and remote devices for op device selection.
|
||||
void InitializeDeviceSet();
|
||||
// Initialize the set of local and remote devices and corresponding flr for op
|
||||
// device selection.
|
||||
void InitializeDeviceAndFlr();
|
||||
|
||||
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
|
||||
|
||||
@ -478,6 +479,9 @@ class ProcessFunctionLibraryRuntime {
|
||||
int next_handle_ TF_GUARDED_BY(mu_);
|
||||
const SessionMetadata* const session_metadata_;
|
||||
const Rendezvous::Factory rendezvous_factory_;
|
||||
|
||||
const OptimizerOptions optimizer_options_;
|
||||
const int graph_def_version_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user