Avoid segmentation fault when instantiating functions on removed devices.
PiperOrigin-RevId: 315958679 Change-Id: Ic5c58926883368fb14bed8ef0b9c51dc97cc6768
This commit is contained in:
parent
d8881eb71d
commit
d2a0ab1c5b
|
@ -111,6 +111,9 @@ class StaticDeviceMgr : public DeviceMgr {
|
|||
TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
|
||||
};
|
||||
|
||||
// Size of stale device buffer for temporary storage of removed devices.
|
||||
static const size_t kStaleDeviceBufferSize = 8192;
|
||||
|
||||
// Represents a dynamic set of devices
|
||||
class DynamicDeviceMgr : public DeviceMgr {
|
||||
public:
|
||||
|
@ -157,6 +160,28 @@ class DynamicDeviceMgr : public DeviceMgr {
|
|||
|
||||
mutable Device* cpu_device_ TF_GUARDED_BY(devices_mu_);
|
||||
|
||||
class DeviceCircularBuffer {
|
||||
public:
|
||||
DeviceCircularBuffer() : index_(0) {
|
||||
devices_.resize(kStaleDeviceBufferSize);
|
||||
}
|
||||
void add(std::unique_ptr<Device> device) {
|
||||
devices_[index_] = std::move(device);
|
||||
index_ = (index_ + 1) % kStaleDeviceBufferSize;
|
||||
}
|
||||
|
||||
private:
|
||||
int index_;
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
};
|
||||
|
||||
// Buffer to temporarily store the removed devices. Raw device pointers are
|
||||
// accessible to DeviceSet, and if the function instantiation process directly
|
||||
// access fields through the device set, the underlying device object must
|
||||
// still be available to avoid segmentation fault. We keep the devices in this
|
||||
// buffer only for that purpose.
|
||||
DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr);
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -178,6 +178,7 @@ Status DynamicDeviceMgr::RemoveDevices(std::vector<Device*> devices) {
|
|||
}
|
||||
device_type_counts_[d->device_type()]--;
|
||||
device_incarnation_set_.erase(d->attributes().incarnation());
|
||||
stale_devices_.add(std::move(it->second));
|
||||
dynamic_devices_.erase(it);
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -18,7 +18,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
|
@ -26,16 +30,36 @@ namespace tensorflow {
|
|||
namespace {
|
||||
|
||||
// Return a fake device with the specified type and name.
|
||||
static Device* CreateDevice(const char* type, const char* name) {
|
||||
static Device* CreateDevice(const char* type, const char* name,
|
||||
Notification* n = nullptr) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
Status Sync() override { return Status::OK(); }
|
||||
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||
};
|
||||
|
||||
class FakeDeviceWithDestructorNotification : public FakeDevice {
|
||||
public:
|
||||
FakeDeviceWithDestructorNotification(const DeviceAttributes& attr,
|
||||
Notification* n)
|
||||
: FakeDevice(attr), n_(n) {}
|
||||
~FakeDeviceWithDestructorNotification() override { n_->Notify(); }
|
||||
|
||||
private:
|
||||
Notification* n_;
|
||||
};
|
||||
|
||||
DeviceAttributes attr;
|
||||
attr.set_name(name);
|
||||
attr.set_device_type(type);
|
||||
do {
|
||||
attr.set_incarnation(random::New64());
|
||||
} while (attr.incarnation() == 0);
|
||||
|
||||
if (n) {
|
||||
return new FakeDeviceWithDestructorNotification(attr, n);
|
||||
}
|
||||
return new FakeDevice(attr);
|
||||
}
|
||||
|
||||
|
@ -57,6 +81,7 @@ TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgr) {
|
|||
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0"));
|
||||
std::unique_ptr<Device> d1(CreateDevice("CPU", "/device:CPU:1"));
|
||||
Device* d1_ptr = d1.get();
|
||||
const int64 d1_incarnation = d1->attributes().incarnation();
|
||||
|
||||
auto dm = MakeUnique<DynamicDeviceMgr>();
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
|
@ -68,6 +93,38 @@ TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgr) {
|
|||
std::vector<Device*> removed_devices{d1_ptr};
|
||||
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
||||
EXPECT_EQ(dm->ListDevices().size(), 1);
|
||||
EXPECT_FALSE(dm->ContainsDevice(d1_incarnation));
|
||||
|
||||
// Device still accessible shortly through the raw pointer after removal.
|
||||
EXPECT_EQ(d1_ptr->name(), "/device:CPU:1");
|
||||
EXPECT_EQ(d1_ptr->device_type(), "CPU");
|
||||
}
|
||||
|
||||
TEST(DynamicDeviceMgrTest, RemoveDeviceFromMgrBuffer) {
|
||||
// Create a device whose destructor will send a notification.
|
||||
Notification n;
|
||||
std::unique_ptr<Device> d0(CreateDevice("CPU", "/device:CPU:0", &n));
|
||||
Device* d0_ptr = d0.get();
|
||||
std::vector<std::unique_ptr<Device>> added_devices;
|
||||
added_devices.emplace_back(std::move(d0));
|
||||
auto dm = MakeUnique<DynamicDeviceMgr>();
|
||||
TF_CHECK_OK(dm->AddDevices(std::move(added_devices)));
|
||||
std::vector<Device*> removed_devices{d0_ptr};
|
||||
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
||||
|
||||
// Repeatedly add and remove devices to fill up the stale devices buffer.
|
||||
for (int i = 0; i < kStaleDeviceBufferSize; i++) {
|
||||
added_devices.clear();
|
||||
removed_devices.clear();
|
||||
std::unique_ptr<Device> d(CreateDevice("CPU", "/device:CPU:0"));
|
||||
Device* d_ptr = d.get();
|
||||
added_devices.emplace_back(std::move(d));
|
||||
TF_CHECK_OK(dm->AddDevices(std::move(added_devices)));
|
||||
removed_devices.emplace_back(d_ptr);
|
||||
TF_CHECK_OK(dm->RemoveDevices(removed_devices));
|
||||
}
|
||||
// Verify that d0 destructor is called after the buffer is full.
|
||||
n.WaitForNotification();
|
||||
}
|
||||
|
||||
TEST(DynamicDeviceMgrTest, RemoveDeviceByNameFromMgr) {
|
||||
|
|
|
@ -871,12 +871,12 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||
Status* status = &instantiate_status[i];
|
||||
string unique_name = name_generator.GetName();
|
||||
ComponentFunctionData* comp_data = &data->glue_[pair.first];
|
||||
runner([this, &pair, comp_data, unique_name, data_lib_def, &control_ret,
|
||||
&options, status, &counter, &data] {
|
||||
runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
|
||||
&control_ret, &options, status, &counter, &data] {
|
||||
const string& target = pair.first;
|
||||
|
||||
const string& device_type =
|
||||
device_set()->FindDeviceByName(target)->device_type();
|
||||
dev_set->FindDeviceByName(target)->device_type();
|
||||
Graph* subgraph = pair.second.get();
|
||||
|
||||
status->Update(UpdateArgAndRetvalMetadata(
|
||||
|
|
|
@ -95,15 +95,26 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||
ProcessFunctionLibraryRuntimeTest() {
|
||||
SessionOptions options;
|
||||
auto* device_count = options.config.mutable_device_count();
|
||||
device_count->insert({"CPU", 2});
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
device_count->insert({"CPU", 3});
|
||||
std::vector<std::unique_ptr<Device>> created_devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
|
||||
&devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
&created_devices));
|
||||
// Do not add CPU:2 to device manager. Used for removed device testing.
|
||||
device2_ = std::move(created_devices[2]);
|
||||
created_devices.erase(created_devices.begin() + 2);
|
||||
|
||||
device_mgr_ = std::make_unique<DynamicDeviceMgr>();
|
||||
TF_CHECK_OK(device_mgr_->AddDevices(std::move(created_devices)));
|
||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||
"/job:a/replica:0/task:0/device:CPU:0", &device0_));
|
||||
TF_CHECK_OK(device_mgr_->LookupDevice(
|
||||
"/job:a/replica:0/task:0/device:CPU:1", &device1_));
|
||||
Device* device2_ptr = nullptr;
|
||||
EXPECT_NE(
|
||||
error::OK,
|
||||
device_mgr_
|
||||
->LookupDevice("/job:a/replica:0/task:0/device:CPU:2", &device2_ptr)
|
||||
.code());
|
||||
// If no GPU is available, gpu_device_ will remain nullptr.
|
||||
Status status = device_mgr_->LookupDevice(
|
||||
"/job:a/replica:0/task:0/device:GPU:0", &gpu_device_);
|
||||
|
@ -301,9 +312,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
std::unique_ptr<DynamicDeviceMgr> device_mgr_;
|
||||
Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||
Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||
std::unique_ptr<Device> device2_;
|
||||
// Remains as nullptr if no GPU is available.
|
||||
Device* gpu_device_ = nullptr; // Not owned. (Owned by device_mgr_.)
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||
|
@ -445,6 +457,28 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
|||
TensorShape({})));
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, InstantiateFunctionOnRemovedDevice) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Device* device2_ptr = device2_.get();
|
||||
devices.emplace_back(std::move(device2_));
|
||||
TF_CHECK_OK(device_mgr_->AddDevices(std::move(devices)));
|
||||
|
||||
Init({test::function::FindDevice()});
|
||||
std::vector<Device*> remove_devices{device2_ptr};
|
||||
TF_CHECK_OK(device_mgr_->RemoveDevices(std::move(remove_devices)));
|
||||
|
||||
// Since the process FLR device set is not updated yet, it still holds the
|
||||
// raw pointer to device2. Make sure that function instantion with device2
|
||||
// will not lead to segfault.
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
FunctionLibraryRuntime::Handle h;
|
||||
instantiate_opts.target = "/job:a/replica:0/task:0/device:CPU:1";
|
||||
instantiate_opts.is_multi_device_function = true;
|
||||
TF_CHECK_OK(Instantiate("FindDevice",
|
||||
{{"_target", "/job:b/replica:0/task:0/device:CPU:2"}},
|
||||
instantiate_opts, &h));
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
|
||||
Init({test::function::FindDevice()});
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
|
|
Loading…
Reference in New Issue