Set a timeout to check health RPC

fail_fast is ignored by default in OSS, so we need to set timeout to the RPC to avoid it taking forever if a remote worker is down.

PiperOrigin-RevId: 338320405
Change-Id: I154e3ebbcd3a224a0caa3a39a0a7f9927e3ea220
This commit is contained in:
Ran Chen 2020-10-21 12:47:34 -07:00 committed by TensorFlower Gardener
parent 9fe37b34f5
commit c2b5bebd70
25 changed files with 70 additions and 51 deletions

View File

@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status); collective_executor_handle->get()->StartAbort(status->status);
} }
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
const char* task, TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle(); auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done; tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth( collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) { task, timeout_in_ms, [&done, status](const Status& s) {
status->status = s; status->status = s;
done.Notify(); done.Notify();
}); });

View File

@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
// Checks the health of collective ops peers. Explicit health check is needed in // Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is // multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang. // down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx, TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
const char* task, TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status); TF_Status* status);
// Information about the shape of a Tensor and its type. // Information about the shape of a Tensor and its type.
struct TF_ShapeAndType { struct TF_ShapeAndType {

View File

@ -107,6 +107,7 @@ void CollectiveRemoteAccessLocal::PostToPeer(
} }
void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task, void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
int64 timeout_in_ms,
const StatusCallback& done) { const StatusCallback& done) {
// Assume local devices are always healthy. // Assume local devices are always healthy.
done(errors::Internal( done(errors::Internal(

View File

@ -55,7 +55,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
const StatusCallback& done) override; const StatusCallback& done) override;
void CheckPeerHealth(const string& peer_task, void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
const StatusCallback& done) override; const StatusCallback& done) override;
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; } BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }

View File

@ -156,10 +156,11 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) { TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
Status status; Status status;
Notification done; Notification done;
rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) { rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0,
status = s; [&status, &done](const Status& s) {
done.Notify(); status = s;
}); done.Notify();
});
done.WaitForNotification(); done.WaitForNotification();
EXPECT_TRUE(errors::IsInternal(status)); EXPECT_TRUE(errors::IsInternal(status));
} }

View File

@ -531,6 +531,7 @@ cc_library(
srcs = ["collective_rma_distributed.cc"], srcs = ["collective_rma_distributed.cc"],
hdrs = ["collective_rma_distributed.h"], hdrs = ["collective_rma_distributed.h"],
deps = [ deps = [
":call_options",
":cancellable_call", ":cancellable_call",
":request_id", ":request_id",
":worker_cache", ":worker_cache",

View File

@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface {
CollectiveParamResolverDistributed* cpres) CollectiveParamResolverDistributed* cpres)
: name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {} : name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
void GetStatusAsync(const GetStatusRequest* request, void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) override { StatusCallback done) override {
std::vector<DeviceAttributes> dev_attr; std::vector<DeviceAttributes> dev_attr;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
@ -179,7 +180,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
} }
void CollectiveRemoteAccessDistributed::CheckPeerHealth( void CollectiveRemoteAccessDistributed::CheckPeerHealth(
const string& peer_task, const StatusCallback& done) { const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) {
if (peer_task == task_name_) { if (peer_task == task_name_) {
// Fast path if the peer is the worker itself. // Fast path if the peer is the worker itself.
done(Status::OK()); done(Status::OK());
@ -196,13 +197,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
"valid form is /job:xxx/replica:0/task:N")); "valid form is /job:xxx/replica:0/task:N"));
return; return;
} }
auto opts = new CallOptions();
opts->SetTimeout(timeout_in_ms);
auto req = new GetStatusRequest(); auto req = new GetStatusRequest();
auto resp = new GetStatusResponse(); auto resp = new GetStatusResponse();
// We're not using Cancellable call because GetStatusAsync doesn't support // Note that fail_fast is not always respected, so we set a timeout as well.
// cancellation yet. // We're not using CancellableCall since check health shouldn't need to be
// cancelled.
wi->GetStatusAsync( wi->GetStatusAsync(
req, resp, /*fail_fast*/ true, opts, req, resp, /*fail_fast*/ true,
[this, req, resp, wi, peer_task, done](Status s) { [this, opts, req, resp, wi, peer_task, done](Status s) {
std::vector<DeviceAttributes> cached_attrs; std::vector<DeviceAttributes> cached_attrs;
if (s.ok()) { if (s.ok()) {
s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs); s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
@ -227,6 +231,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
// first collective. // first collective.
s = Status::OK(); s = Status::OK();
} }
delete opts;
delete req; delete req;
delete resp; delete resp;
worker_cache_->ReleaseWorker(peer_task, wi); worker_cache_->ReleaseWorker(peer_task, wi);

View File

@ -45,7 +45,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
const StatusCallback& done) override; const StatusCallback& done) override;
void CheckPeerHealth(const string& peer_task, void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
const StatusCallback& done) override; const StatusCallback& done) override;
void StartAbort(const Status& s) override; void StartAbort(const Status& s) override;

View File

@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface {
// worker is supposed to have. // worker is supposed to have.
BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; } BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
void GetStatusAsync(const GetStatusRequest* request, void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) override { StatusCallback done) override {
if (is_failed_) { if (is_failed_) {
@ -459,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1", "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) { [&check_health_status, &check_health_done](const Status s) {
check_health_status = s; check_health_status = s;
check_health_done.Notify(); check_health_done.Notify();
@ -472,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1", "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) { [&check_health_status, &check_health_done](const Status s) {
check_health_status = s; check_health_status = s;
check_health_done.Notify(); check_health_done.Notify();
@ -488,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1", "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) { [&check_health_status, &check_health_done](const Status s) {
check_health_status = s; check_health_status = s;
check_health_done.Notify(); check_health_done.Notify();
@ -505,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1", "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) { [&check_health_status, &check_health_done](const Status s) {
check_health_status = s; check_health_status = s;
check_health_done.Notify(); check_health_done.Notify();
@ -520,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1", "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) { [&check_health_status, &check_health_done](const Status s) {
check_health_status = s; check_health_status = s;
check_health_done.Notify(); check_health_done.Notify();

View File

@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
} }
} }
}; };
wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb); wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
/*fail_fast=*/false, cb);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -94,6 +94,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache_logger", "//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core/distributed_runtime:worker_interface",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/generic/generic_stub.h" #include "grpcpp/generic/generic_stub.h"
#include "grpcpp/grpcpp.h" #include "grpcpp/grpcpp.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface {
~GrpcRemoteWorker() override {} ~GrpcRemoteWorker() override {}
void GetStatusAsync(const GetStatusRequest* request, void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) override { StatusCallback done) override {
IssueRequest(request, response, getstatus_, std::move(done), nullptr, IssueRequest(request, response, getstatus_, std::move(done), call_opts,
fail_fast); fail_fast);
} }

View File

@ -84,7 +84,8 @@ class RPCState : public GrpcClientCQTag {
return false; return false;
} }
}(), }(),
/*timeout_in_ms=*/0, max_retries, target) { (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
target) {
} }
template <typename Request> template <typename Request>

View File

@ -30,7 +30,7 @@ namespace tensorflow {
// testing. // testing.
class TestWorkerInterface : public WorkerInterface { class TestWorkerInterface : public WorkerInterface {
public: public:
void GetStatusAsync(const GetStatusRequest* request, void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) override { StatusCallback done) override {
done(errors::Unimplemented("GetStatusAsync")); done(errors::Unimplemented("GetStatusAsync"));

View File

@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
StatusGroup::ConfigureLogHistory(); StatusGroup::ConfigureLogHistory();
} }
void Worker::GetStatusAsync(const GetStatusRequest* request, void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) { StatusCallback done) {
const DeviceMgr* dm = env_->device_mgr; const DeviceMgr* dm = env_->device_mgr;

View File

@ -45,7 +45,7 @@ class Worker : public WorkerInterface {
Worker(WorkerEnv* env); Worker(WorkerEnv* env);
virtual ~Worker() {} virtual ~Worker() {}
void GetStatusAsync(const GetStatusRequest* request, void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) override; StatusCallback done) override;

View File

@ -36,7 +36,8 @@ class TensorResponse;
// Interface for talking with the TensorFlow Worker service. // Interface for talking with the TensorFlow Worker service.
class WorkerInterface { class WorkerInterface {
public: public:
virtual void GetStatusAsync(const GetStatusRequest* request, virtual void GetStatusAsync(CallOptions* opts,
const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast, GetStatusResponse* response, bool fail_fast,
StatusCallback done) = 0; StatusCallback done) = 0;
@ -132,7 +133,7 @@ class WorkerInterface {
GetStatusResponse* response) { GetStatusResponse* response) {
Status ret; Status ret;
Notification n; Notification n;
GetStatusAsync(request, response, /*fail_fast=*/true, GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
[&ret, &n](const Status& s) { [&ret, &n](const Status& s) {
ret = s; ret = s;
n.Notify(); n.Notify();

View File

@ -283,7 +283,7 @@ class CollectiveRemoteAccess {
// Checks the health of a collective peer. It probes the peer to see if it is // Checks the health of a collective peer. It probes the peer to see if it is
// alive. Note that if a peer has restarted, it's considered a different one, // alive. Note that if a peer has restarted, it's considered a different one,
// so CheckPeerHealth fails. // so CheckPeerHealth fails.
virtual void CheckPeerHealth(const string& peer_task, virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
const StatusCallback& done) = 0; const StatusCallback& done) = 0;
virtual BufRendezvous* buf_rendezvous() = 0; virtual BufRendezvous* buf_rendezvous() = 0;

View File

@ -309,6 +309,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
_check_health_initial_timeout = 0 _check_health_initial_timeout = 0
# Times to retry before considering the peer is down. # Times to retry before considering the peer is down.
_check_health_retry_limit = 3 _check_health_retry_limit = 3
# Timeout in seconds the each check health.
_check_health_timeout = 10
def __init__(self, container_strategy, cluster_resolver, def __init__(self, container_strategy, cluster_resolver,
communication_options): communication_options):
@ -780,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
while True: while True:
attempts += 1 attempts += 1
try: try:
context.context().check_collective_ops_peer_health(peer) context.context().check_collective_ops_peer_health(
peer, timeout_in_ms=self._check_health_timeout * 1000)
# If check_collective_ops_peer_health doesn't raise an Exception, # If check_collective_ops_peer_health doesn't raise an Exception,
# the peer is healthy. # the peer is healthy.
break break
except (errors.UnavailableError, except (errors.UnavailableError, errors.FailedPreconditionError,
errors.FailedPreconditionError) as e: errors.DeadlineExceededError) as e:
# TODO(b/151232436): Always raise UnavailableError when a peer # TODO(b/151232436): Always raise UnavailableError when a peer
# fails. Now there could be many kinds of errors: # fails. Now there could be many kinds of errors:
# - Unavailable: when the peer is not reachable, e.g. it's down. # - Unavailable: when the peer is not reachable, e.g. it's down.

View File

@ -36,7 +36,6 @@ cuda_py_test(
shard_count = 2, shard_count = 2,
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",
"no_oss", # TODO(b/170838851): UnavailableError: Connection reset by peer
], ],
deps = [ deps = [
"//tensorflow:tensorflow_py", "//tensorflow:tensorflow_py",

View File

@ -37,6 +37,8 @@ from tensorflow.python.eager import test
mwms_lib.CollectiveAllReduceExtended._enable_check_health = True mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3 mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0 mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
# This is needed for OSS, which issues all RPCs with fail_fast=false by default.
mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1
def get_attempt(strategy, attempts): def get_attempt(strategy, attempts):

View File

@ -748,7 +748,7 @@ class Context(object):
self.ensure_initialized() self.ensure_initialized()
pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
def check_collective_ops_peer_health(self, task): def check_collective_ops_peer_health(self, task, timeout_in_ms):
"""Check collective peer health. """Check collective peer health.
This probes each task to see if they're still alive. Note that restarted This probes each task to see if they're still alive. Note that restarted
@ -758,6 +758,7 @@ class Context(object):
Args: Args:
task: a task string, must be in the format of /job:xxx/replica:0/task:N. task: a task string, must be in the format of /job:xxx/replica:0/task:N.
timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
Raises: Raises:
tf.errors.UnavailableError: when a peer is down. tf.errors.UnavailableError: when a peer is down.
@ -766,7 +767,8 @@ class Context(object):
tf.errors.InvalidArgumentError: when the task string is invalid. tf.errors.InvalidArgumentError: when the task string is invalid.
""" """
self.ensure_initialized() self.ensure_initialized()
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task) pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
timeout_in_ms)
@property @property
def _handle(self): def _handle(self):

View File

@ -60,7 +60,8 @@ class CollectiveOpTest(test.TestCase):
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:0",
"/job:worker/replica:0/task:1", "/job:worker/replica:0/task:1",
]: ]:
context.context().check_collective_ops_peer_health(task) context.context().check_collective_ops_peer_health(
task, timeout_in_ms=1000)
except errors.UnavailableError: except errors.UnavailableError:
continue continue
break break
@ -73,18 +74,16 @@ class CollectiveOpTest(test.TestCase):
def testCheckHealthPeerDown(self): def testCheckHealthPeerDown(self):
if multi_process_runner.is_oss():
self.skipTest("TODO(b/170838845): Failing in OSS")
def worker_fn(): def worker_fn():
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
context.context().check_collective_ops_peer_health( context.context().check_collective_ops_peer_health(
"/job:worker/replica:0/task:1",) "/job:worker/replica:0/task:1", timeout_in_ms=1000)
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
mpr.start_single_process("worker", 0) mpr.start_single_process("worker", 0)
with self.assertRaises(errors.UnavailableError): with self.assertRaises(
(errors.UnavailableError, errors.DeadlineExceededError)):
mpr.join() mpr.join()
def testCheckHealthPeerRestart(self): def testCheckHealthPeerRestart(self):
@ -112,7 +111,7 @@ class CollectiveOpTest(test.TestCase):
time.sleep(1) time.sleep(1)
try: try:
context.context().check_collective_ops_peer_health( context.context().check_collective_ops_peer_health(
"/job:worker/replica:0/task:0",) "/job:worker/replica:0/task:0", timeout_in_ms=1000)
except errors.UnavailableError: except errors.UnavailableError:
pass pass
except errors.FailedPreconditionError: except errors.FailedPreconditionError:
@ -129,7 +128,8 @@ class CollectiveOpTest(test.TestCase):
def worker_fn(): def worker_fn():
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
context.context().check_collective_ops_peer_health("localhost:12345",) context.context().check_collective_ops_peer_health(
"localhost:12345", timeout_in_ms=1000)
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)

View File

@ -1054,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get()); TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
}); });
m.def("TFE_CollectiveOpsCheckPeerHealth", m.def("TFE_CollectiveOpsCheckPeerHealth",
[](const py::handle& ctx, const char* task) { [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
tensorflow::Safe_TF_StatusPtr status = tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus()); tensorflow::make_safe(TF_NewStatus());
TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx), TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
task, status.get()); task, timeout_in_ms, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}); });
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices); m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);