Avoid segmentation fault when instantiating functions on removed devices.
PiperOrigin-RevId: 315958679 Change-Id: Ic5c58926883368fb14bed8ef0b9c51dc97cc6768
This commit is contained in:
parent
d8881eb71d
commit
d2a0ab1c5b
|
@ -111,6 +111,9 @@ class StaticDeviceMgr : public DeviceMgr {
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
|
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Size of stale device buffer for temporary storage of removed devices.
|
||||||
|
static const size_t kStaleDeviceBufferSize = 8192;
|
||||||
|
|
||||||
// Represents a dynamic set of devices
|
// Represents a dynamic set of devices
|
||||||
class DynamicDeviceMgr : public DeviceMgr {
|
class DynamicDeviceMgr : public DeviceMgr {
|
||||||
public:
|
public:
|
||||||
|
@ -157,6 +160,28 @@ class DynamicDeviceMgr : public DeviceMgr {
|
||||||
|
|
||||||
mutable Device* cpu_device_ TF_GUARDED_BY(devices_mu_);
|
mutable Device* cpu_device_ TF_GUARDED_BY(devices_mu_);
|
||||||
|
|
||||||
|
class DeviceCircularBuffer {
|
||||||
|
public:
|
||||||
|
DeviceCircularBuffer() : index_(0) {
|
||||||
|
devices_.resize(kStaleDeviceBufferSize);
|
||||||
|
}
|
||||||
|
void add(std::unique_ptr<Device> device) {
|
||||||
|
devices_[index_] = std::move(device);
|
||||||
|
index_ = (index_ + 1) % kStaleDeviceBufferSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int index_;
|
||||||
|
std::vector<std::unique_ptr<Device>> devices_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Buffer to temporarily store the removed devices. Raw device pointers are
|
||||||
|
// accessible to DeviceSet, and if the function instantiation process directly
|
||||||
|
// access fields through the device set, the underlying device object must
|
||||||
|
// still be available to avoid segmentation fault. We keep the devices in this
|
||||||
|
// buffer only for that purpose.
|
||||||
|
DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr);
|
TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr);
|
||||||
};
|
};
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -178,6 +178,7 @@ Status DynamicDeviceMgr::RemoveDevices(std::vector<Device*> devices) {
|
||||||
}
|
}
|
||||||
device_type_counts_[d->device_type()]--;
|
device_type_counts_[d->device_type()]--;
|
||||||
device_incarnation_set_.erase(d->attributes().incarnation());
|
device_incarnation_set_.erase(d->attributes().incarnation());
|
||||||
|
stale_devices_.add(std::move(it->second));
|
||||||
dynamic_devices_.erase(it);
|
dynamic_devices_.erase(it);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -18,7 +18,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/notification.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
|
@ -26,16 +30,36 @@ namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Return a fake device with the specified type and name.
|
// Return a fake device with the specified type and name.
|
||||||
static Device* CreateDevice(const char* type, const char* name) {
|
static Device* CreateDevice(const char* type, const char* name,
|
||||||
|
Notification* n = nullptr) {
|
||||||
class FakeDevice : public Device {
|
class FakeDevice : public Device {
|
||||||
public:
|
public:
|
||||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||||
Status Sync() override { return Status::OK(); }
|
Status Sync() override { return Status::OK(); }
|
||||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class FakeDeviceWithDestructorNotification : public FakeDevice {
|
||||||
|
public:
|
||||||
|
FakeDeviceWithDestructorNotification(const DeviceAttributes& attr,
|
||||||
|
Notification* n)
|
||||||
|
: FakeDevice(attr), n_(n) {}
|
||||||
|
~FakeDeviceWithDestructorNotification() override { n_->Notify(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Notification* n_;
|
||||||
|
};
|
||||||
|
|
||||||
DeviceAttributes attr;
|
DeviceAttributes attr;
|
||||||
attr.set_name(name);
|
attr.set_name(name);
|
||||||
attr.set_device_type(type);
|
attr.set_device_type(type);
|
||||||
|
do {
|
||||||
|
attr.set_incarnation(random::New64());
|
||||||
|
} while (attr.incarnation() == 0);
|
||||||
|
|
||||||
|
if (n) {
|
||||||
|
return new FakeDeviceWithDestructorNotification(attr, n);
|
||||||
|
}
|
||||||
return new FakeDevice(attr);
|
return new FakeDevice(attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,6 +81,7 @@ TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgr) {
|
||||||
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0"));
|
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0"));
|
||||||
std::unique_ptr<Device> d1(CreateDevice("CPU", "/device:CPU:1"));
|
std::unique_ptr<Device> d1(CreateDevice("CPU", "/device:CPU:1"));
|
||||||
Device* d1_ptr = d1.get();
|
Device* d1_ptr = d1.get();
|
||||||
|
const int64 d1_incarnation = d1->attributes().incarnation();
|
||||||
|
|
||||||
auto dm = MakeUnique<DynamicDeviceMgr>();
|
auto dm = MakeUnique<DynamicDeviceMgr>();
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
|
@ -68,6 +93,38 @@ TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgr) {
|
||||||
std::vector<Device*> removed_devices{d1_ptr};
|
std::vector<Device*> removed_devices{d1_ptr};
|
||||||
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
||||||
EXPECT_EQ(dm->ListDevices().size(), 1);
|
EXPECT_EQ(dm->ListDevices().size(), 1);
|
||||||
|
EXPECT_FALSE(dm->ContainsDevice(d1_incarnation));
|
||||||
|
|
||||||
|
// Device still accessible shortly through the raw pointer after removal.
|
||||||
|
EXPECT_EQ(d1_ptr->name(), "/device:CPU:1");
|
||||||
|
EXPECT_EQ(d1_ptr->device_type(), "CPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgrBuffer) {
|
||||||
|
// Create a device whose destructor will send a notification.
|
||||||
|
Notification n;
|
||||||
|
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0", &n));
|
||||||
|
Device* d0_ptr = d0.get();
|
||||||
|
std::vector<std::unique_ptr<Device>> added_devices;
|
||||||
|
added_devices.emplace_back(std::move(d0));
|
||||||
|
auto dm = MakeUnique<DynamicDeviceMgr>();
|
||||||
|
TF_CHECK_OK(dm->AddDevices(std::move(added_devices)));
|
||||||
|
std::vector<Device*> removed_devices{d0_ptr};
|
||||||
|
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
||||||
|
|
||||||
|
// Repeatedly add and remove devices to fill up the stale devices buffer.
|
||||||
|
for (int i = 0; i < kStaleDeviceBufferSize; i++) {
|
||||||
|
added_devices.clear();
|
||||||
|
removed_devices.clear();
|
||||||
|
std::unique_ptr<Device> d(CreateDevice("CPU", "/device:CPU:0"));
|
||||||
|
Device* d_ptr = d.get();
|
||||||
|
added_devices.emplace_back(std::move(d));
|
||||||
|
TF_CHECK_OK(dm->AddDevices(std::move(added_devices)));
|
||||||
|
removed_devices.emplace_back(d_ptr);
|
||||||
|
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
||||||
|
}
|
||||||
|
// Verify that d0 destructor is called after the buffer is full.
|
||||||
|
n.WaitForNotification();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DynamicDeviceMgrTest, RemoveDeviceByNameFromMgr) {
|
TEST(DynamicDeviceMgrTest, RemoveDeviceByNameFromMgr) {
|
||||||
|
|
|
@ -871,12 +871,12 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||||
Status* status = &instantiate_status[i];
|
Status* status = &instantiate_status[i];
|
||||||
string unique_name = name_generator.GetName();
|
string unique_name = name_generator.GetName();
|
||||||
ComponentFunctionData* comp_data = &data->glue_[pair.first];
|
ComponentFunctionData* comp_data = &data->glue_[pair.first];
|
||||||
runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret,
|
runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
|
||||||
&options, status, &counter, &data] {
|
&control_ret, &options, status, &counter, &data] {
|
||||||
const string& target = pair.first;
|
const string& target = pair.first;
|
||||||
|
|
||||||
const string& device_type =
|
const string& device_type =
|
||||||
device_set()->FindDeviceByName(target)->device_type();
|
dev_set->FindDeviceByName(target)->device_type();
|
||||||
Graph* subgraph = pair.second.get();
|
Graph* subgraph = pair.second.get();
|
||||||
|
|
||||||
status->Update(UpdateArgAndRetvalMetadata(
|
status->Update(UpdateArgAndRetvalMetadata(
|
||||||
|
|
|
@ -95,15 +95,26 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||||
ProcessFunctionLibraryRuntimeTest() {
|
ProcessFunctionLibraryRuntimeTest() {
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 2});
|
device_count->insert({"CPU", 3});
|
||||||
std::vector<std::unique_ptr<Device>> devices;
|
std::vector<std::unique_ptr<Device>> created_devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
||||||
&devices));
|
&created_devices));
|
||||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
// Do not add CPU:2 to device manager. Used for removed device testing.
|
||||||
|
device2_ = std::move(created_devices[2]);
|
||||||
|
created_devices.erase(created_devices.begin() + 2);
|
||||||
|
|
||||||
|
device_mgr_ = std::make_unique<DynamicDeviceMgr>();
|
||||||
|
TF_CHECK_OK(device_mgr_->AddDevices(std::move(created_devices)));
|
||||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||||
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
|
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
|
||||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||||
"/job:a/replica:0/task:0/device:CPU:1", &device1_));
|
"/job:a/replica:0/task:0/device:CPU:1", &device1_));
|
||||||
|
Device* device2_ptr = nullptr;
|
||||||
|
EXPECT_NE(
|
||||||
|
error::OK,
|
||||||
|
device_mgr_
|
||||||
|
->LookupDevice("/job:a/replica:0/task:0/device:CPU:2", &device2_ptr)
|
||||||
|
.code());
|
||||||
// If no GPU is available, gpu_device_ will remain nullptr.
|
// If no GPU is available, gpu_device_ will remain nullptr.
|
||||||
Status status = device_mgr_->LookupDevice(
|
Status status = device_mgr_->LookupDevice(
|
||||||
"/job:a/replica:0/task:0/device:GPU:0", &gpu_device_);
|
"/job:a/replica:0/task:0/device:GPU:0", &gpu_device_);
|
||||||
|
@ -301,9 +312,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DynamicDeviceMgr> device_mgr_;
|
||||||
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||||
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||||
|
std::unique_ptr<Device> device2_;
|
||||||
// Remains as nullptr if no GPU is available.
|
// Remains as nullptr if no GPU is available.
|
||||||
Device* gpu_device_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
Device* gpu_device_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
|
@ -445,6 +457,28 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ProcessFunctionLibraryRuntimeTest, InstantiateFunctionOnRemovedDevice) {
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
|
Device* device2_ptr = device2_.get();
|
||||||
|
devices.emplace_back(std::move(device2_));
|
||||||
|
TF_CHECK_OK(device_mgr_->AddDevices(std::move(devices)));
|
||||||
|
|
||||||
|
Init({test::function::FindDevice()});
|
||||||
|
std::vector<Device*> remove_devices{device2_ptr};
|
||||||
|
TF_CHECK_OK(device_mgr_->RemoveDevices(std::move(remove_devices)));
|
||||||
|
|
||||||
|
// Since the process FLR device set is not updated yet, it still holds the
|
||||||
|
// raw pointer to device2. Make sure that function instantion with device2
|
||||||
|
// will not lead to segfault.
|
||||||
|
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||||
|
FunctionLibraryRuntime::Handle h;
|
||||||
|
instantiate_opts.target = "/job:a/replica:0/task:0/device:CPU:1";
|
||||||
|
instantiate_opts.is_multi_device_function = true;
|
||||||
|
TF_CHECK_OK(Instantiate("FindDevice",
|
||||||
|
{{"_target", "/job:b/replica:0/task:0/device:CPU:2"}},
|
||||||
|
instantiate_opts, &h));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
FunctionLibraryRuntime::Options opts;
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
|
Loading…
Reference in New Issue