Set a timeout to check health RPC
fail_fast is ignored by default in OSS, so we need to set timeout to the RPC to avoid it taking forever if a remote worker is down. PiperOrigin-RevId: 338320405 Change-Id: I154e3ebbcd3a224a0caa3a39a0a7f9927e3ea220
This commit is contained in:
parent
9fe37b34f5
commit
c2b5bebd70
@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
|||||||
collective_executor_handle->get()->StartAbort(status->status);
|
collective_executor_handle->get()->StartAbort(status->status);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||||
const char* task,
|
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::EagerContext* context =
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||||
tensorflow::Notification done;
|
tensorflow::Notification done;
|
||||||
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
||||||
task, [&done, status](const Status& s) {
|
task, timeout_in_ms, [&done, status](const Status& s) {
|
||||||
status->status = s;
|
status->status = s;
|
||||||
done.Notify();
|
done.Notify();
|
||||||
});
|
});
|
||||||
|
@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
|||||||
// Checks the health of collective ops peers. Explicit health check is needed in
|
// Checks the health of collective ops peers. Explicit health check is needed in
|
||||||
// multi worker collective ops to detect failures in the cluster. If a peer is
|
// multi worker collective ops to detect failures in the cluster. If a peer is
|
||||||
// down, collective ops may hang.
|
// down, collective ops may hang.
|
||||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||||
const char* task,
|
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Information about the shape of a Tensor and its type.
|
// Information about the shape of a Tensor and its type.
|
||||||
struct TF_ShapeAndType {
|
struct TF_ShapeAndType {
|
||||||
|
@ -107,6 +107,7 @@ void CollectiveRemoteAccessLocal::PostToPeer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
|
void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
|
||||||
|
int64 timeout_in_ms,
|
||||||
const StatusCallback& done) {
|
const StatusCallback& done) {
|
||||||
// Assume local devices are always healthy.
|
// Assume local devices are always healthy.
|
||||||
done(errors::Internal(
|
done(errors::Internal(
|
||||||
|
@ -55,7 +55,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
|
|||||||
CancellationManager* cancellation_manager,
|
CancellationManager* cancellation_manager,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
void CheckPeerHealth(const string& peer_task,
|
void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
|
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
|
||||||
|
@ -156,10 +156,11 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
|
|||||||
TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
|
TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
|
||||||
Status status;
|
Status status;
|
||||||
Notification done;
|
Notification done;
|
||||||
rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
|
rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0,
|
||||||
status = s;
|
[&status, &done](const Status& s) {
|
||||||
done.Notify();
|
status = s;
|
||||||
});
|
done.Notify();
|
||||||
|
});
|
||||||
done.WaitForNotification();
|
done.WaitForNotification();
|
||||||
EXPECT_TRUE(errors::IsInternal(status));
|
EXPECT_TRUE(errors::IsInternal(status));
|
||||||
}
|
}
|
||||||
|
@ -531,6 +531,7 @@ cc_library(
|
|||||||
srcs = ["collective_rma_distributed.cc"],
|
srcs = ["collective_rma_distributed.cc"],
|
||||||
hdrs = ["collective_rma_distributed.h"],
|
hdrs = ["collective_rma_distributed.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":call_options",
|
||||||
":cancellable_call",
|
":cancellable_call",
|
||||||
":request_id",
|
":request_id",
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
|
@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface {
|
|||||||
CollectiveParamResolverDistributed* cpres)
|
CollectiveParamResolverDistributed* cpres)
|
||||||
: name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
|
: name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
|
||||||
|
|
||||||
void GetStatusAsync(const GetStatusRequest* request,
|
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
std::vector<DeviceAttributes> dev_attr;
|
std::vector<DeviceAttributes> dev_attr;
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
@ -179,7 +180,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||||
const string& peer_task, const StatusCallback& done) {
|
const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) {
|
||||||
if (peer_task == task_name_) {
|
if (peer_task == task_name_) {
|
||||||
// Fast path if the peer is the worker itself.
|
// Fast path if the peer is the worker itself.
|
||||||
done(Status::OK());
|
done(Status::OK());
|
||||||
@ -196,13 +197,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
|||||||
"valid form is /job:xxx/replica:0/task:N"));
|
"valid form is /job:xxx/replica:0/task:N"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
auto opts = new CallOptions();
|
||||||
|
opts->SetTimeout(timeout_in_ms);
|
||||||
auto req = new GetStatusRequest();
|
auto req = new GetStatusRequest();
|
||||||
auto resp = new GetStatusResponse();
|
auto resp = new GetStatusResponse();
|
||||||
// We're not using Cancellable call because GetStatusAsync doesn't support
|
// Note that fail_fast is not always respected, so we set a timeout as well.
|
||||||
// cancellation yet.
|
// We're not using CancellableCall since check health shouldn't need to be
|
||||||
|
// cancelled.
|
||||||
wi->GetStatusAsync(
|
wi->GetStatusAsync(
|
||||||
req, resp, /*fail_fast*/ true,
|
opts, req, resp, /*fail_fast*/ true,
|
||||||
[this, req, resp, wi, peer_task, done](Status s) {
|
[this, opts, req, resp, wi, peer_task, done](Status s) {
|
||||||
std::vector<DeviceAttributes> cached_attrs;
|
std::vector<DeviceAttributes> cached_attrs;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
|
s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
|
||||||
@ -227,6 +231,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
|||||||
// first collective.
|
// first collective.
|
||||||
s = Status::OK();
|
s = Status::OK();
|
||||||
}
|
}
|
||||||
|
delete opts;
|
||||||
delete req;
|
delete req;
|
||||||
delete resp;
|
delete resp;
|
||||||
worker_cache_->ReleaseWorker(peer_task, wi);
|
worker_cache_->ReleaseWorker(peer_task, wi);
|
||||||
|
@ -45,7 +45,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
|||||||
CancellationManager* cancellation_manager,
|
CancellationManager* cancellation_manager,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
void CheckPeerHealth(const string& peer_task,
|
void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
void StartAbort(const Status& s) override;
|
void StartAbort(const Status& s) override;
|
||||||
|
@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface {
|
|||||||
// worker is supposed to have.
|
// worker is supposed to have.
|
||||||
BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
|
BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
|
||||||
|
|
||||||
void GetStatusAsync(const GetStatusRequest* request,
|
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
if (is_failed_) {
|
if (is_failed_) {
|
||||||
@ -459,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
|
|||||||
Status check_health_status;
|
Status check_health_status;
|
||||||
Notification check_health_done;
|
Notification check_health_done;
|
||||||
rma_->CheckPeerHealth(
|
rma_->CheckPeerHealth(
|
||||||
"/job:worker/replica:0/task:1",
|
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||||
[&check_health_status, &check_health_done](const Status s) {
|
[&check_health_status, &check_health_done](const Status s) {
|
||||||
check_health_status = s;
|
check_health_status = s;
|
||||||
check_health_done.Notify();
|
check_health_done.Notify();
|
||||||
@ -472,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
|
|||||||
Status check_health_status;
|
Status check_health_status;
|
||||||
Notification check_health_done;
|
Notification check_health_done;
|
||||||
rma_->CheckPeerHealth(
|
rma_->CheckPeerHealth(
|
||||||
"/job:worker/replica:0/task:1",
|
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||||
[&check_health_status, &check_health_done](const Status s) {
|
[&check_health_status, &check_health_done](const Status s) {
|
||||||
check_health_status = s;
|
check_health_status = s;
|
||||||
check_health_done.Notify();
|
check_health_done.Notify();
|
||||||
@ -488,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
|
|||||||
Status check_health_status;
|
Status check_health_status;
|
||||||
Notification check_health_done;
|
Notification check_health_done;
|
||||||
rma_->CheckPeerHealth(
|
rma_->CheckPeerHealth(
|
||||||
"/job:worker/replica:0/task:1",
|
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||||
[&check_health_status, &check_health_done](const Status s) {
|
[&check_health_status, &check_health_done](const Status s) {
|
||||||
check_health_status = s;
|
check_health_status = s;
|
||||||
check_health_done.Notify();
|
check_health_done.Notify();
|
||||||
@ -505,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
|
|||||||
Status check_health_status;
|
Status check_health_status;
|
||||||
Notification check_health_done;
|
Notification check_health_done;
|
||||||
rma_->CheckPeerHealth(
|
rma_->CheckPeerHealth(
|
||||||
"/job:worker/replica:0/task:1",
|
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||||
[&check_health_status, &check_health_done](const Status s) {
|
[&check_health_status, &check_health_done](const Status s) {
|
||||||
check_health_status = s;
|
check_health_status = s;
|
||||||
check_health_done.Notify();
|
check_health_done.Notify();
|
||||||
@ -520,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
|
|||||||
Status check_health_status;
|
Status check_health_status;
|
||||||
Notification check_health_done;
|
Notification check_health_done;
|
||||||
rma_->CheckPeerHealth(
|
rma_->CheckPeerHealth(
|
||||||
"/job:worker/replica:0/task:1",
|
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||||
[&check_health_status, &check_health_done](const Status s) {
|
[&check_health_status, &check_health_done](const Status s) {
|
||||||
check_health_status = s;
|
check_health_status = s;
|
||||||
check_health_done.Notify();
|
check_health_done.Notify();
|
||||||
|
@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb);
|
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
|
||||||
|
/*fail_fast=*/false, cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -94,6 +94,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/distributed_runtime:call_options",
|
||||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "grpcpp/generic/generic_stub.h"
|
#include "grpcpp/generic/generic_stub.h"
|
||||||
#include "grpcpp/grpcpp.h"
|
#include "grpcpp/grpcpp.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
|
|
||||||
~GrpcRemoteWorker() override {}
|
~GrpcRemoteWorker() override {}
|
||||||
|
|
||||||
void GetStatusAsync(const GetStatusRequest* request,
|
void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
IssueRequest(request, response, getstatus_, std::move(done), nullptr,
|
IssueRequest(request, response, getstatus_, std::move(done), call_opts,
|
||||||
fail_fast);
|
fail_fast);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +84,8 @@ class RPCState : public GrpcClientCQTag {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
/*timeout_in_ms=*/0, max_retries, target) {
|
(call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
|
||||||
|
target) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Request>
|
template <typename Request>
|
||||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
|||||||
// testing.
|
// testing.
|
||||||
class TestWorkerInterface : public WorkerInterface {
|
class TestWorkerInterface : public WorkerInterface {
|
||||||
public:
|
public:
|
||||||
void GetStatusAsync(const GetStatusRequest* request,
|
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
done(errors::Unimplemented("GetStatusAsync"));
|
done(errors::Unimplemented("GetStatusAsync"));
|
||||||
|
@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
|
|||||||
StatusGroup::ConfigureLogHistory();
|
StatusGroup::ConfigureLogHistory();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::GetStatusAsync(const GetStatusRequest* request,
|
void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
const DeviceMgr* dm = env_->device_mgr;
|
const DeviceMgr* dm = env_->device_mgr;
|
||||||
|
@ -45,7 +45,7 @@ class Worker : public WorkerInterface {
|
|||||||
Worker(WorkerEnv* env);
|
Worker(WorkerEnv* env);
|
||||||
virtual ~Worker() {}
|
virtual ~Worker() {}
|
||||||
|
|
||||||
void GetStatusAsync(const GetStatusRequest* request,
|
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) override;
|
StatusCallback done) override;
|
||||||
|
|
||||||
|
@ -36,7 +36,8 @@ class TensorResponse;
|
|||||||
// Interface for talking with the TensorFlow Worker service.
|
// Interface for talking with the TensorFlow Worker service.
|
||||||
class WorkerInterface {
|
class WorkerInterface {
|
||||||
public:
|
public:
|
||||||
virtual void GetStatusAsync(const GetStatusRequest* request,
|
virtual void GetStatusAsync(CallOptions* opts,
|
||||||
|
const GetStatusRequest* request,
|
||||||
GetStatusResponse* response, bool fail_fast,
|
GetStatusResponse* response, bool fail_fast,
|
||||||
StatusCallback done) = 0;
|
StatusCallback done) = 0;
|
||||||
|
|
||||||
@ -132,7 +133,7 @@ class WorkerInterface {
|
|||||||
GetStatusResponse* response) {
|
GetStatusResponse* response) {
|
||||||
Status ret;
|
Status ret;
|
||||||
Notification n;
|
Notification n;
|
||||||
GetStatusAsync(request, response, /*fail_fast=*/true,
|
GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
|
||||||
[&ret, &n](const Status& s) {
|
[&ret, &n](const Status& s) {
|
||||||
ret = s;
|
ret = s;
|
||||||
n.Notify();
|
n.Notify();
|
||||||
|
@ -283,7 +283,7 @@ class CollectiveRemoteAccess {
|
|||||||
// Checks the health of a collective peer. It probes the peer to see if it is
|
// Checks the health of a collective peer. It probes the peer to see if it is
|
||||||
// alive. Note that if a peer has restarted, it's considered a different one,
|
// alive. Note that if a peer has restarted, it's considered a different one,
|
||||||
// so CheckPeerHealth fails.
|
// so CheckPeerHealth fails.
|
||||||
virtual void CheckPeerHealth(const string& peer_task,
|
virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
|
||||||
const StatusCallback& done) = 0;
|
const StatusCallback& done) = 0;
|
||||||
|
|
||||||
virtual BufRendezvous* buf_rendezvous() = 0;
|
virtual BufRendezvous* buf_rendezvous() = 0;
|
||||||
|
@ -309,6 +309,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
_check_health_initial_timeout = 0
|
_check_health_initial_timeout = 0
|
||||||
# Times to retry before considering the peer is down.
|
# Times to retry before considering the peer is down.
|
||||||
_check_health_retry_limit = 3
|
_check_health_retry_limit = 3
|
||||||
|
# Timeout in seconds the each check health.
|
||||||
|
_check_health_timeout = 10
|
||||||
|
|
||||||
def __init__(self, container_strategy, cluster_resolver,
|
def __init__(self, container_strategy, cluster_resolver,
|
||||||
communication_options):
|
communication_options):
|
||||||
@ -780,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
while True:
|
while True:
|
||||||
attempts += 1
|
attempts += 1
|
||||||
try:
|
try:
|
||||||
context.context().check_collective_ops_peer_health(peer)
|
context.context().check_collective_ops_peer_health(
|
||||||
|
peer, timeout_in_ms=self._check_health_timeout * 1000)
|
||||||
# If check_collective_ops_peer_health doesn't raise an Exception,
|
# If check_collective_ops_peer_health doesn't raise an Exception,
|
||||||
# the peer is healthy.
|
# the peer is healthy.
|
||||||
break
|
break
|
||||||
except (errors.UnavailableError,
|
except (errors.UnavailableError, errors.FailedPreconditionError,
|
||||||
errors.FailedPreconditionError) as e:
|
errors.DeadlineExceededError) as e:
|
||||||
# TODO(b/151232436): Always raise UnavailableError when a peer
|
# TODO(b/151232436): Always raise UnavailableError when a peer
|
||||||
# fails. Now there could be many kinds of errors:
|
# fails. Now there could be many kinds of errors:
|
||||||
# - Unavailable: when the peer is not reachable, e.g. it's down.
|
# - Unavailable: when the peer is not reachable, e.g. it's down.
|
||||||
|
@ -36,7 +36,6 @@ cuda_py_test(
|
|||||||
shard_count = 2,
|
shard_count = 2,
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
"no_oss", # TODO(b/170838851): UnavailableError: Connection reset by peer
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
|
@ -37,6 +37,8 @@ from tensorflow.python.eager import test
|
|||||||
mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
|
mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
|
||||||
mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
|
mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
|
||||||
mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
|
mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
|
||||||
|
# This is needed for OSS, which issues all RPCs with fail_fast=false by default.
|
||||||
|
mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1
|
||||||
|
|
||||||
|
|
||||||
def get_attempt(strategy, attempts):
|
def get_attempt(strategy, attempts):
|
||||||
|
@ -748,7 +748,7 @@ class Context(object):
|
|||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
|
pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
|
||||||
|
|
||||||
def check_collective_ops_peer_health(self, task):
|
def check_collective_ops_peer_health(self, task, timeout_in_ms):
|
||||||
"""Check collective peer health.
|
"""Check collective peer health.
|
||||||
|
|
||||||
This probes each task to see if they're still alive. Note that restarted
|
This probes each task to see if they're still alive. Note that restarted
|
||||||
@ -758,6 +758,7 @@ class Context(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: a task string, must be in the format of /job:xxx/replica:0/task:N.
|
task: a task string, must be in the format of /job:xxx/replica:0/task:N.
|
||||||
|
timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
tf.errors.UnavailableError: when a peer is down.
|
tf.errors.UnavailableError: when a peer is down.
|
||||||
@ -766,7 +767,8 @@ class Context(object):
|
|||||||
tf.errors.InvalidArgumentError: when the task string is invalid.
|
tf.errors.InvalidArgumentError: when the task string is invalid.
|
||||||
"""
|
"""
|
||||||
self.ensure_initialized()
|
self.ensure_initialized()
|
||||||
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task)
|
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
|
||||||
|
timeout_in_ms)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _handle(self):
|
def _handle(self):
|
||||||
|
@ -60,7 +60,8 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
"/job:worker/replica:0/task:0",
|
"/job:worker/replica:0/task:0",
|
||||||
"/job:worker/replica:0/task:1",
|
"/job:worker/replica:0/task:1",
|
||||||
]:
|
]:
|
||||||
context.context().check_collective_ops_peer_health(task)
|
context.context().check_collective_ops_peer_health(
|
||||||
|
task, timeout_in_ms=1000)
|
||||||
except errors.UnavailableError:
|
except errors.UnavailableError:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
@ -73,18 +74,16 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
|
|
||||||
def testCheckHealthPeerDown(self):
|
def testCheckHealthPeerDown(self):
|
||||||
|
|
||||||
if multi_process_runner.is_oss():
|
|
||||||
self.skipTest("TODO(b/170838845): Failing in OSS")
|
|
||||||
|
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
|
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
|
||||||
context.context().check_collective_ops_peer_health(
|
context.context().check_collective_ops_peer_health(
|
||||||
"/job:worker/replica:0/task:1",)
|
"/job:worker/replica:0/task:1", timeout_in_ms=1000)
|
||||||
|
|
||||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||||
mpr.start_single_process("worker", 0)
|
mpr.start_single_process("worker", 0)
|
||||||
with self.assertRaises(errors.UnavailableError):
|
with self.assertRaises(
|
||||||
|
(errors.UnavailableError, errors.DeadlineExceededError)):
|
||||||
mpr.join()
|
mpr.join()
|
||||||
|
|
||||||
def testCheckHealthPeerRestart(self):
|
def testCheckHealthPeerRestart(self):
|
||||||
@ -112,7 +111,7 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
try:
|
try:
|
||||||
context.context().check_collective_ops_peer_health(
|
context.context().check_collective_ops_peer_health(
|
||||||
"/job:worker/replica:0/task:0",)
|
"/job:worker/replica:0/task:0", timeout_in_ms=1000)
|
||||||
except errors.UnavailableError:
|
except errors.UnavailableError:
|
||||||
pass
|
pass
|
||||||
except errors.FailedPreconditionError:
|
except errors.FailedPreconditionError:
|
||||||
@ -129,7 +128,8 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
|
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
|
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
|
||||||
context.context().check_collective_ops_peer_health("localhost:12345",)
|
context.context().check_collective_ops_peer_health(
|
||||||
|
"localhost:12345", timeout_in_ms=1000)
|
||||||
|
|
||||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||||
|
@ -1054,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
|
TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
|
||||||
});
|
});
|
||||||
m.def("TFE_CollectiveOpsCheckPeerHealth",
|
m.def("TFE_CollectiveOpsCheckPeerHealth",
|
||||||
[](const py::handle& ctx, const char* task) {
|
[](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
|
TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
|
||||||
task, status.get());
|
task, timeout_in_ms, status.get());
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
});
|
});
|
||||||
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
|
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user