Remove devices_ and devices_map_ from EagerContext.
We currently look up devices either through device managers or the devices_map_ field. To make the runtime fault tolerant, the list of devices (especially remote ones) must change dynamically, and accessing the inconsistent devices_map_ often leads to errors. PiperOrigin-RevId: 285475152 Change-Id: I37016875d94e481c4cb5143f19092fc46a481ff2
This commit is contained in:
parent
3e72418d7c
commit
84b1a3498c
@ -1045,7 +1045,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceByName(device_name, &device);
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
|
||||
@ -79,7 +79,7 @@ EagerContext::EagerContext(
|
||||
: default_device_placement_policy_(default_device_placement_policy),
|
||||
default_mirroring_policy_(default_mirroring_policy),
|
||||
local_device_manager_(device_mgr, device_mgr_owned),
|
||||
devices_(device_mgr->ListDevices()),
|
||||
host_cpu_device_(device_mgr->ListDevices()[0]),
|
||||
rendezvous_(rendezvous),
|
||||
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
||||
custom_kernel_creator_(custom_kernel_creator),
|
||||
@ -102,7 +102,7 @@ EagerContext::EagerContext(
|
||||
// currently a no-op.
|
||||
eager_context_created->GetCell()->Set(true);
|
||||
monitoring::StartExporter();
|
||||
InitDeviceMapAndAsync();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
runner_ = [this](std::function<void()> closure) {
|
||||
this->thread_pool_->Schedule(std::move(closure));
|
||||
};
|
||||
@ -140,24 +140,17 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
||||
}
|
||||
}
|
||||
|
||||
void EagerContext::InitDeviceMapAndAsync() {
|
||||
for (auto* device : devices_) {
|
||||
devices_map_[device->name()] = device;
|
||||
}
|
||||
|
||||
if (remote_device_mgr() != nullptr) {
|
||||
for (auto* device : remote_device_mgr()->ListDevices()) {
|
||||
if (devices_map_.find(device->name()) == devices_map_.end()) {
|
||||
devices_map_[device->name()] = device;
|
||||
devices_.push_back(device);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EagerContext::InitPrioritizedDeviceTypeList() {
|
||||
DeviceSet ds;
|
||||
for (Device* d : devices_) {
|
||||
for (Device* d : local_device_mgr()->ListDevices()) {
|
||||
ds.AddDevice(d);
|
||||
}
|
||||
auto remote_device_manager = remote_device_mgr();
|
||||
if (remote_device_manager != nullptr) {
|
||||
for (Device* d : remote_device_manager->ListDevices()) {
|
||||
ds.AddDevice(d);
|
||||
}
|
||||
}
|
||||
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
||||
}
|
||||
|
||||
@ -391,17 +384,6 @@ std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
|
||||
return result;
|
||||
}
|
||||
|
||||
// TODO(gjn): Delete in favour of FindDeviceFromName
|
||||
Status EagerContext::FindDeviceByName(const string& name,
|
||||
Device** result) const {
|
||||
auto it = devices_map_.find(name);
|
||||
if (it == devices_map_.end()) {
|
||||
return errors::InvalidArgument(name, " unknown device.");
|
||||
}
|
||||
*result = it->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||
|
||||
void EagerContext::StartStep() {
|
||||
@ -634,7 +616,7 @@ Status EagerContext::CPUDeviceOnTask(const Device* device,
|
||||
TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
|
||||
device->name(), &cpu_device_name));
|
||||
|
||||
return FindDeviceByName(cpu_device_name, cpu_device);
|
||||
return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -718,11 +700,9 @@ Status EagerContext::StoreCollectiveOpsServer(
|
||||
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
|
||||
|
||||
local_device_manager_.Reset(device_mgr);
|
||||
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
|
||||
|
||||
devices_ = local_device_manager_.Get()->ListDevices();
|
||||
devices_map_.clear();
|
||||
|
||||
InitDeviceMapAndAsync();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
ClearCaches();
|
||||
default_executor_.ClearError();
|
||||
{
|
||||
@ -860,9 +840,7 @@ Status EagerContext::SetMasterContextState(
|
||||
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
|
||||
|
||||
local_device_manager_.Reset(local_device_mgr);
|
||||
|
||||
devices_ = local_device_manager_.Get()->ListDevices();
|
||||
devices_map_.clear();
|
||||
host_cpu_device_ = local_device_manager_.Get()->ListDevices()[0];
|
||||
|
||||
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
||||
rendezvous_ = r;
|
||||
@ -893,7 +871,7 @@ Status EagerContext::SetMasterContextState(
|
||||
DCHECK(remote_device_manager_.Owned());
|
||||
ResetClusterFLR(cluster_flr);
|
||||
|
||||
InitDeviceMapAndAsync();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
|
||||
ClearCaches();
|
||||
default_executor_.ClearError();
|
||||
@ -1009,7 +987,7 @@ Status EagerContext::InitializeRemoteWorker(
|
||||
ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
|
||||
&func_lib_def_, config->graph_options().optimizer_options(),
|
||||
thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
|
||||
InitDeviceMapAndAsync();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
|
||||
ClearCaches();
|
||||
default_executor_.ClearError();
|
||||
@ -1048,9 +1026,7 @@ Status EagerContext::UpdateRemoteWorker(
|
||||
ResetClusterFLR(cluster_flr);
|
||||
|
||||
remote_device_manager_.Reset(remote_device_mgr);
|
||||
devices_ = worker_session_device_mgr->ListDevices();
|
||||
devices_map_.clear();
|
||||
InitDeviceMapAndAsync();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
|
||||
ClearCaches();
|
||||
default_executor_.ClearError();
|
||||
|
||||
@ -141,13 +141,6 @@ class EagerContext : public core::RefCounted {
|
||||
// Specify a executor for this thread.
|
||||
void SetExecutorForThread(EagerExecutor* executor);
|
||||
|
||||
// TODO(apassos) make this return a constant reference
|
||||
gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
|
||||
return &devices_map_;
|
||||
}
|
||||
|
||||
// TODO(apassos) make this return a constant reference
|
||||
std::vector<Device*>* devices() { return &devices_; }
|
||||
const std::vector<DeviceType>& prioritized_device_type_list() {
|
||||
return prioritized_device_type_list_;
|
||||
}
|
||||
@ -178,9 +171,7 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
const FunctionDef* FindFunctionDef(const string& name);
|
||||
|
||||
Status FindDeviceByName(const string& name, Device** result) const;
|
||||
|
||||
Device* HostCPU() const { return devices_[0]; }
|
||||
Device* HostCPU() const { return host_cpu_device_; }
|
||||
Device* CanonicalDevice(Device* d) const {
|
||||
return HostCPU() == d ? nullptr : d;
|
||||
}
|
||||
@ -386,7 +377,7 @@ class EagerContext : public core::RefCounted {
|
||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||
|
||||
private:
|
||||
void InitDeviceMapAndAsync();
|
||||
void InitPrioritizedDeviceTypeList();
|
||||
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
|
||||
Status RegisterExistingFunctionsOnRemoteWorkers(
|
||||
const std::vector<const FunctionDef*>& function_defs,
|
||||
@ -453,11 +444,8 @@ class EagerContext : public core::RefCounted {
|
||||
// multi-device function on remote worker.
|
||||
OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
|
||||
|
||||
// Devices owned by device_manager
|
||||
std::vector<Device*> devices_;
|
||||
Device* host_cpu_device_; // Owned by device_manager
|
||||
std::vector<DeviceType> prioritized_device_type_list_;
|
||||
// All devices are not owned.
|
||||
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
|
||||
Rendezvous* rendezvous_;
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||
|
||||
|
||||
@ -85,30 +85,6 @@ std::vector<string> DevicesToString(const std::vector<Device*> devices) {
|
||||
return v;
|
||||
}
|
||||
|
||||
// Initializes the step stats if needed.
|
||||
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
|
||||
// Lazily initialize the RunMetadata with information about all devices if
|
||||
// this is the first call.
|
||||
while (step_stats->dev_stats_size() < ctx->devices()->size()) {
|
||||
int device_idx = step_stats->dev_stats_size();
|
||||
auto* dev_stats = step_stats->add_dev_stats();
|
||||
dev_stats->set_device(ctx->devices()->at(device_idx)->name());
|
||||
}
|
||||
}
|
||||
|
||||
int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
||||
Device* device) {
|
||||
// Find the current device's index.
|
||||
for (int i = 0; i < ctx->devices()->size(); ++i) {
|
||||
if (ctx->devices()->at(i) == device ||
|
||||
ctx->devices()->at(i)->name() == device->name()) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
// TODO(apassos) do not fall back to host CPU if device is unknown.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const string& DeviceNameOrUnspecified(Device* device) {
|
||||
static string* unspecified_string = new string("<unspecified>");
|
||||
return (device == nullptr) ? *unspecified_string : device->name();
|
||||
@ -716,7 +692,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
if (op->Device() == nullptr) {
|
||||
tensorflow::Device* device = nullptr;
|
||||
string device_name = op->GetDeviceName();
|
||||
TF_RETURN_IF_ERROR(ctx->FindDeviceByName(device_name, &device));
|
||||
TF_RETURN_IF_ERROR(ctx->FindDeviceFromName(device_name.c_str(), &device));
|
||||
op->SetDevice(device);
|
||||
}
|
||||
|
||||
|
||||
@ -657,13 +657,12 @@ Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto& map = *ctx->device_map();
|
||||
auto it = map.find(handle.device());
|
||||
if (it == map.end()) {
|
||||
Device* device = nullptr;
|
||||
if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
|
||||
LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
return device;
|
||||
}
|
||||
|
||||
string TensorHandle::DebugString() const {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user