Remove devices_ and devices_map_ from EagerContext.

We currently look up devices either through device managers or the devices_map_ field. To make the runtime fault tolerant, the list of devices (especially remote ones) must change dynamically, and accessing the inconsistent devices_map_ often leads to errors.

PiperOrigin-RevId: 285475152
Change-Id: I37016875d94e481c4cb5143f19092fc46a481ff2
This commit is contained in:
Haoyu Zhang 2019-12-13 14:43:38 -08:00 committed by TensorFlower Gardener
parent 3e72418d7c
commit 84b1a3498c
5 changed files with 25 additions and 86 deletions

View File

@ -1045,7 +1045,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg), void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) { void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device; tensorflow::Device* device;
status->status = ctx->context->FindDeviceByName(device_name, &device); status->status = ctx->context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) { if (!status->status.ok()) {
deallocator(data, len, deallocator_arg); deallocator(data, len, deallocator_arg);
return nullptr; return nullptr;

View File

@ -79,7 +79,7 @@ EagerContext::EagerContext(
: default_device_placement_policy_(default_device_placement_policy), : default_device_placement_policy_(default_device_placement_policy),
default_mirroring_policy_(default_mirroring_policy), default_mirroring_policy_(default_mirroring_policy),
local_device_manager_(device_mgr, device_mgr_owned), local_device_manager_(device_mgr, device_mgr_owned),
devices_(device_mgr->ListDevices()), host_cpu_device_(device_mgr->ListDevices()[0]),
rendezvous_(rendezvous), rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)), thread_pool_(NewThreadPoolFromSessionOptions(opts)),
custom_kernel_creator_(custom_kernel_creator), custom_kernel_creator_(custom_kernel_creator),
@ -102,7 +102,7 @@ EagerContext::EagerContext(
// currently a no-op. // currently a no-op.
eager_context_created->GetCell()->Set(true); eager_context_created->GetCell()->Set(true);
monitoring::StartExporter(); monitoring::StartExporter();
InitDeviceMapAndAsync(); InitPrioritizedDeviceTypeList();
runner_ = [this](std::function<void()> closure) { runner_ = [this](std::function<void()> closure) {
this->thread_pool_->Schedule(std::move(closure)); this->thread_pool_->Schedule(std::move(closure));
}; };
@ -140,24 +140,17 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
} }
} }
void EagerContext::InitDeviceMapAndAsync() { void EagerContext::InitPrioritizedDeviceTypeList() {
for (auto* device : devices_) {
devices_map_[device->name()] = device;
}
if (remote_device_mgr() != nullptr) {
for (auto* device : remote_device_mgr()->ListDevices()) {
if (devices_map_.find(device->name()) == devices_map_.end()) {
devices_map_[device->name()] = device;
devices_.push_back(device);
}
}
}
DeviceSet ds; DeviceSet ds;
for (Device* d : devices_) { for (Device* d : local_device_mgr()->ListDevices()) {
ds.AddDevice(d); ds.AddDevice(d);
} }
auto remote_device_manager = remote_device_mgr();
if (remote_device_manager != nullptr) {
for (Device* d : remote_device_manager->ListDevices()) {
ds.AddDevice(d);
}
}
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList(); prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
} }
@ -391,17 +384,6 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
return result; return result;
} }
// TODO(gjn): Delete in favour of FindDeviceFromName
Status EagerContext::FindDeviceByName(const string& name,
Device** result) const {
auto it = devices_map_.find(name);
if (it == devices_map_.end()) {
return errors::InvalidArgument(name, " unknown device.");
}
*result = it->second;
return Status::OK();
}
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); } void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
void EagerContext::StartStep() { void EagerContext::StartStep() {
@ -634,7 +616,7 @@ Status EagerContext::CPUDeviceOnTask(const Device* device,
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
device->name(), &cpu_device_name)); device->name(), &cpu_device_name));
return FindDeviceByName(cpu_device_name, cpu_device); return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
} }
namespace { namespace {
@ -718,11 +700,9 @@ Status EagerContext::StoreCollectiveOpsServer(
collective_executor_mgr_.Reset(rpc_collective_executor_mgr); collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
local_device_manager_.Reset(device_mgr); local_device_manager_.Reset(device_mgr);
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
devices_ = local_device_manager_.Get()->ListDevices(); InitPrioritizedDeviceTypeList();
devices_map_.clear();
InitDeviceMapAndAsync();
ClearCaches(); ClearCaches();
default_executor_.ClearError(); default_executor_.ClearError();
{ {
@ -860,9 +840,7 @@ Status EagerContext::SetMasterContextState(
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true); ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
local_device_manager_.Reset(local_device_mgr); local_device_manager_.Reset(local_device_mgr);
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
devices_ = local_device_manager_.Get()->ListDevices();
devices_map_.clear();
if (rendezvous_ != nullptr) rendezvous_->Unref(); if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r; rendezvous_ = r;
@ -893,7 +871,7 @@ Status EagerContext::SetMasterContextState(
DCHECK(remote_device_manager_.Owned()); DCHECK(remote_device_manager_.Owned());
ResetClusterFLR(cluster_flr); ResetClusterFLR(cluster_flr);
InitDeviceMapAndAsync(); InitPrioritizedDeviceTypeList();
ClearCaches(); ClearCaches();
default_executor_.ClearError(); default_executor_.ClearError();
@ -1009,7 +987,7 @@ Status EagerContext::InitializeRemoteWorker(
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION, ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, config->graph_options().optimizer_options(), &func_lib_def_, config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_); thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
InitDeviceMapAndAsync(); InitPrioritizedDeviceTypeList();
ClearCaches(); ClearCaches();
default_executor_.ClearError(); default_executor_.ClearError();
@ -1048,9 +1026,7 @@ Status EagerContext::UpdateRemoteWorker(
ResetClusterFLR(cluster_flr); ResetClusterFLR(cluster_flr);
remote_device_manager_.Reset(remote_device_mgr); remote_device_manager_.Reset(remote_device_mgr);
devices_ = worker_session_device_mgr->ListDevices(); InitPrioritizedDeviceTypeList();
devices_map_.clear();
InitDeviceMapAndAsync();
ClearCaches(); ClearCaches();
default_executor_.ClearError(); default_executor_.ClearError();

View File

@ -141,13 +141,6 @@ class EagerContext : public core::RefCounted {
// Specify a executor for this thread. // Specify a executor for this thread.
void SetExecutorForThread(EagerExecutor* executor); void SetExecutorForThread(EagerExecutor* executor);
// TODO(apassos) make this return a constant reference
gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
return &devices_map_;
}
// TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; }
const std::vector<DeviceType>& prioritized_device_type_list() { const std::vector<DeviceType>& prioritized_device_type_list() {
return prioritized_device_type_list_; return prioritized_device_type_list_;
} }
@ -178,9 +171,7 @@ class EagerContext : public core::RefCounted {
const FunctionDef* FindFunctionDef(const string& name); const FunctionDef* FindFunctionDef(const string& name);
Status FindDeviceByName(const string& name, Device** result) const; Device* HostCPU() const { return host_cpu_device_; }
Device* HostCPU() const { return devices_[0]; }
Device* CanonicalDevice(Device* d) const { Device* CanonicalDevice(Device* d) const {
return HostCPU() == d ? nullptr : d; return HostCPU() == d ? nullptr : d;
} }
@ -386,7 +377,7 @@ class EagerContext : public core::RefCounted {
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
private: private:
void InitDeviceMapAndAsync(); void InitPrioritizedDeviceTypeList();
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
Status RegisterExistingFunctionsOnRemoteWorkers( Status RegisterExistingFunctionsOnRemoteWorkers(
const std::vector<const FunctionDef*>& function_defs, const std::vector<const FunctionDef*>& function_defs,
@ -453,11 +444,8 @@ class EagerContext : public core::RefCounted {
// multi-device function on remote worker. // multi-device function on remote worker.
OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_; OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
// Devices owned by device_manager Device* host_cpu_device_; // Owned by device_manager
std::vector<Device*> devices_;
std::vector<DeviceType> prioritized_device_type_list_; std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_; Rendezvous* rendezvous_;
std::function<Rendezvous*(const int64)> rendezvous_creator_; std::function<Rendezvous*(const int64)> rendezvous_creator_;

View File

@ -85,30 +85,6 @@ std::vector<string> DevicesToString(const std::vector<Device*> devices) {
return v; return v;
} }
// Initializes the step stats if needed.
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
// Lazily initialize the RunMetadata with information about all devices if
// this is the first call.
while (step_stats->dev_stats_size() < ctx->devices()->size()) {
int device_idx = step_stats->dev_stats_size();
auto* dev_stats = step_stats->add_dev_stats();
dev_stats->set_device(ctx->devices()->at(device_idx)->name());
}
}
int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
Device* device) {
// Find the current device's index.
for (int i = 0; i < ctx->devices()->size(); ++i) {
if (ctx->devices()->at(i) == device ||
ctx->devices()->at(i)->name() == device->name()) {
return i;
}
}
// TODO(apassos) do not fall back to host CPU if device is unknown.
return 0;
}
const string& DeviceNameOrUnspecified(Device* device) { const string& DeviceNameOrUnspecified(Device* device) {
static string* unspecified_string = new string("<unspecified>"); static string* unspecified_string = new string("<unspecified>");
return (device == nullptr) ? *unspecified_string : device->name(); return (device == nullptr) ? *unspecified_string : device->name();
@ -716,7 +692,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
if (op->Device() == nullptr) { if (op->Device() == nullptr) {
tensorflow::Device* device = nullptr; tensorflow::Device* device = nullptr;
string device_name = op->GetDeviceName(); string device_name = op->GetDeviceName();
TF_RETURN_IF_ERROR(ctx->FindDeviceByName(device_name, &device)); TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name.c_str(), &device));
op->SetDevice(device); op->SetDevice(device);
} }

View File

@ -657,13 +657,12 @@ Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
if (ctx == nullptr) { if (ctx == nullptr) {
return nullptr; return nullptr;
} }
const auto& map = *ctx->device_map(); Device* device = nullptr;
auto it = map.find(handle.device()); if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
if (it == map.end()) {
LOG(ERROR) << "Cannot find resource device: " << handle.device() << "."; LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
return nullptr; return nullptr;
} }
return it->second; return device;
} }
string TensorHandle::DebugString() const { string TensorHandle::DebugString() const {