Remove devices_ and devices_map_ from EagerContext.

We currently look up devices either through device managers or the devices_map_ field. To make the runtime fault tolerant, the list of devices (especially remote ones) must change dynamically, and accessing the inconsistent devices_map_ often leads to errors.

PiperOrigin-RevId: 285475152
Change-Id: I37016875d94e481c4cb5143f19092fc46a481ff2
This commit is contained in:
Haoyu Zhang 2019-12-13 14:43:38 -08:00 committed by TensorFlower Gardener
parent 3e72418d7c
commit 84b1a3498c
5 changed files with 25 additions and 86 deletions

View File

@ -1045,7 +1045,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device;
status->status = ctx->context->FindDeviceByName(device_name, &device);
status->status = ctx->context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;

View File

@ -79,7 +79,7 @@ EagerContext::EagerContext(
: default_device_placement_policy_(default_device_placement_policy),
default_mirroring_policy_(default_mirroring_policy),
local_device_manager_(device_mgr, device_mgr_owned),
devices_(device_mgr->ListDevices()),
host_cpu_device_(device_mgr->ListDevices()[0]),
rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
custom_kernel_creator_(custom_kernel_creator),
@ -102,7 +102,7 @@ EagerContext::EagerContext(
// currently a no-op.
eager_context_created->GetCell()->Set(true);
monitoring::StartExporter();
InitDeviceMapAndAsync();
InitPrioritizedDeviceTypeList();
runner_ = [this](std::function<void()> closure) {
this->thread_pool_->Schedule(std::move(closure));
};
@ -140,24 +140,17 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
}
}
void EagerContext::InitDeviceMapAndAsync() {
for (auto* device : devices_) {
devices_map_[device->name()] = device;
}
if (remote_device_mgr() != nullptr) {
for (auto* device : remote_device_mgr()->ListDevices()) {
if (devices_map_.find(device->name()) == devices_map_.end()) {
devices_map_[device->name()] = device;
devices_.push_back(device);
}
}
}
void EagerContext::InitPrioritizedDeviceTypeList() {
DeviceSet ds;
for (Device* d : devices_) {
for (Device* d : local_device_mgr()->ListDevices()) {
ds.AddDevice(d);
}
auto remote_device_manager = remote_device_mgr();
if (remote_device_manager != nullptr) {
for (Device* d : remote_device_manager->ListDevices()) {
ds.AddDevice(d);
}
}
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
@ -391,17 +384,6 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
return result;
}
// TODO(gjn): Delete in favour of FindDeviceFromName
Status EagerContext::FindDeviceByName(const string& name,
Device** result) const {
auto it = devices_map_.find(name);
if (it == devices_map_.end()) {
return errors::InvalidArgument(name, " unknown device.");
}
*result = it->second;
return Status::OK();
}
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
void EagerContext::StartStep() {
@ -634,7 +616,7 @@ Status EagerContext::CPUDeviceOnTask(const Device* device,
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
device->name(), &cpu_device_name));
return FindDeviceByName(cpu_device_name, cpu_device);
return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
}
namespace {
@ -718,11 +700,9 @@ Status EagerContext::StoreCollectiveOpsServer(
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
local_device_manager_.Reset(device_mgr);
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
devices_ = local_device_manager_.Get()->ListDevices();
devices_map_.clear();
InitDeviceMapAndAsync();
InitPrioritizedDeviceTypeList();
ClearCaches();
default_executor_.ClearError();
{
@ -860,9 +840,7 @@ Status EagerContext::SetMasterContextState(
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
local_device_manager_.Reset(local_device_mgr);
devices_ = local_device_manager_.Get()->ListDevices();
devices_map_.clear();
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
@ -893,7 +871,7 @@ Status EagerContext::SetMasterContextState(
DCHECK(remote_device_manager_.Owned());
ResetClusterFLR(cluster_flr);
InitDeviceMapAndAsync();
InitPrioritizedDeviceTypeList();
ClearCaches();
default_executor_.ClearError();
@ -1009,7 +987,7 @@ Status EagerContext::InitializeRemoteWorker(
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
InitDeviceMapAndAsync();
InitPrioritizedDeviceTypeList();
ClearCaches();
default_executor_.ClearError();
@ -1048,9 +1026,7 @@ Status EagerContext::UpdateRemoteWorker(
ResetClusterFLR(cluster_flr);
remote_device_manager_.Reset(remote_device_mgr);
devices_ = worker_session_device_mgr->ListDevices();
devices_map_.clear();
InitDeviceMapAndAsync();
InitPrioritizedDeviceTypeList();
ClearCaches();
default_executor_.ClearError();

View File

@ -141,13 +141,6 @@ class EagerContext : public core::RefCounted {
// Specify a executor for this thread.
void SetExecutorForThread(EagerExecutor* executor);
// TODO(apassos) make this return a constant reference
gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
return &devices_map_;
}
// TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; }
const std::vector<DeviceType>& prioritized_device_type_list() {
return prioritized_device_type_list_;
}
@ -178,9 +171,7 @@ class EagerContext : public core::RefCounted {
const FunctionDef* FindFunctionDef(const string& name);
Status FindDeviceByName(const string& name, Device** result) const;
Device* HostCPU() const { return devices_[0]; }
Device* HostCPU() const { return host_cpu_device_; }
Device* CanonicalDevice(Device* d) const {
return HostCPU() == d ? nullptr : d;
}
@ -386,7 +377,7 @@ class EagerContext : public core::RefCounted {
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
private:
void InitDeviceMapAndAsync();
void InitPrioritizedDeviceTypeList();
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
Status RegisterExistingFunctionsOnRemoteWorkers(
const std::vector<const FunctionDef*>& function_defs,
@ -453,11 +444,8 @@ class EagerContext : public core::RefCounted {
// multi-device function on remote worker.
OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
// Devices owned by device_manager
std::vector<Device*> devices_;
Device* host_cpu_device_; // Owned by device_manager
std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_;
std::function<Rendezvous*(const int64)> rendezvous_creator_;

View File

@ -85,30 +85,6 @@ std::vector<string> DevicesToString(const std::vector<Device*> devices) {
return v;
}
// Initializes the step stats if needed.
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
// Lazily initialize the RunMetadata with information about all devices if
// this is the first call.
while (step_stats->dev_stats_size() < ctx->devices()->size()) {
int device_idx = step_stats->dev_stats_size();
auto* dev_stats = step_stats->add_dev_stats();
dev_stats->set_device(ctx->devices()->at(device_idx)->name());
}
}
int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
Device* device) {
// Find the current device's index.
for (int i = 0; i < ctx->devices()->size(); ++i) {
if (ctx->devices()->at(i) == device ||
ctx->devices()->at(i)->name() == device->name()) {
return i;
}
}
// TODO(apassos) do not fall back to host CPU if device is unknown.
return 0;
}
const string& DeviceNameOrUnspecified(Device* device) {
static string* unspecified_string = new string("<unspecified>");
return (device == nullptr) ? *unspecified_string : device->name();
@ -716,7 +692,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
if (op->Device() == nullptr) {
tensorflow::Device* device = nullptr;
string device_name = op->GetDeviceName();
TF_RETURN_IF_ERROR(ctx->FindDeviceByName(device_name, &device));
TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name.c_str(), &device));
op->SetDevice(device);
}

View File

@ -657,13 +657,12 @@ Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
if (ctx == nullptr) {
return nullptr;
}
const auto& map = *ctx->device_map();
auto it = map.find(handle.device());
if (it == map.end()) {
Device* device = nullptr;
if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
return nullptr;
}
return it->second;
return device;
}
string TensorHandle::DebugString() const {