From c2b5bebd70ca3cc69411ce0385ee4d42d887f795 Mon Sep 17 00:00:00 2001 From: Ran Chen <crccw@google.com> Date: Wed, 21 Oct 2020 12:47:34 -0700 Subject: [PATCH] Set a timeout to check health RPC fail_fast is ignored by default in OSS, so we need to set timeout to the RPC to avoid it taking forever if a remote worker is down. PiperOrigin-RevId: 338320405 Change-Id: I154e3ebbcd3a224a0caa3a39a0a7f9927e3ea220 --- tensorflow/c/c_api_experimental.cc | 8 ++++---- tensorflow/c/c_api_experimental.h | 6 +++--- .../core/common_runtime/collective_rma_local.cc | 1 + .../core/common_runtime/collective_rma_local.h | 2 +- .../common_runtime/collective_rma_local_test.cc | 9 +++++---- tensorflow/core/distributed_runtime/BUILD | 1 + ...collective_param_resolver_distributed_test.cc | 2 +- .../collective_rma_distributed.cc | 15 ++++++++++----- .../collective_rma_distributed.h | 2 +- .../collective_rma_distributed_test.cc | 12 ++++++------ .../core/distributed_runtime/remote_device.cc | 3 ++- tensorflow/core/distributed_runtime/rpc/BUILD | 1 + .../rpc/grpc_remote_worker.cc | 5 +++-- .../core/distributed_runtime/rpc/grpc_state.h | 3 ++- tensorflow/core/distributed_runtime/test_utils.h | 2 +- tensorflow/core/distributed_runtime/worker.cc | 2 +- tensorflow/core/distributed_runtime/worker.h | 2 +- .../core/distributed_runtime/worker_interface.h | 5 +++-- tensorflow/core/framework/collective.h | 2 +- .../distribute/collective_all_reduce_strategy.py | 9 ++++++--- .../python/distribute/integration_test/BUILD | 1 - .../integration_test/mwms_peer_failure_test.py | 2 ++ tensorflow/python/eager/context.py | 6 ++++-- .../collective_ops_multi_worker_test.py | 16 ++++++++-------- tensorflow/python/tfe_wrapper.cc | 4 ++-- 25 files changed, 70 insertions(+), 51 deletions(-) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 81fb9d1a2b8..0d188aa5ee0 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, collective_executor_handle->get()->StartAbort(status->status); } -TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, - const char* task, - TF_Status* status) { +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status) { tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto collective_executor_handle = context->GetCollectiveExecutorHandle(); tensorflow::Notification done; collective_executor_handle->get()->remote_access()->CheckPeerHealth( - task, [&done, status](const Status& s) { + task, timeout_in_ms, [&done, status](const Status& s) { status->status = s; done.Notify(); }); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index c9c74f4e874..90e074d232f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, // Checks the health of collective ops peers. Explicit health check is needed in // multi worker collective ops to detect failures in the cluster. If a peer is // down, collective ops may hang. -TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, - const char* task, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status); // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc index 64ec7f03cdb..44175a042a7 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.cc +++ b/tensorflow/core/common_runtime/collective_rma_local.cc @@ -107,6 +107,7 @@ void CollectiveRemoteAccessLocal::PostToPeer( } void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task, + int64 timeout_in_ms, const StatusCallback& done) { // Assume local devices are always healthy. done(errors::Internal( diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h index dbb10f41a2d..fb4ddf178e5 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.h +++ b/tensorflow/core/common_runtime/collective_rma_local.h @@ -55,7 +55,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess { CancellationManager* cancellation_manager, const StatusCallback& done) override; - void CheckPeerHealth(const string& peer_task, + void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) override; BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; } diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc index be15e0eeff1..30f6e372606 100644 --- a/tensorflow/core/common_runtime/collective_rma_local_test.cc +++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc @@ -156,10 +156,11 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) { Status status; Notification done; - rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) { - status = s; - done.Notify(); - }); + rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0, + [&status, &done](const Status& s) { + status = s; + done.Notify(); + }); done.WaitForNotification(); EXPECT_TRUE(errors::IsInternal(status)); } diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index d3030f4ca0b..a8681af3f7f 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -531,6 +531,7 @@ cc_library( srcs = ["collective_rma_distributed.cc"], hdrs = ["collective_rma_distributed.h"], deps = [ + ":call_options", ":cancellable_call", ":request_id", ":worker_cache", diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index f08f7a3275d..1c62b17fe54 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface { CollectiveParamResolverDistributed* cpres) : name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {} - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { std::vector<DeviceAttributes> dev_attr; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index f0e30b2cbe5..0f35af48e61 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" @@ -179,7 +180,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( } void CollectiveRemoteAccessDistributed::CheckPeerHealth( - const string& peer_task, const StatusCallback& done) { + const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) { if (peer_task == task_name_) { // Fast path if the peer is the worker itself. done(Status::OK()); @@ -196,13 +197,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( "valid form is /job:xxx/replica:0/task:N")); return; } + auto opts = new CallOptions(); + opts->SetTimeout(timeout_in_ms); auto req = new GetStatusRequest(); auto resp = new GetStatusResponse(); - // We're not using Cancellable call because GetStatusAsync doesn't support - // cancellation yet. + // Note that fail_fast is not always respected, so we set a timeout as well. + // We're not using CancellableCall since check health shouldn't need to be + // cancelled. wi->GetStatusAsync( - req, resp, /*fail_fast*/ true, - [this, req, resp, wi, peer_task, done](Status s) { + opts, req, resp, /*fail_fast*/ true, + [this, opts, req, resp, wi, peer_task, done](Status s) { std::vector<DeviceAttributes> cached_attrs; if (s.ok()) { s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs); @@ -227,6 +231,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( // first collective. s = Status::OK(); } + delete opts; delete req; delete resp; worker_cache_->ReleaseWorker(peer_task, wi); diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h index 58065284f72..f17c6360cf7 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h @@ -45,7 +45,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { CancellationManager* cancellation_manager, const StatusCallback& done) override; - void CheckPeerHealth(const string& peer_task, + void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) override; void StartAbort(const Status& s) override; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc index 8ccdde064de..74282beff1f 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface { // worker is supposed to have. BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; } - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { if (is_failed_) { @@ -459,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -472,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -488,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -505,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); @@ -520,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) { Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( - "/job:worker/replica:0/task:1", + "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, [&check_health_status, &check_health_done](const Status s) { check_health_status = s; check_health_done.Notify(); diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index 05a9072894e..bb9b074858a 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, } } }; - wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb); + wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp, + /*fail_fast=*/false, cb); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 14a358e8ac2..97dc8257750 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -94,6 +94,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache_logger", "//tensorflow/core/distributed_runtime:worker_interface", diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index d529abef36c..986ae6adf78 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -20,6 +20,7 @@ limitations under the License. #include "grpcpp/generic/generic_stub.h" #include "grpcpp/grpcpp.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface { ~GrpcRemoteWorker() override {} - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { - IssueRequest(request, response, getstatus_, std::move(done), nullptr, + IssueRequest(request, response, getstatus_, std::move(done), call_opts, fail_fast); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index c6e08f9c1a7..d0e67cdcd57 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -84,7 +84,8 @@ class RPCState : public GrpcClientCQTag { return false; } }(), - /*timeout_in_ms=*/0, max_retries, target) { + (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries, + target) { } template <typename Request> diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index cec09775469..dc9badfedef 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -30,7 +30,7 @@ namespace tensorflow { // testing. class TestWorkerInterface : public WorkerInterface { public: - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override { done(errors::Unimplemented("GetStatusAsync")); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index c4dc51ce47d..be14a58ca49 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) { StatusGroup::ConfigureLogHistory(); } -void Worker::GetStatusAsync(const GetStatusRequest* request, +void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) { const DeviceMgr* dm = env_->device_mgr; diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 273335ff36f..e280cf2447d 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -45,7 +45,7 @@ class Worker : public WorkerInterface { Worker(WorkerEnv* env); virtual ~Worker() {} - void GetStatusAsync(const GetStatusRequest* request, + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 9492d1cd31b..7b759eef95b 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -36,7 +36,8 @@ class TensorResponse; // Interface for talking with the TensorFlow Worker service. class WorkerInterface { public: - virtual void GetStatusAsync(const GetStatusRequest* request, + virtual void GetStatusAsync(CallOptions* opts, + const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) = 0; @@ -132,7 +133,7 @@ class WorkerInterface { GetStatusResponse* response) { Status ret; Notification n; - GetStatusAsync(request, response, /*fail_fast=*/true, + GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true, [&ret, &n](const Status& s) { ret = s; n.Notify(); diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index b9a05b64b09..cd4c28e1d2f 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -283,7 +283,7 @@ class CollectiveRemoteAccess { // Checks the health of a collective peer. It probes the peer to see if it is // alive. Note that if a peer has restarted, it's considered a different one, // so CheckPeerHealth fails. - virtual void CheckPeerHealth(const string& peer_task, + virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) = 0; virtual BufRendezvous* buf_rendezvous() = 0; diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index d08df3ed7fe..37a440bf46e 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -309,6 +309,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): _check_health_initial_timeout = 0 # Times to retry before considering the peer is down. _check_health_retry_limit = 3 + # Timeout in seconds the each check health. + _check_health_timeout = 10 def __init__(self, container_strategy, cluster_resolver, communication_options): @@ -780,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): while True: attempts += 1 try: - context.context().check_collective_ops_peer_health(peer) + context.context().check_collective_ops_peer_health( + peer, timeout_in_ms=self._check_health_timeout * 1000) # If check_collective_ops_peer_health doesn't raise an Exception, # the peer is healthy. break - except (errors.UnavailableError, - errors.FailedPreconditionError) as e: + except (errors.UnavailableError, errors.FailedPreconditionError, + errors.DeadlineExceededError) as e: # TODO(b/151232436): Always raise UnavailableError when a peer # fails. Now there could be many kinds of errors: # - Unavailable: when the peer is not reachable, e.g. it's down. diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD index 16736501d1d..b19710e17b3 100644 --- a/tensorflow/python/distribute/integration_test/BUILD +++ b/tensorflow/python/distribute/integration_test/BUILD @@ -36,7 +36,6 @@ cuda_py_test( shard_count = 2, tags = [ "multi_and_single_gpu", - "no_oss", # TODO(b/170838851): UnavailableError: Connection reset by peer ], deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py index 02dee6f6adb..283a76300a9 100644 --- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py +++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py @@ -37,6 +37,8 @@ from tensorflow.python.eager import test mwms_lib.CollectiveAllReduceExtended._enable_check_health = True mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3 mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0 +# This is needed for OSS, which issues all RPCs with fail_fast=false by default. +mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1 def get_attempt(strategy, attempts): diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 026dbce321d..ac7f7e2cfa1 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -748,7 +748,7 @@ class Context(object): self.ensure_initialized() pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) - def check_collective_ops_peer_health(self, task): + def check_collective_ops_peer_health(self, task, timeout_in_ms): """Check collective peer health. This probes each task to see if they're still alive. Note that restarted @@ -758,6 +758,7 @@ class Context(object): Args: task: a task string, must be in the format of /job:xxx/replica:0/task:N. + timeout_in_ms: an integer, the timeout. If zero, there's no timeout. Raises: tf.errors.UnavailableError: when a peer is down. @@ -766,7 +767,8 @@ class Context(object): tf.errors.InvalidArgumentError: when the task string is invalid. """ self.ensure_initialized() - pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task) + pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, + timeout_in_ms) @property def _handle(self): diff --git a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py index 47e296a384b..9dbedc771e1 100644 --- a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py @@ -60,7 +60,8 @@ class CollectiveOpTest(test.TestCase): "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1", ]: - context.context().check_collective_ops_peer_health(task) + context.context().check_collective_ops_peer_health( + task, timeout_in_ms=1000) except errors.UnavailableError: continue break @@ -73,18 +74,16 @@ class CollectiveOpTest(test.TestCase): def testCheckHealthPeerDown(self): - if multi_process_runner.is_oss(): - self.skipTest("TODO(b/170838845): Failing in OSS") - def worker_fn(): enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) context.context().check_collective_ops_peer_health( - "/job:worker/replica:0/task:1",) + "/job:worker/replica:0/task:1", timeout_in_ms=1000) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) mpr.start_single_process("worker", 0) - with self.assertRaises(errors.UnavailableError): + with self.assertRaises( + (errors.UnavailableError, errors.DeadlineExceededError)): mpr.join() def testCheckHealthPeerRestart(self): @@ -112,7 +111,7 @@ class CollectiveOpTest(test.TestCase): time.sleep(1) try: context.context().check_collective_ops_peer_health( - "/job:worker/replica:0/task:0",) + "/job:worker/replica:0/task:0", timeout_in_ms=1000) except errors.UnavailableError: pass except errors.FailedPreconditionError: @@ -129,7 +128,8 @@ class CollectiveOpTest(test.TestCase): def worker_fn(): enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) - context.context().check_collective_ops_peer_health("localhost:12345",) + context.context().check_collective_ops_peer_health( + "localhost:12345", timeout_in_ms=1000) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 980695f28bb..1c46c228cf0 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1054,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) { TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get()); }); m.def("TFE_CollectiveOpsCheckPeerHealth", - [](const py::handle& ctx, const char* task) { + [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx), - task, status.get()); + task, timeout_in_ms, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); }); m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);