Share ownership of UnboundedWorkQueue
between collective executor and
executor manager. Before this change, the lifetime of the single `UnboundedWorkQueue` that backed collective execution was tied to `CollectiveExecutorMgr`. However, it is possible for the `CollectiveRemoteAccessLocal`s created by the executor manager to outlive the manager. This could lead to the executors enqueuing work on a queue that was destroyed. This change converts instances of `UnboundedWorkQueue` in collective op implementation to shared pointers. Each `CollectiveExecutor` that is created by the executor manager also shares ownership of the work queue. The work queue is unreffed at two places: in the destructors of `CollectiveExecutorMgr` and `CollectiveRemoteAccessLocal`. PiperOrigin-RevId: 261775938
This commit is contained in:
parent
26824feb8d
commit
17ce384df7
@ -66,12 +66,11 @@ class RecvBufCall : public CancellableCall {
|
||||
|
||||
class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
public:
|
||||
CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
|
||||
DeviceResolverInterface* dev_resolver,
|
||||
UnboundedWorkQueue* work_queue,
|
||||
WorkerCacheInterface* worker_cache,
|
||||
int64 step_id,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
CollectiveRemoteAccessDistributed(
|
||||
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue,
|
||||
WorkerCacheInterface* worker_cache, int64 step_id,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
|
||||
worker_cache_(worker_cache),
|
||||
remote_memory_manager_(remote_memory_manager) {}
|
||||
@ -154,8 +153,8 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) {
|
||||
CollectiveRemoteAccessDistributed* rma =
|
||||
new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
|
||||
&work_queue_, worker_cache_,
|
||||
step_id, remote_memory_manager_);
|
||||
work_queue_, worker_cache_, step_id,
|
||||
remote_memory_manager_);
|
||||
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
|
||||
&gpu_ring_order_);
|
||||
}
|
||||
|
@ -32,7 +32,8 @@ CollectiveExecutorMgr::CollectiveExecutorMgr(
|
||||
param_resolver_(std::move(param_resolver)),
|
||||
gpu_ring_order_(
|
||||
config.gpu_options().experimental().collective_ring_order()),
|
||||
work_queue_(Env::Default(), "collective_ops") {}
|
||||
work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(),
|
||||
"collective_ops")) {}
|
||||
|
||||
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
|
||||
for (auto iter : executor_table_) {
|
||||
@ -58,7 +59,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
|
||||
|
||||
CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
|
||||
CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
|
||||
dev_mgr_, dev_resolver_.get(), &work_queue_, step_id);
|
||||
dev_mgr_, dev_resolver_.get(), work_queue_, step_id);
|
||||
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
|
||||
&gpu_ring_order_);
|
||||
}
|
||||
|
@ -65,8 +65,9 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
|
||||
std::unique_ptr<ParamResolverInterface> param_resolver_;
|
||||
string gpu_ring_order_;
|
||||
// Unbounded work queue for scheduling potentially-blocking work during
|
||||
// collective op execution.
|
||||
UnboundedWorkQueue work_queue_;
|
||||
// collective op execution. Ownership is shared between `this` and
|
||||
// `CollectiveRemoteAccessLocal`.
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
|
||||
private:
|
||||
mutex exec_mu_;
|
||||
|
@ -28,14 +28,15 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
|
||||
public:
|
||||
CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
|
||||
DeviceResolverInterface* dev_resolver,
|
||||
UnboundedWorkQueue* work_queue, int64 step_id)
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue,
|
||||
int64 step_id)
|
||||
: dev_mgr_(dev_mgr),
|
||||
dev_resolver_(dev_resolver),
|
||||
work_queue_(work_queue),
|
||||
work_queue_(std::move(work_queue)),
|
||||
buf_rendezvous_(step_id, dev_mgr),
|
||||
step_id_(step_id) {}
|
||||
|
||||
virtual ~CollectiveRemoteAccessLocal() {}
|
||||
~CollectiveRemoteAccessLocal() override = default;
|
||||
|
||||
void StartAbort(const Status& s) override;
|
||||
|
||||
@ -95,7 +96,9 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
|
||||
protected:
|
||||
const DeviceMgr* dev_mgr_; // not owned
|
||||
DeviceResolverInterface* dev_resolver_; // not owned
|
||||
UnboundedWorkQueue* work_queue_; // not owned
|
||||
// Ownership of `work_queue_` is shared between `this` and
|
||||
// `CollectiveExecutorMgr`.
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
BufRendezvous buf_rendezvous_;
|
||||
int64 step_id_;
|
||||
};
|
||||
|
@ -39,7 +39,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
|
||||
const string kTaskName = "/job:localhost/replica:0/task:0";
|
||||
|
||||
CollectiveRemoteAccessLocalTest() {
|
||||
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
ConfigProto cp;
|
||||
SessionOptions options;
|
||||
auto* device_count = options.config.mutable_device_count();
|
||||
@ -51,10 +51,12 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
|
||||
prl_ = absl::make_unique<CollectiveParamResolverLocal>(
|
||||
cp, device_mgr_.get(), drl_.get(), kTaskName);
|
||||
rma_ = absl::make_unique<CollectiveRemoteAccessLocal>(
|
||||
device_mgr_.get(), drl_.get(), work_queue_.get(), kStepId);
|
||||
device_mgr_.get(), drl_.get(), work_queue_, kStepId);
|
||||
}
|
||||
|
||||
std::unique_ptr<UnboundedWorkQueue> work_queue_;
|
||||
~CollectiveRemoteAccessLocalTest() override = default;
|
||||
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
std::unique_ptr<DeviceResolverLocal> drl_;
|
||||
std::unique_ptr<CollectiveParamResolverLocal> prl_;
|
||||
|
@ -138,7 +138,8 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
|
||||
class FailTestRMA : public CollectiveRemoteAccessLocal {
|
||||
public:
|
||||
FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
|
||||
UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
|
||||
int fail_after)
|
||||
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
|
||||
fail_after_(fail_after) {}
|
||||
|
||||
@ -252,9 +253,9 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
gpu_ring_order_ = absl::make_unique<string>();
|
||||
}
|
||||
dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
|
||||
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
|
||||
work_queue_.get(), kStepId, fail_after);
|
||||
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
|
||||
kStepId, fail_after);
|
||||
col_exec_ = new BaseCollectiveExecutor(
|
||||
&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
|
||||
col_params_.name = "test_collective";
|
||||
@ -719,7 +720,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
|
||||
CollectiveExecutor* col_exec_ = nullptr;
|
||||
CollectiveRemoteAccessLocal* rma_;
|
||||
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
||||
std::unique_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::vector<DeviceInstance*> instances_;
|
||||
CollectiveParams col_params_;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
|
||||
|
@ -46,7 +46,8 @@ namespace tensorflow {
|
||||
class FailTestRMA : public CollectiveRemoteAccessLocal {
|
||||
public:
|
||||
FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
|
||||
UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
|
||||
int fail_after)
|
||||
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
|
||||
fail_after_(fail_after) {}
|
||||
|
||||
@ -172,9 +173,9 @@ class RingGathererTest : public ::testing::Test {
|
||||
gpu_ring_order_ = absl::make_unique<string>();
|
||||
}
|
||||
dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
|
||||
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
|
||||
work_queue_.get(), kStepId, fail_after);
|
||||
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
|
||||
kStepId, fail_after);
|
||||
col_exec_ = new BaseCollectiveExecutor(
|
||||
&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
|
||||
col_params_.name = "test_collective";
|
||||
@ -523,7 +524,7 @@ class RingGathererTest : public ::testing::Test {
|
||||
CollectiveExecutor* col_exec_;
|
||||
CollectiveRemoteAccessLocal* rma_;
|
||||
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
||||
std::unique_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::vector<DeviceInstance*> instances_;
|
||||
CollectiveParams col_params_;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
|
||||
|
@ -46,7 +46,8 @@ namespace tensorflow {
|
||||
class FailTestRMA : public CollectiveRemoteAccessLocal {
|
||||
public:
|
||||
FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
|
||||
UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
|
||||
int fail_after)
|
||||
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
|
||||
fail_after_(fail_after) {}
|
||||
|
||||
@ -194,9 +195,9 @@ class RingReducerTest : public ::testing::Test {
|
||||
gpu_ring_order_ = absl::make_unique<string>();
|
||||
}
|
||||
dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
|
||||
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
|
||||
work_queue_.get(), kStepId, fail_after);
|
||||
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
|
||||
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
|
||||
kStepId, fail_after);
|
||||
col_exec_ = new BaseCollectiveExecutor(
|
||||
&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
|
||||
col_params_.name = "test_collective";
|
||||
@ -550,7 +551,7 @@ class RingReducerTest : public ::testing::Test {
|
||||
CollectiveExecutor* col_exec_;
|
||||
CollectiveRemoteAccessLocal* rma_;
|
||||
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
|
||||
std::unique_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::vector<DeviceInstance*> instances_;
|
||||
CollectiveParams col_params_;
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
|
||||
|
@ -25,11 +25,10 @@ class WorkerCacheInterface;
|
||||
// Extend CollectiveRemoteAccessLocal with access to remote peers.
|
||||
class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
public:
|
||||
CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
|
||||
DeviceResolverInterface* dev_resolver,
|
||||
UnboundedWorkQueue* work_queue,
|
||||
WorkerCacheInterface* worker_cache,
|
||||
int64 step_id)
|
||||
CollectiveRemoteAccessDistributed(
|
||||
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue,
|
||||
WorkerCacheInterface* worker_cache, int64 step_id)
|
||||
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
|
||||
worker_cache_(worker_cache) {}
|
||||
|
||||
|
@ -170,7 +170,9 @@ class FakeCache : public TestWorkerCache {
|
||||
|
||||
class CollRMADistTest : public ::testing::Test {
|
||||
protected:
|
||||
CollRMADistTest() : work_queue_(Env::Default(), "test") {}
|
||||
CollRMADistTest()
|
||||
: work_queue_(
|
||||
std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
|
||||
|
||||
~CollRMADistTest() override {
|
||||
for (DeviceMgr* dm : device_mgrs_) {
|
||||
@ -198,7 +200,7 @@ class CollRMADistTest : public ::testing::Test {
|
||||
}
|
||||
// All tests simulate requests from worker 0 to worker 1.
|
||||
rma_.reset(new CollectiveRemoteAccessDistributed(
|
||||
device_mgrs_[0], dev_resolvers_[dev0_worker_name], &work_queue_, &wc_,
|
||||
device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
|
||||
kStepId));
|
||||
|
||||
const int kNumElts = 8;
|
||||
@ -258,7 +260,7 @@ class CollRMADistTest : public ::testing::Test {
|
||||
std::vector<DeviceMgr*> device_mgrs_;
|
||||
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
|
||||
std::unordered_map<string, std::vector<string>> dev_by_task_;
|
||||
UnboundedWorkQueue work_queue_;
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::vector<FakeWorker*> workers_;
|
||||
std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
|
||||
mutex mu_;
|
||||
|
@ -48,7 +48,7 @@ RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
|
||||
CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
|
||||
CollectiveRemoteAccessDistributed* rma =
|
||||
new CollectiveRemoteAccessDistributed(
|
||||
dev_mgr_, dev_resolver_.get(), &work_queue_, worker_cache_, step_id);
|
||||
dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
|
||||
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
|
||||
&gpu_ring_order_);
|
||||
}
|
||||
|
@ -2486,9 +2486,6 @@ tf_py_test(
|
||||
":framework_for_generated_wrappers",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
tags = [
|
||||
"no_oss", # TODO(b/138811357): re-enable after fixing flakiness.
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
Loading…
Reference in New Issue
Block a user