Refresh device in EagerContext and pflr when device is updated. This is required to allow RuntimeFallback and KernelFallback to access TPU device created by tfrt.

PiperOrigin-RevId: 355932739
Change-Id: I043d217e06612734ba0e2b0bbf1cd672b4c9bf44
This commit is contained in:
Xiao Yu 2021-02-05 14:47:25 -08:00 committed by TensorFlower Gardener
parent 6fd7d69823
commit 4413f34d5c
4 changed files with 22 additions and 15 deletions

View File

@ -1277,7 +1277,7 @@ Status EagerContext::UpdateRemoteMaster(
context_view_id_++;
remote_eager_workers_ = std::move(remote_eager_workers);
pflr_->InitializeDeviceSet();
pflr_->InitializeDeviceAndFlr();
InitPrioritizedDeviceTypeList();
default_executor_.ClearError();
@ -1496,7 +1496,7 @@ Status EagerContext::UpdateRemoteWorker(
remote_contexts_ = remote_contexts;
remote_eager_workers_ = std::move(remote_eager_workers);
InitPrioritizedDeviceTypeList();
pflr_->InitializeDeviceSet();
pflr_->InitializeDeviceAndFlr();
}
// No need to update remote_device_manager_ since it's not owned for remote

View File

@ -486,6 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
const SessionOptions& session_options() const { return opts_; }
void InitPrioritizedDeviceTypeList();
private:
Rendezvous* CreateRendezvous(int64 step_id) const {
@ -510,7 +511,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
~EagerContext() override;
void InitPrioritizedDeviceTypeList();
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
Status RegisterExistingFunctionsOnRemoteWorkers(
const std::vector<string>& remote_workers);

View File

@ -100,7 +100,9 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
std::unique_ptr<FunctionLibraryRuntime>>),
next_handle_(0),
session_metadata_(session_metadata),
rendezvous_factory_(std::move(rendezvous_factory)) {
rendezvous_factory_(std::move(rendezvous_factory)),
optimizer_options_(optimizer_options),
graph_def_version_(graph_def_version) {
if (device_mgr == nullptr) {
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
@ -108,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
session_metadata_, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
(*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
lib_def_, default_thread_pool, optimizer_options, session_metadata_,
this);
}
InitializeDeviceSet();
InitializeDeviceAndFlr();
}
/* static */
@ -214,7 +209,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
"function executions");
}
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
@ -225,6 +220,14 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
for (auto d : all_devices->ListDevices()) {
device_set_->AddDevice(d);
}
for (Device* d : device_mgr_->ListDevices()) {
if ((*flr_map_)[d] == nullptr) {
(*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
graph_def_version_, lib_def_, default_thread_pool_,
optimizer_options_, session_metadata_, this);
}
}
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(

View File

@ -207,8 +207,9 @@ class ProcessFunctionLibraryRuntime {
return device_set_;
}
// Initialize the set of local and remote devices for op device selection.
void InitializeDeviceSet();
// Initialize the set of local and remote devices and corresponding flr for op
// device selection.
void InitializeDeviceAndFlr();
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
@ -478,6 +479,9 @@ class ProcessFunctionLibraryRuntime {
int next_handle_ TF_GUARDED_BY(mu_);
const SessionMetadata* const session_metadata_;
const Rendezvous::Factory rendezvous_factory_;
const OptimizerOptions optimizer_options_;
const int graph_def_version_;
};
} // namespace tensorflow