From 84b1a3498cd89ea95204a1955f0aa2363e560ad4 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Fri, 13 Dec 2019 14:43:38 -0800 Subject: [PATCH] Remove `devices_` and `devices_map_` from EagerContext. We currently look up devices either through device managers or the devices_map_ field. To make the runtime fault tolerant, the list of devices (especially remote ones) must change dynamically, and accessing the inconsistent devices_map_ often leads to errors. PiperOrigin-RevId: 285475152 Change-Id: I37016875d94e481c4cb5143f19092fc46a481ff2 --- tensorflow/c/eager/c_api.cc | 2 +- .../core/common_runtime/eager/context.cc | 58 ++++++------------- .../core/common_runtime/eager/context.h | 18 +----- .../core/common_runtime/eager/execute.cc | 26 +-------- .../common_runtime/eager/tensor_handle.cc | 7 +-- 5 files changed, 25 insertions(+), 86 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index a9892499cf0..66a2a4aaa3c 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1045,7 +1045,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( void (*deallocator)(void* data, size_t len, void* arg), void* deallocator_arg, TF_Status* status) { tensorflow::Device* device; - status->status = ctx->context->FindDeviceByName(device_name, &device); + status->status = ctx->context->FindDeviceFromName(device_name, &device); if (!status->status.ok()) { deallocator(data, len, deallocator_arg); return nullptr; diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index b8dd8d8dcd1..a58122b05bb 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -79,7 +79,7 @@ EagerContext::EagerContext( : default_device_placement_policy_(default_device_placement_policy), default_mirroring_policy_(default_mirroring_policy), local_device_manager_(device_mgr, device_mgr_owned), - devices_(device_mgr->ListDevices()), + host_cpu_device_(device_mgr->ListDevices()[0]), rendezvous_(rendezvous), thread_pool_(NewThreadPoolFromSessionOptions(opts)), custom_kernel_creator_(custom_kernel_creator), @@ -102,7 +102,7 @@ EagerContext::EagerContext( // currently a no-op. eager_context_created->GetCell()->Set(true); monitoring::StartExporter(); - InitDeviceMapAndAsync(); + InitPrioritizedDeviceTypeList(); runner_ = [this](std::function closure) { this->thread_pool_->Schedule(std::move(closure)); }; @@ -140,24 +140,17 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env, } } -void EagerContext::InitDeviceMapAndAsync() { - for (auto* device : devices_) { - devices_map_[device->name()] = device; - } - - if (remote_device_mgr() != nullptr) { - for (auto* device : remote_device_mgr()->ListDevices()) { - if (devices_map_.find(device->name()) == devices_map_.end()) { - devices_map_[device->name()] = device; - devices_.push_back(device); - } - } - } - +void EagerContext::InitPrioritizedDeviceTypeList() { DeviceSet ds; - for (Device* d : devices_) { + for (Device* d : local_device_mgr()->ListDevices()) { ds.AddDevice(d); } + auto remote_device_manager = remote_device_mgr(); + if (remote_device_manager != nullptr) { + for (Device* d : remote_device_manager->ListDevices()) { + ds.AddDevice(d); + } + } prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList(); } @@ -391,17 +384,6 @@ std::vector EagerContext::ListRegisteredFunctions() { return result; } -// TODO(gjn): Delete in favour of FindDeviceFromName -Status EagerContext::FindDeviceByName(const string& name, - Device** result) const { - auto it = devices_map_.find(name); - if (it == devices_map_.end()) { - return errors::InvalidArgument(name, " unknown device."); - } - *result = it->second; - return Status::OK(); -} - void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); } void EagerContext::StartStep() { @@ -634,7 +616,7 @@ Status EagerContext::CPUDeviceOnTask(const Device* device, TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( device->name(), &cpu_device_name)); - return FindDeviceByName(cpu_device_name, cpu_device); + return FindDeviceFromName(cpu_device_name.c_str(), cpu_device); } namespace { @@ -718,11 +700,9 @@ Status EagerContext::StoreCollectiveOpsServer( collective_executor_mgr_.Reset(rpc_collective_executor_mgr); local_device_manager_.Reset(device_mgr); + host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0]; - devices_ = local_device_manager_.Get()->ListDevices(); - devices_map_.clear(); - - InitDeviceMapAndAsync(); + InitPrioritizedDeviceTypeList(); ClearCaches(); default_executor_.ClearError(); { @@ -860,9 +840,7 @@ Status EagerContext::SetMasterContextState( ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true); local_device_manager_.Reset(local_device_mgr); - - devices_ = local_device_manager_.Get()->ListDevices(); - devices_map_.clear(); + host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0]; if (rendezvous_ != nullptr) rendezvous_->Unref(); rendezvous_ = r; @@ -893,7 +871,7 @@ Status EagerContext::SetMasterContextState( DCHECK(remote_device_manager_.Owned()); ResetClusterFLR(cluster_flr); - InitDeviceMapAndAsync(); + InitPrioritizedDeviceTypeList(); ClearCaches(); default_executor_.ClearError(); @@ -1009,7 +987,7 @@ Status EagerContext::InitializeRemoteWorker( ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION, &func_lib_def_, config->graph_options().optimizer_options(), thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_); - InitDeviceMapAndAsync(); + InitPrioritizedDeviceTypeList(); ClearCaches(); default_executor_.ClearError(); @@ -1048,9 +1026,7 @@ Status EagerContext::UpdateRemoteWorker( ResetClusterFLR(cluster_flr); remote_device_manager_.Reset(remote_device_mgr); - devices_ = worker_session_device_mgr->ListDevices(); - devices_map_.clear(); - InitDeviceMapAndAsync(); + InitPrioritizedDeviceTypeList(); ClearCaches(); default_executor_.ClearError(); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 93fbd8947fe..6807e0a9d5a 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -141,13 +141,6 @@ class EagerContext : public core::RefCounted { // Specify a executor for this thread. void SetExecutorForThread(EagerExecutor* executor); - // TODO(apassos) make this return a constant reference - gtl::FlatMap* device_map() { - return &devices_map_; - } - - // TODO(apassos) make this return a constant reference - std::vector* devices() { return &devices_; } const std::vector& prioritized_device_type_list() { return prioritized_device_type_list_; } @@ -178,9 +171,7 @@ class EagerContext : public core::RefCounted { const FunctionDef* FindFunctionDef(const string& name); - Status FindDeviceByName(const string& name, Device** result) const; - - Device* HostCPU() const { return devices_[0]; } + Device* HostCPU() const { return host_cpu_device_; } Device* CanonicalDevice(Device* d) const { return HostCPU() == d ? nullptr : d; } @@ -386,7 +377,7 @@ class EagerContext : public core::RefCounted { Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; private: - void InitDeviceMapAndAsync(); + void InitPrioritizedDeviceTypeList(); Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); Status RegisterExistingFunctionsOnRemoteWorkers( const std::vector& function_defs, @@ -453,11 +444,8 @@ class EagerContext : public core::RefCounted { // multi-device function on remote worker. OwnedOrUnownedHelper remote_device_manager_; - // Devices owned by device_manager - std::vector devices_; + Device* host_cpu_device_; // Owned by device_manager std::vector prioritized_device_type_list_; - // All devices are not owned. - gtl::FlatMap devices_map_; Rendezvous* rendezvous_; std::function rendezvous_creator_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 1ad83fd7f26..9584056295c 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -85,30 +85,6 @@ std::vector DevicesToString(const std::vector devices) { return v; } -// Initializes the step stats if needed. -void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) { - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices()->size()) { - int device_idx = step_stats->dev_stats_size(); - auto* dev_stats = step_stats->add_dev_stats(); - dev_stats->set_device(ctx->devices()->at(device_idx)->name()); - } -} - -int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx, - Device* device) { - // Find the current device's index. - for (int i = 0; i < ctx->devices()->size(); ++i) { - if (ctx->devices()->at(i) == device || - ctx->devices()->at(i)->name() == device->name()) { - return i; - } - } - // TODO(apassos) do not fall back to host CPU if device is unknown. - return 0; -} - const string& DeviceNameOrUnspecified(Device* device) { static string* unspecified_string = new string(""); return (device == nullptr) ? *unspecified_string : device->name(); @@ -716,7 +692,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, if (op->Device() == nullptr) { tensorflow::Device* device = nullptr; string device_name = op->GetDeviceName(); - TF_RETURN_IF_ERROR(ctx->FindDeviceByName(device_name, &device)); + TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name.c_str(), &device)); op->SetDevice(device); } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 717ec586eef..cc3e4a754a9 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -657,13 +657,12 @@ Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) { if (ctx == nullptr) { return nullptr; } - const auto& map = *ctx->device_map(); - auto it = map.find(handle.device()); - if (it == map.end()) { + Device* device = nullptr; + if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) { LOG(ERROR) << "Cannot find resource device: " << handle.device() << "."; return nullptr; } - return it->second; + return device; } string TensorHandle::DebugString() const {