From c2b5bebd70ca3cc69411ce0385ee4d42d887f795 Mon Sep 17 00:00:00 2001
From: Ran Chen <crccw@google.com>
Date: Wed, 21 Oct 2020 12:47:34 -0700
Subject: [PATCH] Set a timeout to check health RPC

fail_fast is ignored by default in OSS, so we need to set timeout to the RPC to avoid it taking forever if a remote worker is down.

PiperOrigin-RevId: 338320405
Change-Id: I154e3ebbcd3a224a0caa3a39a0a7f9927e3ea220
---
 tensorflow/c/c_api_experimental.cc               |  8 ++++----
 tensorflow/c/c_api_experimental.h                |  6 +++---
 .../core/common_runtime/collective_rma_local.cc  |  1 +
 .../core/common_runtime/collective_rma_local.h   |  2 +-
 .../common_runtime/collective_rma_local_test.cc  |  9 +++++----
 tensorflow/core/distributed_runtime/BUILD        |  1 +
 ...collective_param_resolver_distributed_test.cc |  2 +-
 .../collective_rma_distributed.cc                | 15 ++++++++++-----
 .../collective_rma_distributed.h                 |  2 +-
 .../collective_rma_distributed_test.cc           | 12 ++++++------
 .../core/distributed_runtime/remote_device.cc    |  3 ++-
 tensorflow/core/distributed_runtime/rpc/BUILD    |  1 +
 .../rpc/grpc_remote_worker.cc                    |  5 +++--
 .../core/distributed_runtime/rpc/grpc_state.h    |  3 ++-
 tensorflow/core/distributed_runtime/test_utils.h |  2 +-
 tensorflow/core/distributed_runtime/worker.cc    |  2 +-
 tensorflow/core/distributed_runtime/worker.h     |  2 +-
 .../core/distributed_runtime/worker_interface.h  |  5 +++--
 tensorflow/core/framework/collective.h           |  2 +-
 .../distribute/collective_all_reduce_strategy.py |  9 ++++++---
 .../python/distribute/integration_test/BUILD     |  1 -
 .../integration_test/mwms_peer_failure_test.py   |  2 ++
 tensorflow/python/eager/context.py               |  6 ++++--
 .../collective_ops_multi_worker_test.py          | 16 ++++++++--------
 tensorflow/python/tfe_wrapper.cc                 |  4 ++--
 25 files changed, 70 insertions(+), 51 deletions(-)

diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 81fb9d1a2b8..0d188aa5ee0 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
   collective_executor_handle->get()->StartAbort(status->status);
 }
 
-TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
-                                                            const char* task,
-                                                            TF_Status* status) {
+TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
+    TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
+    TF_Status* status) {
   tensorflow::EagerContext* context =
       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
   auto collective_executor_handle = context->GetCollectiveExecutorHandle();
   tensorflow::Notification done;
   collective_executor_handle->get()->remote_access()->CheckPeerHealth(
-      task, [&done, status](const Status& s) {
+      task, timeout_in_ms, [&done, status](const Status& s) {
         status->status = s;
         done.Notify();
       });
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index c9c74f4e874..90e074d232f 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
 // Checks the health of collective ops peers. Explicit health check is needed in
 // multi worker collective ops to detect failures in the cluster.  If a peer is
 // down, collective ops may hang.
-TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
-                                                            const char* task,
-                                                            TF_Status* status);
+TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
+    TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
+    TF_Status* status);
 
 // Information about the shape of a Tensor and its type.
 struct TF_ShapeAndType {
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
index 64ec7f03cdb..44175a042a7 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -107,6 +107,7 @@ void CollectiveRemoteAccessLocal::PostToPeer(
 }
 
 void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
+                                                  int64 timeout_in_ms,
                                                   const StatusCallback& done) {
   // Assume local devices are always healthy.
   done(errors::Internal(
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index dbb10f41a2d..fb4ddf178e5 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -55,7 +55,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
                   CancellationManager* cancellation_manager,
                   const StatusCallback& done) override;
 
-  void CheckPeerHealth(const string& peer_task,
+  void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
                        const StatusCallback& done) override;
 
   BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index be15e0eeff1..30f6e372606 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -156,10 +156,11 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
 TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
   Status status;
   Notification done;
-  rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
-    status = s;
-    done.Notify();
-  });
+  rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0,
+                        [&status, &done](const Status& s) {
+                          status = s;
+                          done.Notify();
+                        });
   done.WaitForNotification();
   EXPECT_TRUE(errors::IsInternal(status));
 }
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index d3030f4ca0b..a8681af3f7f 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -531,6 +531,7 @@ cc_library(
     srcs = ["collective_rma_distributed.cc"],
     hdrs = ["collective_rma_distributed.h"],
     deps = [
+        ":call_options",
         ":cancellable_call",
         ":request_id",
         ":worker_cache",
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index f08f7a3275d..1c62b17fe54 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface {
              CollectiveParamResolverDistributed* cpres)
       : name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
 
-  void GetStatusAsync(const GetStatusRequest* request,
+  void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override {
     std::vector<DeviceAttributes> dev_attr;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index f0e30b2cbe5..0f35af48e61 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
 #include "tensorflow/core/distributed_runtime/request_id.h"
 #include "tensorflow/core/distributed_runtime/worker_cache.h"
@@ -179,7 +180,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
 }
 
 void CollectiveRemoteAccessDistributed::CheckPeerHealth(
-    const string& peer_task, const StatusCallback& done) {
+    const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) {
   if (peer_task == task_name_) {
     // Fast path if the peer is the worker itself.
     done(Status::OK());
@@ -196,13 +197,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
                                  "valid form is /job:xxx/replica:0/task:N"));
     return;
   }
+  auto opts = new CallOptions();
+  opts->SetTimeout(timeout_in_ms);
   auto req = new GetStatusRequest();
   auto resp = new GetStatusResponse();
-  // We're not using Cancellable call because GetStatusAsync doesn't support
-  // cancellation yet.
+  // Note that fail_fast is not always respected, so we set a timeout as well.
+  // We're not using CancellableCall since check health shouldn't need to be
+  // cancelled.
   wi->GetStatusAsync(
-      req, resp, /*fail_fast*/ true,
-      [this, req, resp, wi, peer_task, done](Status s) {
+      opts, req, resp, /*fail_fast*/ true,
+      [this, opts, req, resp, wi, peer_task, done](Status s) {
         std::vector<DeviceAttributes> cached_attrs;
         if (s.ok()) {
           s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
@@ -227,6 +231,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
           // first collective.
           s = Status::OK();
         }
+        delete opts;
         delete req;
         delete resp;
         worker_cache_->ReleaseWorker(peer_task, wi);
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index 58065284f72..f17c6360cf7 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -45,7 +45,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
                     CancellationManager* cancellation_manager,
                     const StatusCallback& done) override;
 
-  void CheckPeerHealth(const string& peer_task,
+  void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
                        const StatusCallback& done) override;
 
   void StartAbort(const Status& s) override;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 8ccdde064de..74282beff1f 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface {
   // worker is supposed to have.
   BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
 
-  void GetStatusAsync(const GetStatusRequest* request,
+  void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override {
     if (is_failed_) {
@@ -459,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
   Status check_health_status;
   Notification check_health_done;
   rma_->CheckPeerHealth(
-      "/job:worker/replica:0/task:1",
+      "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
       [&check_health_status, &check_health_done](const Status s) {
         check_health_status = s;
         check_health_done.Notify();
@@ -472,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
   Status check_health_status;
   Notification check_health_done;
   rma_->CheckPeerHealth(
-      "/job:worker/replica:0/task:1",
+      "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
       [&check_health_status, &check_health_done](const Status s) {
         check_health_status = s;
         check_health_done.Notify();
@@ -488,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
   Status check_health_status;
   Notification check_health_done;
   rma_->CheckPeerHealth(
-      "/job:worker/replica:0/task:1",
+      "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
       [&check_health_status, &check_health_done](const Status s) {
         check_health_status = s;
         check_health_done.Notify();
@@ -505,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
   Status check_health_status;
   Notification check_health_done;
   rma_->CheckPeerHealth(
-      "/job:worker/replica:0/task:1",
+      "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
       [&check_health_status, &check_health_done](const Status s) {
         check_health_status = s;
         check_health_done.Notify();
@@ -520,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
   Status check_health_status;
   Notification check_health_done;
   rma_->CheckPeerHealth(
-      "/job:worker/replica:0/task:1",
+      "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
       [&check_health_status, &check_health_done](const Status s) {
         check_health_status = s;
         check_health_done.Notify();
diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc
index 05a9072894e..bb9b074858a 100644
--- a/tensorflow/core/distributed_runtime/remote_device.cc
+++ b/tensorflow/core/distributed_runtime/remote_device.cc
@@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
       }
     }
   };
-  wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb);
+  wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
+                     /*fail_fast=*/false, cb);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 14a358e8ac2..97dc8257750 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -94,6 +94,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/distributed_runtime:call_options",
         "//tensorflow/core/distributed_runtime:tensor_coding",
         "//tensorflow/core/distributed_runtime:worker_cache_logger",
         "//tensorflow/core/distributed_runtime:worker_interface",
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index d529abef36c..986ae6adf78 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "grpcpp/generic/generic_stub.h"
 #include "grpcpp/grpcpp.h"
 #include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface {
 
   ~GrpcRemoteWorker() override {}
 
-  void GetStatusAsync(const GetStatusRequest* request,
+  void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override {
-    IssueRequest(request, response, getstatus_, std::move(done), nullptr,
+    IssueRequest(request, response, getstatus_, std::move(done), call_opts,
                  fail_fast);
   }
 
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index c6e08f9c1a7..d0e67cdcd57 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -84,7 +84,8 @@ class RPCState : public GrpcClientCQTag {
                 return false;
               }
             }(),
-            /*timeout_in_ms=*/0, max_retries, target) {
+            (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
+            target) {
   }
 
   template <typename Request>
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index cec09775469..dc9badfedef 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -30,7 +30,7 @@ namespace tensorflow {
 // testing.
 class TestWorkerInterface : public WorkerInterface {
  public:
-  void GetStatusAsync(const GetStatusRequest* request,
+  void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override {
     done(errors::Unimplemented("GetStatusAsync"));
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index c4dc51ce47d..be14a58ca49 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
   StatusGroup::ConfigureLogHistory();
 }
 
-void Worker::GetStatusAsync(const GetStatusRequest* request,
+void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                             GetStatusResponse* response, bool fail_fast,
                             StatusCallback done) {
   const DeviceMgr* dm = env_->device_mgr;
diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h
index 273335ff36f..e280cf2447d 100644
--- a/tensorflow/core/distributed_runtime/worker.h
+++ b/tensorflow/core/distributed_runtime/worker.h
@@ -45,7 +45,7 @@ class Worker : public WorkerInterface {
   Worker(WorkerEnv* env);
   virtual ~Worker() {}
 
-  void GetStatusAsync(const GetStatusRequest* request,
+  void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                       GetStatusResponse* response, bool fail_fast,
                       StatusCallback done) override;
 
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index 9492d1cd31b..7b759eef95b 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -36,7 +36,8 @@ class TensorResponse;
 // Interface for talking with the TensorFlow Worker service.
 class WorkerInterface {
  public:
-  virtual void GetStatusAsync(const GetStatusRequest* request,
+  virtual void GetStatusAsync(CallOptions* opts,
+                              const GetStatusRequest* request,
                               GetStatusResponse* response, bool fail_fast,
                               StatusCallback done) = 0;
 
@@ -132,7 +133,7 @@ class WorkerInterface {
                    GetStatusResponse* response) {
     Status ret;
     Notification n;
-    GetStatusAsync(request, response, /*fail_fast=*/true,
+    GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
                    [&ret, &n](const Status& s) {
                      ret = s;
                      n.Notify();
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index b9a05b64b09..cd4c28e1d2f 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -283,7 +283,7 @@ class CollectiveRemoteAccess {
   // Checks the health of a collective peer. It probes the peer to see if it is
   // alive. Note that if a peer has restarted, it's considered a different one,
   // so CheckPeerHealth fails.
-  virtual void CheckPeerHealth(const string& peer_task,
+  virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
                                const StatusCallback& done) = 0;
 
   virtual BufRendezvous* buf_rendezvous() = 0;
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
index d08df3ed7fe..37a440bf46e 100644
--- a/tensorflow/python/distribute/collective_all_reduce_strategy.py
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -309,6 +309,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
   _check_health_initial_timeout = 0
   # Times to retry before considering the peer is down.
   _check_health_retry_limit = 3
+  # Timeout in seconds the each check health.
+  _check_health_timeout = 10
 
   def __init__(self, container_strategy, cluster_resolver,
                communication_options):
@@ -780,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
           while True:
             attempts += 1
             try:
-              context.context().check_collective_ops_peer_health(peer)
+              context.context().check_collective_ops_peer_health(
+                  peer, timeout_in_ms=self._check_health_timeout * 1000)
               # If check_collective_ops_peer_health doesn't raise an Exception,
               # the peer is healthy.
               break
-            except (errors.UnavailableError,
-                    errors.FailedPreconditionError) as e:
+            except (errors.UnavailableError, errors.FailedPreconditionError,
+                    errors.DeadlineExceededError) as e:
               # TODO(b/151232436): Always raise UnavailableError when a peer
               # fails. Now there could be many kinds of errors:
               # - Unavailable: when the peer is not reachable, e.g. it's down.
diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD
index 16736501d1d..b19710e17b3 100644
--- a/tensorflow/python/distribute/integration_test/BUILD
+++ b/tensorflow/python/distribute/integration_test/BUILD
@@ -36,7 +36,6 @@ cuda_py_test(
     shard_count = 2,
     tags = [
         "multi_and_single_gpu",
-        "no_oss",  # TODO(b/170838851): UnavailableError:  Connection reset by peer
     ],
     deps = [
         "//tensorflow:tensorflow_py",
diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py
index 02dee6f6adb..283a76300a9 100644
--- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py
+++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py
@@ -37,6 +37,8 @@ from tensorflow.python.eager import test
 mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
 mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
 mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
+# This is needed for OSS, which issues all RPCs with fail_fast=false by default.
+mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1
 
 
 def get_attempt(strategy, attempts):
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 026dbce321d..ac7f7e2cfa1 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -748,7 +748,7 @@ class Context(object):
     self.ensure_initialized()
     pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
 
-  def check_collective_ops_peer_health(self, task):
+  def check_collective_ops_peer_health(self, task, timeout_in_ms):
     """Check collective peer health.
 
     This probes each task to see if they're still alive. Note that restarted
@@ -758,6 +758,7 @@ class Context(object):
 
     Args:
       task: a task string, must be in the format of /job:xxx/replica:0/task:N.
+      timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
 
     Raises:
       tf.errors.UnavailableError: when a peer is down.
@@ -766,7 +767,8 @@ class Context(object):
       tf.errors.InvalidArgumentError: when the task string is invalid.
     """
     self.ensure_initialized()
-    pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task)
+    pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
+                                                timeout_in_ms)
 
   @property
   def _handle(self):
diff --git a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py
index 47e296a384b..9dbedc771e1 100644
--- a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py
+++ b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py
@@ -60,7 +60,8 @@ class CollectiveOpTest(test.TestCase):
               "/job:worker/replica:0/task:0",
               "/job:worker/replica:0/task:1",
           ]:
-            context.context().check_collective_ops_peer_health(task)
+            context.context().check_collective_ops_peer_health(
+                task, timeout_in_ms=1000)
         except errors.UnavailableError:
           continue
         break
@@ -73,18 +74,16 @@ class CollectiveOpTest(test.TestCase):
 
   def testCheckHealthPeerDown(self):
 
-    if multi_process_runner.is_oss():
-      self.skipTest("TODO(b/170838845): Failing in OSS")
-
     def worker_fn():
       enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
       context.context().check_collective_ops_peer_health(
-          "/job:worker/replica:0/task:1",)
+          "/job:worker/replica:0/task:1", timeout_in_ms=1000)
 
     cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
     mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
     mpr.start_single_process("worker", 0)
-    with self.assertRaises(errors.UnavailableError):
+    with self.assertRaises(
+        (errors.UnavailableError, errors.DeadlineExceededError)):
       mpr.join()
 
   def testCheckHealthPeerRestart(self):
@@ -112,7 +111,7 @@ class CollectiveOpTest(test.TestCase):
           time.sleep(1)
           try:
             context.context().check_collective_ops_peer_health(
-                "/job:worker/replica:0/task:0",)
+                "/job:worker/replica:0/task:0", timeout_in_ms=1000)
           except errors.UnavailableError:
             pass
           except errors.FailedPreconditionError:
@@ -129,7 +128,8 @@ class CollectiveOpTest(test.TestCase):
 
     def worker_fn():
       enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
-      context.context().check_collective_ops_peer_health("localhost:12345",)
+      context.context().check_collective_ops_peer_health(
+          "localhost:12345", timeout_in_ms=1000)
 
     cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
     mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 980695f28bb..1c46c228cf0 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -1054,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
     TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
   });
   m.def("TFE_CollectiveOpsCheckPeerHealth",
-        [](const py::handle& ctx, const char* task) {
+        [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
           tensorflow::Safe_TF_StatusPtr status =
               tensorflow::make_safe(TF_NewStatus());
           TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
-                                           task, status.get());
+                                           task, timeout_in_ms, status.get());
           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
         });
   m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);