Avoid segmentation fault when instantiating functions on removed devices.

PiperOrigin-RevId: 315958679
Change-Id: Ic5c58926883368fb14bed8ef0b9c51dc97cc6768
This commit is contained in:
Haoyu Zhang 2020-06-11 12:53:27 -07:00 committed by TensorFlower Gardener
parent d8881eb71d
commit d2a0ab1c5b
5 changed files with 126 additions and 9 deletions

View File

@ -111,6 +111,9 @@ class StaticDeviceMgr : public DeviceMgr {
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr); TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
}; };
// Size of stale device buffer for temporary storage of removed devices.
static const size_t kStaleDeviceBufferSize = 8192;
// Represents a dynamic set of devices // Represents a dynamic set of devices
class DynamicDeviceMgr : public DeviceMgr { class DynamicDeviceMgr : public DeviceMgr {
public: public:
@ -157,6 +160,28 @@ class DynamicDeviceMgr : public DeviceMgr {
mutable Device* cpu_device_ TF_GUARDED_BY(devices_mu_); mutable Device* cpu_device_ TF_GUARDED_BY(devices_mu_);
class DeviceCircularBuffer {
public:
DeviceCircularBuffer() : index_(0) {
devices_.resize(kStaleDeviceBufferSize);
}
void add(std::unique_ptr<Device> device) {
devices_[index_] = std::move(device);
index_ = (index_ + 1) % kStaleDeviceBufferSize;
}
private:
int index_;
std::vector<std::unique_ptr<Device>> devices_;
};
// Buffer to temporarily store the removed devices. Raw device pointers are
// accessible to DeviceSet, and if the function instantiation process directly
// access fields through the device set, the underlying device object must
// still be available to avoid segmentation fault. We keep the devices in this
// buffer only for that purpose.
DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr); TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr);
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -178,6 +178,7 @@ Status DynamicDeviceMgr::RemoveDevices(std::vector<Device*> devices) {
} }
device_type_counts_[d->device_type()]--; device_type_counts_[d->device_type()]--;
device_incarnation_set_.erase(d->attributes().incarnation()); device_incarnation_set_.erase(d->attributes().incarnation());
stale_devices_.add(std::move(it->second));
dynamic_devices_.erase(it); dynamic_devices_.erase(it);
} }
return Status::OK(); return Status::OK();

View File

@ -18,7 +18,11 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/ptr_util.h"
@ -26,16 +30,36 @@ namespace tensorflow {
namespace { namespace {
// Return a fake device with the specified type and name. // Return a fake device with the specified type and name.
static Device* CreateDevice(const char* type, const char* name) { static Device* CreateDevice(const char* type, const char* name,
Notification* n = nullptr) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
}; };
class FakeDeviceWithDestructorNotification : public FakeDevice {
public:
FakeDeviceWithDestructorNotification(const DeviceAttributes& attr,
Notification* n)
: FakeDevice(attr), n_(n) {}
~FakeDeviceWithDestructorNotification() override { n_->Notify(); }
private:
Notification* n_;
};
DeviceAttributes attr; DeviceAttributes attr;
attr.set_name(name); attr.set_name(name);
attr.set_device_type(type); attr.set_device_type(type);
do {
attr.set_incarnation(random::New64());
} while (attr.incarnation() == 0);
if (n) {
return new FakeDeviceWithDestructorNotification(attr, n);
}
return new FakeDevice(attr); return new FakeDevice(attr);
} }
@ -57,6 +81,7 @@ TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgr) {
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0")); std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0"));
std::unique_ptr<Device> d1(CreateDevice("CPU", "/device:CPU:1")); std::unique_ptr<Device> d1(CreateDevice("CPU", "/device:CPU:1"));
Device* d1_ptr = d1.get(); Device* d1_ptr = d1.get();
const int64 d1_incarnation = d1->attributes().incarnation();
auto dm = MakeUnique<DynamicDeviceMgr>(); auto dm = MakeUnique<DynamicDeviceMgr>();
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
@ -68,6 +93,38 @@ TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgr) {
std::vector<Device*> removed_devices{d1_ptr}; std::vector<Device*> removed_devices{d1_ptr};
TF_CHECK_OK(dm->RemoveDevices(removed_devices)); TF_CHECK_OK(dm->RemoveDevices(removed_devices));
EXPECT_EQ(dm->ListDevices().size(), 1); EXPECT_EQ(dm->ListDevices().size(), 1);
EXPECT_FALSE(dm->ContainsDevice(d1_incarnation));
// Device still accessible shortly through the raw pointer after removal.
EXPECT_EQ(d1_ptr->name(), "/device:CPU:1");
EXPECT_EQ(d1_ptr->device_type(), "CPU");
}
TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgrBuffer) {
// Create a device whose destructor will send a notification.
Notification n;
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0", &n));
Device* d0_ptr = d0.get();
std::vector<std::unique_ptr<Device>> added_devices;
added_devices.emplace_back(std::move(d0));
auto dm = MakeUnique<DynamicDeviceMgr>();
TF_CHECK_OK(dm->AddDevices(std::move(added_devices)));
std::vector<Device*> removed_devices{d0_ptr};
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
// Repeatedly add and remove devices to fill up the stale devices buffer.
for (int i = 0; i < kStaleDeviceBufferSize; i++) {
added_devices.clear();
removed_devices.clear();
std::unique_ptr<Device> d(CreateDevice("CPU", "/device:CPU:0"));
Device* d_ptr = d.get();
added_devices.emplace_back(std::move(d));
TF_CHECK_OK(dm->AddDevices(std::move(added_devices)));
removed_devices.emplace_back(d_ptr);
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
}
// Verify that d0 destructor is called after the buffer is full.
n.WaitForNotification();
} }
TEST(DynamicDeviceMgrTest, RemoveDeviceByNameFromMgr) { TEST(DynamicDeviceMgrTest, RemoveDeviceByNameFromMgr) {

View File

@ -871,12 +871,12 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
Status* status = &instantiate_status[i]; Status* status = &instantiate_status[i];
string unique_name = name_generator.GetName(); string unique_name = name_generator.GetName();
ComponentFunctionData* comp_data = &data->glue_[pair.first]; ComponentFunctionData* comp_data = &data->glue_[pair.first];
runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret, runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
&options, status, &counter, &data] { &control_ret, &options, status, &counter, &data] {
const string& target = pair.first; const string& target = pair.first;
const string& device_type = const string& device_type =
device_set()->FindDeviceByName(target)->device_type(); dev_set->FindDeviceByName(target)->device_type();
Graph* subgraph = pair.second.get(); Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata( status->Update(UpdateArgAndRetvalMetadata(

View File

@ -95,15 +95,26 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
ProcessFunctionLibraryRuntimeTest() { ProcessFunctionLibraryRuntimeTest() {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 2}); device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> created_devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0", TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
&devices)); &created_devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices)); // Do not add CPU:2 to device manager. Used for removed device testing.
device2_ = std::move(created_devices[2]);
created_devices.erase(created_devices.begin() + 2);
device_mgr_ = std::make_unique<DynamicDeviceMgr>();
TF_CHECK_OK(device_mgr_->AddDevices(std::move(created_devices)));
TF_CHECK_OK(device_mgr_->LookupDevice( TF_CHECK_OK(device_mgr_->LookupDevice(
"/job:a/replica:0/task:0/device:CPU:0", &device0_)); "/job:a/replica:0/task:0/device:CPU:0", &device0_));
TF_CHECK_OK(device_mgr_->LookupDevice( TF_CHECK_OK(device_mgr_->LookupDevice(
"/job:a/replica:0/task:0/device:CPU:1", &device1_)); "/job:a/replica:0/task:0/device:CPU:1", &device1_));
Device* device2_ptr = nullptr;
EXPECT_NE(
error::OK,
device_mgr_
->LookupDevice("/job:a/replica:0/task:0/device:CPU:2", &device2_ptr)
.code());
// If no GPU is available, gpu_device_ will remain nullptr. // If no GPU is available, gpu_device_ will remain nullptr.
Status status = device_mgr_->LookupDevice( Status status = device_mgr_->LookupDevice(
"/job:a/replica:0/task:0/device:GPU:0", &gpu_device_); "/job:a/replica:0/task:0/device:GPU:0", &gpu_device_);
@ -301,9 +312,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK(); return Status::OK();
} }
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DynamicDeviceMgr> device_mgr_;
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.) Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.) Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
std::unique_ptr<Device> device2_;
// Remains as nullptr if no GPU is available. // Remains as nullptr if no GPU is available.
Device* gpu_device_ = nullptr; // Not owned. (Owned by device_mgr_.) Device* gpu_device_ = nullptr; // Not owned. (Owned by device_mgr_.)
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
@ -445,6 +457,28 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
TensorShape({}))); TensorShape({})));
} }
TEST_F(ProcessFunctionLibraryRuntimeTest, InstantiateFunctionOnRemovedDevice) {
std::vector<std::unique_ptr<Device>> devices;
Device* device2_ptr = device2_.get();
devices.emplace_back(std::move(device2_));
TF_CHECK_OK(device_mgr_->AddDevices(std::move(devices)));
Init({test::function::FindDevice()});
std::vector<Device*> remove_devices{device2_ptr};
TF_CHECK_OK(device_mgr_->RemoveDevices(std::move(remove_devices)));
// Since the process FLR device set is not updated yet, it still holds the
// raw pointer to device2. Make sure that function instantion with device2
// will not lead to segfault.
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
FunctionLibraryRuntime::Handle h;
instantiate_opts.target = "/job:a/replica:0/task:0/device:CPU:1";
instantiate_opts.is_multi_device_function = true;
TF_CHECK_OK(Instantiate("FindDevice",
{{"_target", "/job:b/replica:0/task:0/device:CPU:2"}},
instantiate_opts, &h));
}
TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) { TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
Init({test::function::FindDevice()}); Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts; FunctionLibraryRuntime::Options opts;