Share ownership of UnboundedWorkQueue between collective executor and

executor manager.

Before this change, the lifetime of the single `UnboundedWorkQueue` that backed
collective execution was tied to `CollectiveExecutorMgr`.  However, it is
possible for the `CollectiveRemoteAccessLocal`s created by the executor manager
to outlive the manager.  This could lead to the executors enqueuing work on a
queue that was destroyed.

This change converts instances of `UnboundedWorkQueue` in collective op
implementation to shared pointers.  Each `CollectiveExecutor` that is created
by the executor manager also shares ownership of the work queue.  The work
queue is unreffed at two places: in the destructors of `CollectiveExecutorMgr`
and `CollectiveRemoteAccessLocal`.

PiperOrigin-RevId: 261775938
This commit is contained in:
Ayush Dubey 2019-08-05 14:54:41 -07:00 committed by TensorFlower Gardener
parent 26824feb8d
commit 17ce384df7
12 changed files with 53 additions and 46 deletions

View File

@ -66,12 +66,11 @@ class RecvBufCall : public CancellableCall {
class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
public:
CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver,
UnboundedWorkQueue* work_queue,
WorkerCacheInterface* worker_cache,
int64 step_id,
RemoteMemoryManager* remote_memory_manager)
CollectiveRemoteAccessDistributed(
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
std::shared_ptr<UnboundedWorkQueue> work_queue,
WorkerCacheInterface* worker_cache, int64 step_id,
RemoteMemoryManager* remote_memory_manager)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
worker_cache_(worker_cache),
remote_memory_manager_(remote_memory_manager) {}
@ -154,8 +153,8 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) {
CollectiveRemoteAccessDistributed* rma =
new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
&work_queue_, worker_cache_,
step_id, remote_memory_manager_);
work_queue_, worker_cache_, step_id,
remote_memory_manager_);
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
&gpu_ring_order_);
}

View File

@ -32,7 +32,8 @@ CollectiveExecutorMgr::CollectiveExecutorMgr(
param_resolver_(std::move(param_resolver)),
gpu_ring_order_(
config.gpu_options().experimental().collective_ring_order()),
work_queue_(Env::Default(), "collective_ops") {}
work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(),
"collective_ops")) {}
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
for (auto iter : executor_table_) {
@ -58,7 +59,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
dev_mgr_, dev_resolver_.get(), &work_queue_, step_id);
dev_mgr_, dev_resolver_.get(), work_queue_, step_id);
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
&gpu_ring_order_);
}

View File

@ -65,8 +65,9 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
std::unique_ptr<ParamResolverInterface> param_resolver_;
string gpu_ring_order_;
// Unbounded work queue for scheduling potentially-blocking work during
// collective op execution.
UnboundedWorkQueue work_queue_;
// collective op execution. Ownership is shared between `this` and
// `CollectiveRemoteAccessLocal`.
std::shared_ptr<UnboundedWorkQueue> work_queue_;
private:
mutex exec_mu_;

View File

@ -28,14 +28,15 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
public:
CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver,
UnboundedWorkQueue* work_queue, int64 step_id)
std::shared_ptr<UnboundedWorkQueue> work_queue,
int64 step_id)
: dev_mgr_(dev_mgr),
dev_resolver_(dev_resolver),
work_queue_(work_queue),
work_queue_(std::move(work_queue)),
buf_rendezvous_(step_id, dev_mgr),
step_id_(step_id) {}
virtual ~CollectiveRemoteAccessLocal() {}
~CollectiveRemoteAccessLocal() override = default;
void StartAbort(const Status& s) override;
@ -95,7 +96,9 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
protected:
const DeviceMgr* dev_mgr_; // not owned
DeviceResolverInterface* dev_resolver_; // not owned
UnboundedWorkQueue* work_queue_; // not owned
// Ownership of `work_queue_` is shared between `this` and
// `CollectiveExecutorMgr`.
std::shared_ptr<UnboundedWorkQueue> work_queue_;
BufRendezvous buf_rendezvous_;
int64 step_id_;
};

View File

@ -39,7 +39,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
const string kTaskName = "/job:localhost/replica:0/task:0";
CollectiveRemoteAccessLocalTest() {
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
ConfigProto cp;
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
@ -51,10 +51,12 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
prl_ = absl::make_unique<CollectiveParamResolverLocal>(
cp, device_mgr_.get(), drl_.get(), kTaskName);
rma_ = absl::make_unique<CollectiveRemoteAccessLocal>(
device_mgr_.get(), drl_.get(), work_queue_.get(), kStepId);
device_mgr_.get(), drl_.get(), work_queue_, kStepId);
}
std::unique_ptr<UnboundedWorkQueue> work_queue_;
~CollectiveRemoteAccessLocalTest() override = default;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_;

View File

@ -138,7 +138,8 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
class FailTestRMA : public CollectiveRemoteAccessLocal {
public:
FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
int fail_after)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
fail_after_(fail_after) {}
@ -252,9 +253,9 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
gpu_ring_order_ = absl::make_unique<string>();
}
dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
work_queue_.get(), kStepId, fail_after);
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
kStepId, fail_after);
col_exec_ = new BaseCollectiveExecutor(
&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
col_params_.name = "test_collective";
@ -719,7 +720,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
CollectiveExecutor* col_exec_ = nullptr;
CollectiveRemoteAccessLocal* rma_;
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::unique_ptr<UnboundedWorkQueue> work_queue_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;

View File

@ -46,7 +46,8 @@ namespace tensorflow {
class FailTestRMA : public CollectiveRemoteAccessLocal {
public:
FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
int fail_after)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
fail_after_(fail_after) {}
@ -172,9 +173,9 @@ class RingGathererTest : public ::testing::Test {
gpu_ring_order_ = absl::make_unique<string>();
}
dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
work_queue_.get(), kStepId, fail_after);
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
kStepId, fail_after);
col_exec_ = new BaseCollectiveExecutor(
&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
col_params_.name = "test_collective";
@ -523,7 +524,7 @@ class RingGathererTest : public ::testing::Test {
CollectiveExecutor* col_exec_;
CollectiveRemoteAccessLocal* rma_;
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::unique_ptr<UnboundedWorkQueue> work_queue_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;

View File

@ -46,7 +46,8 @@ namespace tensorflow {
class FailTestRMA : public CollectiveRemoteAccessLocal {
public:
FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
int fail_after)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
fail_after_(fail_after) {}
@ -194,9 +195,9 @@ class RingReducerTest : public ::testing::Test {
gpu_ring_order_ = absl::make_unique<string>();
}
dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
work_queue_.get(), kStepId, fail_after);
work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
kStepId, fail_after);
col_exec_ = new BaseCollectiveExecutor(
&col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
col_params_.name = "test_collective";
@ -550,7 +551,7 @@ class RingReducerTest : public ::testing::Test {
CollectiveExecutor* col_exec_;
CollectiveRemoteAccessLocal* rma_;
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::unique_ptr<UnboundedWorkQueue> work_queue_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;

View File

@ -25,11 +25,10 @@ class WorkerCacheInterface;
// Extend CollectiveRemoteAccessLocal with access to remote peers.
class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
public:
CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver,
UnboundedWorkQueue* work_queue,
WorkerCacheInterface* worker_cache,
int64 step_id)
CollectiveRemoteAccessDistributed(
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
std::shared_ptr<UnboundedWorkQueue> work_queue,
WorkerCacheInterface* worker_cache, int64 step_id)
: CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
worker_cache_(worker_cache) {}

View File

@ -170,7 +170,9 @@ class FakeCache : public TestWorkerCache {
class CollRMADistTest : public ::testing::Test {
protected:
CollRMADistTest() : work_queue_(Env::Default(), "test") {}
CollRMADistTest()
: work_queue_(
std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
~CollRMADistTest() override {
for (DeviceMgr* dm : device_mgrs_) {
@ -198,7 +200,7 @@ class CollRMADistTest : public ::testing::Test {
}
// All tests simulate requests from worker 0 to worker 1.
rma_.reset(new CollectiveRemoteAccessDistributed(
device_mgrs_[0], dev_resolvers_[dev0_worker_name], &work_queue_, &wc_,
device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
kStepId));
const int kNumElts = 8;
@ -258,7 +260,7 @@ class CollRMADistTest : public ::testing::Test {
std::vector<DeviceMgr*> device_mgrs_;
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
std::unordered_map<string, std::vector<string>> dev_by_task_;
UnboundedWorkQueue work_queue_;
std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<FakeWorker*> workers_;
std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
mutex mu_;

View File

@ -48,7 +48,7 @@ RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
CollectiveRemoteAccessDistributed* rma =
new CollectiveRemoteAccessDistributed(
dev_mgr_, dev_resolver_.get(), &work_queue_, worker_cache_, step_id);
dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
&gpu_ring_order_);
}

View File

@ -2486,9 +2486,6 @@ tf_py_test(
":framework_for_generated_wrappers",
"//third_party/py/numpy",
],
tags = [
"no_oss", # TODO(b/138811357): re-enable after fixing flakiness.
],
)
cuda_py_test(