Remove devices_ and devices_map_ from EagerContext.
We currently look up devices either through device managers or the devices_map_ field. To make the runtime fault tolerant, the list of devices (especially remote ones) must change dynamically, and accessing the inconsistent devices_map_ often leads to errors. PiperOrigin-RevId: 285475152 Change-Id: I37016875d94e481c4cb5143f19092fc46a481ff2
This commit is contained in:
parent
3e72418d7c
commit
84b1a3498c
@ -1045,7 +1045,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg, TF_Status* status) {
|
void* deallocator_arg, TF_Status* status) {
|
||||||
tensorflow::Device* device;
|
tensorflow::Device* device;
|
||||||
status->status = ctx->context->FindDeviceByName(device_name, &device);
|
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
deallocator(data, len, deallocator_arg);
|
deallocator(data, len, deallocator_arg);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@ -79,7 +79,7 @@ EagerContext::EagerContext(
|
|||||||
: default_device_placement_policy_(default_device_placement_policy),
|
: default_device_placement_policy_(default_device_placement_policy),
|
||||||
default_mirroring_policy_(default_mirroring_policy),
|
default_mirroring_policy_(default_mirroring_policy),
|
||||||
local_device_manager_(device_mgr, device_mgr_owned),
|
local_device_manager_(device_mgr, device_mgr_owned),
|
||||||
devices_(device_mgr->ListDevices()),
|
host_cpu_device_(device_mgr->ListDevices()[0]),
|
||||||
rendezvous_(rendezvous),
|
rendezvous_(rendezvous),
|
||||||
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
||||||
custom_kernel_creator_(custom_kernel_creator),
|
custom_kernel_creator_(custom_kernel_creator),
|
||||||
@ -102,7 +102,7 @@ EagerContext::EagerContext(
|
|||||||
// currently a no-op.
|
// currently a no-op.
|
||||||
eager_context_created->GetCell()->Set(true);
|
eager_context_created->GetCell()->Set(true);
|
||||||
monitoring::StartExporter();
|
monitoring::StartExporter();
|
||||||
InitDeviceMapAndAsync();
|
InitPrioritizedDeviceTypeList();
|
||||||
runner_ = [this](std::function<void()> closure) {
|
runner_ = [this](std::function<void()> closure) {
|
||||||
this->thread_pool_->Schedule(std::move(closure));
|
this->thread_pool_->Schedule(std::move(closure));
|
||||||
};
|
};
|
||||||
@ -140,24 +140,17 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerContext::InitDeviceMapAndAsync() {
|
void EagerContext::InitPrioritizedDeviceTypeList() {
|
||||||
for (auto* device : devices_) {
|
|
||||||
devices_map_[device->name()] = device;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (remote_device_mgr() != nullptr) {
|
|
||||||
for (auto* device : remote_device_mgr()->ListDevices()) {
|
|
||||||
if (devices_map_.find(device->name()) == devices_map_.end()) {
|
|
||||||
devices_map_[device->name()] = device;
|
|
||||||
devices_.push_back(device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceSet ds;
|
DeviceSet ds;
|
||||||
for (Device* d : devices_) {
|
for (Device* d : local_device_mgr()->ListDevices()) {
|
||||||
ds.AddDevice(d);
|
ds.AddDevice(d);
|
||||||
}
|
}
|
||||||
|
auto remote_device_manager = remote_device_mgr();
|
||||||
|
if (remote_device_manager != nullptr) {
|
||||||
|
for (Device* d : remote_device_manager->ListDevices()) {
|
||||||
|
ds.AddDevice(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -391,17 +384,6 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(gjn): Delete in favour of FindDeviceFromName
|
|
||||||
Status EagerContext::FindDeviceByName(const string& name,
|
|
||||||
Device** result) const {
|
|
||||||
auto it = devices_map_.find(name);
|
|
||||||
if (it == devices_map_.end()) {
|
|
||||||
return errors::InvalidArgument(name, " unknown device.");
|
|
||||||
}
|
|
||||||
*result = it->second;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||||
|
|
||||||
void EagerContext::StartStep() {
|
void EagerContext::StartStep() {
|
||||||
@ -634,7 +616,7 @@ Status EagerContext::CPUDeviceOnTask(const Device* device,
|
|||||||
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
|
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
|
||||||
device->name(), &cpu_device_name));
|
device->name(), &cpu_device_name));
|
||||||
|
|
||||||
return FindDeviceByName(cpu_device_name, cpu_device);
|
return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -718,11 +700,9 @@ Status EagerContext::StoreCollectiveOpsServer(
|
|||||||
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
|
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
|
||||||
|
|
||||||
local_device_manager_.Reset(device_mgr);
|
local_device_manager_.Reset(device_mgr);
|
||||||
|
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
|
||||||
|
|
||||||
devices_ = local_device_manager_.Get()->ListDevices();
|
InitPrioritizedDeviceTypeList();
|
||||||
devices_map_.clear();
|
|
||||||
|
|
||||||
InitDeviceMapAndAsync();
|
|
||||||
ClearCaches();
|
ClearCaches();
|
||||||
default_executor_.ClearError();
|
default_executor_.ClearError();
|
||||||
{
|
{
|
||||||
@ -860,9 +840,7 @@ Status EagerContext::SetMasterContextState(
|
|||||||
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
|
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
|
||||||
|
|
||||||
local_device_manager_.Reset(local_device_mgr);
|
local_device_manager_.Reset(local_device_mgr);
|
||||||
|
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
|
||||||
devices_ = local_device_manager_.Get()->ListDevices();
|
|
||||||
devices_map_.clear();
|
|
||||||
|
|
||||||
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
||||||
rendezvous_ = r;
|
rendezvous_ = r;
|
||||||
@ -893,7 +871,7 @@ Status EagerContext::SetMasterContextState(
|
|||||||
DCHECK(remote_device_manager_.Owned());
|
DCHECK(remote_device_manager_.Owned());
|
||||||
ResetClusterFLR(cluster_flr);
|
ResetClusterFLR(cluster_flr);
|
||||||
|
|
||||||
InitDeviceMapAndAsync();
|
InitPrioritizedDeviceTypeList();
|
||||||
|
|
||||||
ClearCaches();
|
ClearCaches();
|
||||||
default_executor_.ClearError();
|
default_executor_.ClearError();
|
||||||
@ -1009,7 +987,7 @@ Status EagerContext::InitializeRemoteWorker(
|
|||||||
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
|
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
|
||||||
&func_lib_def_, config->graph_options().optimizer_options(),
|
&func_lib_def_, config->graph_options().optimizer_options(),
|
||||||
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
|
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
|
||||||
InitDeviceMapAndAsync();
|
InitPrioritizedDeviceTypeList();
|
||||||
|
|
||||||
ClearCaches();
|
ClearCaches();
|
||||||
default_executor_.ClearError();
|
default_executor_.ClearError();
|
||||||
@ -1048,9 +1026,7 @@ Status EagerContext::UpdateRemoteWorker(
|
|||||||
ResetClusterFLR(cluster_flr);
|
ResetClusterFLR(cluster_flr);
|
||||||
|
|
||||||
remote_device_manager_.Reset(remote_device_mgr);
|
remote_device_manager_.Reset(remote_device_mgr);
|
||||||
devices_ = worker_session_device_mgr->ListDevices();
|
InitPrioritizedDeviceTypeList();
|
||||||
devices_map_.clear();
|
|
||||||
InitDeviceMapAndAsync();
|
|
||||||
|
|
||||||
ClearCaches();
|
ClearCaches();
|
||||||
default_executor_.ClearError();
|
default_executor_.ClearError();
|
||||||
|
|||||||
@ -141,13 +141,6 @@ class EagerContext : public core::RefCounted {
|
|||||||
// Specify a executor for this thread.
|
// Specify a executor for this thread.
|
||||||
void SetExecutorForThread(EagerExecutor* executor);
|
void SetExecutorForThread(EagerExecutor* executor);
|
||||||
|
|
||||||
// TODO(apassos) make this return a constant reference
|
|
||||||
gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
|
|
||||||
return &devices_map_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(apassos) make this return a constant reference
|
|
||||||
std::vector<Device*>* devices() { return &devices_; }
|
|
||||||
const std::vector<DeviceType>& prioritized_device_type_list() {
|
const std::vector<DeviceType>& prioritized_device_type_list() {
|
||||||
return prioritized_device_type_list_;
|
return prioritized_device_type_list_;
|
||||||
}
|
}
|
||||||
@ -178,9 +171,7 @@ class EagerContext : public core::RefCounted {
|
|||||||
|
|
||||||
const FunctionDef* FindFunctionDef(const string& name);
|
const FunctionDef* FindFunctionDef(const string& name);
|
||||||
|
|
||||||
Status FindDeviceByName(const string& name, Device** result) const;
|
Device* HostCPU() const { return host_cpu_device_; }
|
||||||
|
|
||||||
Device* HostCPU() const { return devices_[0]; }
|
|
||||||
Device* CanonicalDevice(Device* d) const {
|
Device* CanonicalDevice(Device* d) const {
|
||||||
return HostCPU() == d ? nullptr : d;
|
return HostCPU() == d ? nullptr : d;
|
||||||
}
|
}
|
||||||
@ -386,7 +377,7 @@ class EagerContext : public core::RefCounted {
|
|||||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitDeviceMapAndAsync();
|
void InitPrioritizedDeviceTypeList();
|
||||||
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
||||||
Status RegisterExistingFunctionsOnRemoteWorkers(
|
Status RegisterExistingFunctionsOnRemoteWorkers(
|
||||||
const std::vector<const FunctionDef*>& function_defs,
|
const std::vector<const FunctionDef*>& function_defs,
|
||||||
@ -453,11 +444,8 @@ class EagerContext : public core::RefCounted {
|
|||||||
// multi-device function on remote worker.
|
// multi-device function on remote worker.
|
||||||
OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
|
OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
|
||||||
|
|
||||||
// Devices owned by device_manager
|
Device* host_cpu_device_; // Owned by device_manager
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::vector<DeviceType> prioritized_device_type_list_;
|
std::vector<DeviceType> prioritized_device_type_list_;
|
||||||
// All devices are not owned.
|
|
||||||
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
|
|
||||||
Rendezvous* rendezvous_;
|
Rendezvous* rendezvous_;
|
||||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||||
|
|
||||||
|
|||||||
@ -85,30 +85,6 @@ std::vector<string> DevicesToString(const std::vector<Device*> devices) {
|
|||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializes the step stats if needed.
|
|
||||||
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
|
|
||||||
// Lazily initialize the RunMetadata with information about all devices if
|
|
||||||
// this is the first call.
|
|
||||||
while (step_stats->dev_stats_size() < ctx->devices()->size()) {
|
|
||||||
int device_idx = step_stats->dev_stats_size();
|
|
||||||
auto* dev_stats = step_stats->add_dev_stats();
|
|
||||||
dev_stats->set_device(ctx->devices()->at(device_idx)->name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
|
||||||
Device* device) {
|
|
||||||
// Find the current device's index.
|
|
||||||
for (int i = 0; i < ctx->devices()->size(); ++i) {
|
|
||||||
if (ctx->devices()->at(i) == device ||
|
|
||||||
ctx->devices()->at(i)->name() == device->name()) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO(apassos) do not fall back to host CPU if device is unknown.
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const string& DeviceNameOrUnspecified(Device* device) {
|
const string& DeviceNameOrUnspecified(Device* device) {
|
||||||
static string* unspecified_string = new string("<unspecified>");
|
static string* unspecified_string = new string("<unspecified>");
|
||||||
return (device == nullptr) ? *unspecified_string : device->name();
|
return (device == nullptr) ? *unspecified_string : device->name();
|
||||||
@ -716,7 +692,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
if (op->Device() == nullptr) {
|
if (op->Device() == nullptr) {
|
||||||
tensorflow::Device* device = nullptr;
|
tensorflow::Device* device = nullptr;
|
||||||
string device_name = op->GetDeviceName();
|
string device_name = op->GetDeviceName();
|
||||||
TF_RETURN_IF_ERROR(ctx->FindDeviceByName(device_name, &device));
|
TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name.c_str(), &device));
|
||||||
op->SetDevice(device);
|
op->SetDevice(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -657,13 +657,12 @@ Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
|
|||||||
if (ctx == nullptr) {
|
if (ctx == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
const auto& map = *ctx->device_map();
|
Device* device = nullptr;
|
||||||
auto it = map.find(handle.device());
|
if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
|
||||||
if (it == map.end()) {
|
|
||||||
LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
|
LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return it->second;
|
return device;
|
||||||
}
|
}
|
||||||
|
|
||||||
string TensorHandle::DebugString() const {
|
string TensorHandle::DebugString() const {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user