From 17ce384df70e9fb69d881d7d60f1c802156a25bd Mon Sep 17 00:00:00 2001
From: Ayush Dubey <ayushd@google.com>
Date: Mon, 5 Aug 2019 14:54:41 -0700
Subject: [PATCH] Share ownership of `UnboundedWorkQueue` between collective
 executor and executor manager.

Before this change, the lifetime of the single `UnboundedWorkQueue` that backed
collective execution was tied to `CollectiveExecutorMgr`.  However, it is
possible for the `CollectiveRemoteAccessLocal`s created by the executor manager
to outlive the manager.  This could lead to the executors enqueuing work on a
queue that was destroyed.

This change converts instances of `UnboundedWorkQueue` in collective op
implementation to shared pointers.  Each `CollectiveExecutor` that is created
by the executor manager also shares ownership of the work queue.  The work
queue is unreffed at two places: in the destructors of `CollectiveExecutorMgr`
and `CollectiveRemoteAccessLocal`.

PiperOrigin-RevId: 261775938
---
 .../contrib/gdr/gdr_collective_executor_mgr.cc    | 15 +++++++--------
 .../common_runtime/collective_executor_mgr.cc     |  5 +++--
 .../core/common_runtime/collective_executor_mgr.h |  5 +++--
 .../core/common_runtime/collective_rma_local.h    | 11 +++++++----
 .../common_runtime/collective_rma_local_test.cc   |  8 +++++---
 .../hierarchical_tree_broadcaster_test.cc         | 11 ++++++-----
 .../core/common_runtime/ring_gatherer_test.cc     | 11 ++++++-----
 .../core/common_runtime/ring_reducer_test.cc      | 11 ++++++-----
 .../collective_rma_distributed.h                  |  9 ++++-----
 .../collective_rma_distributed_test.cc            |  8 +++++---
 .../rpc_collective_executor_mgr.cc                |  2 +-
 tensorflow/python/BUILD                           |  3 ---
 12 files changed, 53 insertions(+), 46 deletions(-)

diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
index 619c5bb0294..4988ce6d2fe 100644
--- a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
+++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc
@@ -66,12 +66,11 @@ class RecvBufCall : public CancellableCall {
 
 class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
  public:
-  CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
-                                    DeviceResolverInterface* dev_resolver,
-                                    UnboundedWorkQueue* work_queue,
-                                    WorkerCacheInterface* worker_cache,
-                                    int64 step_id,
-                                    RemoteMemoryManager* remote_memory_manager)
+  CollectiveRemoteAccessDistributed(
+      const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+      std::shared_ptr<UnboundedWorkQueue> work_queue,
+      WorkerCacheInterface* worker_cache, int64 step_id,
+      RemoteMemoryManager* remote_memory_manager)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         worker_cache_(worker_cache),
         remote_memory_manager_(remote_memory_manager) {}
@@ -154,8 +153,8 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
 CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessDistributed* rma =
       new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
-                                            &work_queue_, worker_cache_,
-                                            step_id, remote_memory_manager_);
+                                            work_queue_, worker_cache_, step_id,
+                                            remote_memory_manager_);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_);
 }
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index 105c400b6e3..e9e0082195d 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -32,7 +32,8 @@ CollectiveExecutorMgr::CollectiveExecutorMgr(
       param_resolver_(std::move(param_resolver)),
       gpu_ring_order_(
           config.gpu_options().experimental().collective_ring_order()),
-      work_queue_(Env::Default(), "collective_ops") {}
+      work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(),
+                                                       "collective_ops")) {}
 
 CollectiveExecutorMgr::~CollectiveExecutorMgr() {
   for (auto iter : executor_table_) {
@@ -58,7 +59,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
 
 CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
-      dev_mgr_, dev_resolver_.get(), &work_queue_, step_id);
+      dev_mgr_, dev_resolver_.get(), work_queue_, step_id);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_);
 }
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index ae5a67dbe7b..d4cef14c1d2 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -65,8 +65,9 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
   std::unique_ptr<ParamResolverInterface> param_resolver_;
   string gpu_ring_order_;
   // Unbounded work queue for scheduling potentially-blocking work during
-  // collective op execution.
-  UnboundedWorkQueue work_queue_;
+  // collective op execution.  Ownership is shared between `this` and
+  // `CollectiveRemoteAccessLocal`.
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
 
  private:
   mutex exec_mu_;
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 073c38e7bba..b5d02f4d2bd 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -28,14 +28,15 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
  public:
   CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr,
                               DeviceResolverInterface* dev_resolver,
-                              UnboundedWorkQueue* work_queue, int64 step_id)
+                              std::shared_ptr<UnboundedWorkQueue> work_queue,
+                              int64 step_id)
       : dev_mgr_(dev_mgr),
         dev_resolver_(dev_resolver),
-        work_queue_(work_queue),
+        work_queue_(std::move(work_queue)),
         buf_rendezvous_(step_id, dev_mgr),
         step_id_(step_id) {}
 
-  virtual ~CollectiveRemoteAccessLocal() {}
+  ~CollectiveRemoteAccessLocal() override = default;
 
   void StartAbort(const Status& s) override;
 
@@ -95,7 +96,9 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
  protected:
   const DeviceMgr* dev_mgr_;               // not owned
   DeviceResolverInterface* dev_resolver_;  // not owned
-  UnboundedWorkQueue* work_queue_;         // not owned
+  // Ownership of `work_queue_` is shared between `this` and
+  // `CollectiveExecutorMgr`.
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   BufRendezvous buf_rendezvous_;
   int64 step_id_;
 };
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index 57a497f9563..6024359643b 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -39,7 +39,7 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
   const string kTaskName = "/job:localhost/replica:0/task:0";
 
   CollectiveRemoteAccessLocalTest() {
-    work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
     ConfigProto cp;
     SessionOptions options;
     auto* device_count = options.config.mutable_device_count();
@@ -51,10 +51,12 @@ class CollectiveRemoteAccessLocalTest : public ::testing::Test {
     prl_ = absl::make_unique<CollectiveParamResolverLocal>(
         cp, device_mgr_.get(), drl_.get(), kTaskName);
     rma_ = absl::make_unique<CollectiveRemoteAccessLocal>(
-        device_mgr_.get(), drl_.get(), work_queue_.get(), kStepId);
+        device_mgr_.get(), drl_.get(), work_queue_, kStepId);
   }
 
-  std::unique_ptr<UnboundedWorkQueue> work_queue_;
+  ~CollectiveRemoteAccessLocalTest() override = default;
+
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::unique_ptr<DeviceMgr> device_mgr_;
   std::unique_ptr<DeviceResolverLocal> drl_;
   std::unique_ptr<CollectiveParamResolverLocal> prl_;
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 9253488b2c9..c00645a3ec3 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -138,7 +138,8 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
 class FailTestRMA : public CollectiveRemoteAccessLocal {
  public:
   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
-              UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
+              std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
+              int fail_after)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         fail_after_(fail_after) {}
 
@@ -252,9 +253,9 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
       gpu_ring_order_ = absl::make_unique<string>();
     }
     dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
-    work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
-    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
-                           work_queue_.get(), kStepId, fail_after);
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
+                           kStepId, fail_after);
     col_exec_ = new BaseCollectiveExecutor(
         &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
     col_params_.name = "test_collective";
@@ -719,7 +720,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
   CollectiveExecutor* col_exec_ = nullptr;
   CollectiveRemoteAccessLocal* rma_;
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
-  std::unique_ptr<UnboundedWorkQueue> work_queue_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
   CollectiveParams col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc
index f6cf8146ddd..a5648684906 100644
--- a/tensorflow/core/common_runtime/ring_gatherer_test.cc
+++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc
@@ -46,7 +46,8 @@ namespace tensorflow {
 class FailTestRMA : public CollectiveRemoteAccessLocal {
  public:
   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
-              UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
+              std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
+              int fail_after)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         fail_after_(fail_after) {}
 
@@ -172,9 +173,9 @@ class RingGathererTest : public ::testing::Test {
       gpu_ring_order_ = absl::make_unique<string>();
     }
     dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
-    work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
-    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
-                           work_queue_.get(), kStepId, fail_after);
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
+                           kStepId, fail_after);
     col_exec_ = new BaseCollectiveExecutor(
         &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
     col_params_.name = "test_collective";
@@ -523,7 +524,7 @@ class RingGathererTest : public ::testing::Test {
   CollectiveExecutor* col_exec_;
   CollectiveRemoteAccessLocal* rma_;
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
-  std::unique_ptr<UnboundedWorkQueue> work_queue_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
   CollectiveParams col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 39a7b63ce93..6141d332dd0 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -46,7 +46,8 @@ namespace tensorflow {
 class FailTestRMA : public CollectiveRemoteAccessLocal {
  public:
   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
-              UnboundedWorkQueue* work_queue, int64 step_id, int fail_after)
+              std::shared_ptr<UnboundedWorkQueue> work_queue, int64 step_id,
+              int fail_after)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         fail_after_(fail_after) {}
 
@@ -194,9 +195,9 @@ class RingReducerTest : public ::testing::Test {
       gpu_ring_order_ = absl::make_unique<string>();
     }
     dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
-    work_queue_ = absl::make_unique<UnboundedWorkQueue>(Env::Default(), "test");
-    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(),
-                           work_queue_.get(), kStepId, fail_after);
+    work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
+    rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), work_queue_,
+                           kStepId, fail_after);
     col_exec_ = new BaseCollectiveExecutor(
         &col_exec_mgr_, rma_, kStepId, dev_mgr_.get(), gpu_ring_order_.get());
     col_params_.name = "test_collective";
@@ -550,7 +551,7 @@ class RingReducerTest : public ::testing::Test {
   CollectiveExecutor* col_exec_;
   CollectiveRemoteAccessLocal* rma_;
   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
-  std::unique_ptr<UnboundedWorkQueue> work_queue_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<DeviceInstance*> instances_;
   CollectiveParams col_params_;
   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index 81d675a58f0..7d8fcc615cb 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -25,11 +25,10 @@ class WorkerCacheInterface;
 // Extend CollectiveRemoteAccessLocal with access to remote peers.
 class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
  public:
-  CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
-                                    DeviceResolverInterface* dev_resolver,
-                                    UnboundedWorkQueue* work_queue,
-                                    WorkerCacheInterface* worker_cache,
-                                    int64 step_id)
+  CollectiveRemoteAccessDistributed(
+      const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
+      std::shared_ptr<UnboundedWorkQueue> work_queue,
+      WorkerCacheInterface* worker_cache, int64 step_id)
       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id),
         worker_cache_(worker_cache) {}
 
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 99e9d7f0492..d55465099b5 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -170,7 +170,9 @@ class FakeCache : public TestWorkerCache {
 
 class CollRMADistTest : public ::testing::Test {
  protected:
-  CollRMADistTest() : work_queue_(Env::Default(), "test") {}
+  CollRMADistTest()
+      : work_queue_(
+            std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
 
   ~CollRMADistTest() override {
     for (DeviceMgr* dm : device_mgrs_) {
@@ -198,7 +200,7 @@ class CollRMADistTest : public ::testing::Test {
     }
     // All tests simulate requests from worker 0 to worker 1.
     rma_.reset(new CollectiveRemoteAccessDistributed(
-        device_mgrs_[0], dev_resolvers_[dev0_worker_name], &work_queue_, &wc_,
+        device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
         kStepId));
 
     const int kNumElts = 8;
@@ -258,7 +260,7 @@ class CollRMADistTest : public ::testing::Test {
   std::vector<DeviceMgr*> device_mgrs_;
   std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
   std::unordered_map<string, std::vector<string>> dev_by_task_;
-  UnboundedWorkQueue work_queue_;
+  std::shared_ptr<UnboundedWorkQueue> work_queue_;
   std::vector<FakeWorker*> workers_;
   std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
   mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 61c8f477e03..0c3ef6ab075 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -48,7 +48,7 @@ RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
   CollectiveRemoteAccessDistributed* rma =
       new CollectiveRemoteAccessDistributed(
-          dev_mgr_, dev_resolver_.get(), &work_queue_, worker_cache_, step_id);
+          dev_mgr_, dev_resolver_.get(), work_queue_, worker_cache_, step_id);
   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
                                     &gpu_ring_order_);
 }
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 0d0da3c3c7a..a8050e7afb7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2486,9 +2486,6 @@ tf_py_test(
         ":framework_for_generated_wrappers",
         "//third_party/py/numpy",
     ],
-    tags = [
-        "no_oss",  # TODO(b/138811357): re-enable after fixing flakiness.
-    ],
 )
 
 cuda_py_test(