Set a timeout to check health RPC
fail_fast is ignored by default in OSS, so we need to set timeout to the RPC to avoid it taking forever if a remote worker is down. PiperOrigin-RevId: 338320405 Change-Id: I154e3ebbcd3a224a0caa3a39a0a7f9927e3ea220
This commit is contained in:
parent
9fe37b34f5
commit
c2b5bebd70
@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
collective_executor_handle->get()->StartAbort(status->status);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status) {
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
|
||||
tensorflow::Notification done;
|
||||
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
|
||||
task, [&done, status](const Status& s) {
|
||||
task, timeout_in_ms, [&done, status](const Status& s) {
|
||||
status->status = s;
|
||||
done.Notify();
|
||||
});
|
||||
|
@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
// Checks the health of collective ops peers. Explicit health check is needed in
|
||||
// multi worker collective ops to detect failures in the cluster. If a peer is
|
||||
// down, collective ops may hang.
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
|
||||
const char* task,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
|
||||
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
|
||||
TF_Status* status);
|
||||
|
||||
// Information about the shape of a Tensor and its type.
|
||||
struct TF_ShapeAndType {
|
||||
|
@ -107,6 +107,7 @@ void CollectiveRemoteAccessLocal::PostToPeer(
|
||||
}
|
||||
|
||||
void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
|
||||
int64 timeout_in_ms,
|
||||
const StatusCallback& done) {
|
||||
// Assume local devices are always healthy.
|
||||
done(errors::Internal(
|
||||
|
@ -55,7 +55,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
|
||||
CancellationManager* cancellation_manager,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void CheckPeerHealth(const string& peer_task,
|
||||
void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
|
||||
|
@ -156,10 +156,11 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
|
||||
TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
|
||||
Status status;
|
||||
Notification done;
|
||||
rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
|
||||
status = s;
|
||||
done.Notify();
|
||||
});
|
||||
rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0,
|
||||
[&status, &done](const Status& s) {
|
||||
status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
EXPECT_TRUE(errors::IsInternal(status));
|
||||
}
|
||||
|
@ -531,6 +531,7 @@ cc_library(
|
||||
srcs = ["collective_rma_distributed.cc"],
|
||||
hdrs = ["collective_rma_distributed.h"],
|
||||
deps = [
|
||||
":call_options",
|
||||
":cancellable_call",
|
||||
":request_id",
|
||||
":worker_cache",
|
||||
|
@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface {
|
||||
CollectiveParamResolverDistributed* cpres)
|
||||
: name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
|
||||
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) override {
|
||||
std::vector<DeviceAttributes> dev_attr;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
@ -179,7 +180,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
|
||||
}
|
||||
|
||||
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
const string& peer_task, const StatusCallback& done) {
|
||||
const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) {
|
||||
if (peer_task == task_name_) {
|
||||
// Fast path if the peer is the worker itself.
|
||||
done(Status::OK());
|
||||
@ -196,13 +197,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
"valid form is /job:xxx/replica:0/task:N"));
|
||||
return;
|
||||
}
|
||||
auto opts = new CallOptions();
|
||||
opts->SetTimeout(timeout_in_ms);
|
||||
auto req = new GetStatusRequest();
|
||||
auto resp = new GetStatusResponse();
|
||||
// We're not using Cancellable call because GetStatusAsync doesn't support
|
||||
// cancellation yet.
|
||||
// Note that fail_fast is not always respected, so we set a timeout as well.
|
||||
// We're not using CancellableCall since check health shouldn't need to be
|
||||
// cancelled.
|
||||
wi->GetStatusAsync(
|
||||
req, resp, /*fail_fast*/ true,
|
||||
[this, req, resp, wi, peer_task, done](Status s) {
|
||||
opts, req, resp, /*fail_fast*/ true,
|
||||
[this, opts, req, resp, wi, peer_task, done](Status s) {
|
||||
std::vector<DeviceAttributes> cached_attrs;
|
||||
if (s.ok()) {
|
||||
s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
|
||||
@ -227,6 +231,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
// first collective.
|
||||
s = Status::OK();
|
||||
}
|
||||
delete opts;
|
||||
delete req;
|
||||
delete resp;
|
||||
worker_cache_->ReleaseWorker(peer_task, wi);
|
||||
|
@ -45,7 +45,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
CancellationManager* cancellation_manager,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void CheckPeerHealth(const string& peer_task,
|
||||
void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void StartAbort(const Status& s) override;
|
||||
|
@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface {
|
||||
// worker is supposed to have.
|
||||
BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
|
||||
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) override {
|
||||
if (is_failed_) {
|
||||
@ -459,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
"/job:worker/replica:0/task:1",
|
||||
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||
[&check_health_status, &check_health_done](const Status s) {
|
||||
check_health_status = s;
|
||||
check_health_done.Notify();
|
||||
@ -472,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
"/job:worker/replica:0/task:1",
|
||||
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||
[&check_health_status, &check_health_done](const Status s) {
|
||||
check_health_status = s;
|
||||
check_health_done.Notify();
|
||||
@ -488,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
"/job:worker/replica:0/task:1",
|
||||
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||
[&check_health_status, &check_health_done](const Status s) {
|
||||
check_health_status = s;
|
||||
check_health_done.Notify();
|
||||
@ -505,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
"/job:worker/replica:0/task:1",
|
||||
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||
[&check_health_status, &check_health_done](const Status s) {
|
||||
check_health_status = s;
|
||||
check_health_done.Notify();
|
||||
@ -520,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
"/job:worker/replica:0/task:1",
|
||||
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
|
||||
[&check_health_status, &check_health_done](const Status s) {
|
||||
check_health_status = s;
|
||||
check_health_done.Notify();
|
||||
|
@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
|
||||
}
|
||||
}
|
||||
};
|
||||
wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb);
|
||||
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
|
||||
/*fail_fast=*/false, cb);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -94,6 +94,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "grpcpp/generic/generic_stub.h"
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
|
||||
~GrpcRemoteWorker() override {}
|
||||
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) override {
|
||||
IssueRequest(request, response, getstatus_, std::move(done), nullptr,
|
||||
IssueRequest(request, response, getstatus_, std::move(done), call_opts,
|
||||
fail_fast);
|
||||
}
|
||||
|
||||
|
@ -84,7 +84,8 @@ class RPCState : public GrpcClientCQTag {
|
||||
return false;
|
||||
}
|
||||
}(),
|
||||
/*timeout_in_ms=*/0, max_retries, target) {
|
||||
(call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
|
||||
target) {
|
||||
}
|
||||
|
||||
template <typename Request>
|
||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
||||
// testing.
|
||||
class TestWorkerInterface : public WorkerInterface {
|
||||
public:
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) override {
|
||||
done(errors::Unimplemented("GetStatusAsync"));
|
||||
|
@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
|
||||
StatusGroup::ConfigureLogHistory();
|
||||
}
|
||||
|
||||
void Worker::GetStatusAsync(const GetStatusRequest* request,
|
||||
void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) {
|
||||
const DeviceMgr* dm = env_->device_mgr;
|
||||
|
@ -45,7 +45,7 @@ class Worker : public WorkerInterface {
|
||||
Worker(WorkerEnv* env);
|
||||
virtual ~Worker() {}
|
||||
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) override;
|
||||
|
||||
|
@ -36,7 +36,8 @@ class TensorResponse;
|
||||
// Interface for talking with the TensorFlow Worker service.
|
||||
class WorkerInterface {
|
||||
public:
|
||||
virtual void GetStatusAsync(const GetStatusRequest* request,
|
||||
virtual void GetStatusAsync(CallOptions* opts,
|
||||
const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) = 0;
|
||||
|
||||
@ -132,7 +133,7 @@ class WorkerInterface {
|
||||
GetStatusResponse* response) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
GetStatusAsync(request, response, /*fail_fast=*/true,
|
||||
GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
|
||||
[&ret, &n](const Status& s) {
|
||||
ret = s;
|
||||
n.Notify();
|
||||
|
@ -283,7 +283,7 @@ class CollectiveRemoteAccess {
|
||||
// Checks the health of a collective peer. It probes the peer to see if it is
|
||||
// alive. Note that if a peer has restarted, it's considered a different one,
|
||||
// so CheckPeerHealth fails.
|
||||
virtual void CheckPeerHealth(const string& peer_task,
|
||||
virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
|
||||
const StatusCallback& done) = 0;
|
||||
|
||||
virtual BufRendezvous* buf_rendezvous() = 0;
|
||||
|
@ -309,6 +309,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
_check_health_initial_timeout = 0
|
||||
# Times to retry before considering the peer is down.
|
||||
_check_health_retry_limit = 3
|
||||
# Timeout in seconds the each check health.
|
||||
_check_health_timeout = 10
|
||||
|
||||
def __init__(self, container_strategy, cluster_resolver,
|
||||
communication_options):
|
||||
@ -780,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
while True:
|
||||
attempts += 1
|
||||
try:
|
||||
context.context().check_collective_ops_peer_health(peer)
|
||||
context.context().check_collective_ops_peer_health(
|
||||
peer, timeout_in_ms=self._check_health_timeout * 1000)
|
||||
# If check_collective_ops_peer_health doesn't raise an Exception,
|
||||
# the peer is healthy.
|
||||
break
|
||||
except (errors.UnavailableError,
|
||||
errors.FailedPreconditionError) as e:
|
||||
except (errors.UnavailableError, errors.FailedPreconditionError,
|
||||
errors.DeadlineExceededError) as e:
|
||||
# TODO(b/151232436): Always raise UnavailableError when a peer
|
||||
# fails. Now there could be many kinds of errors:
|
||||
# - Unavailable: when the peer is not reachable, e.g. it's down.
|
||||
|
@ -36,7 +36,6 @@ cuda_py_test(
|
||||
shard_count = 2,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_oss", # TODO(b/170838851): UnavailableError: Connection reset by peer
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -37,6 +37,8 @@ from tensorflow.python.eager import test
|
||||
mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
|
||||
mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
|
||||
mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
|
||||
# This is needed for OSS, which issues all RPCs with fail_fast=false by default.
|
||||
mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1
|
||||
|
||||
|
||||
def get_attempt(strategy, attempts):
|
||||
|
@ -748,7 +748,7 @@ class Context(object):
|
||||
self.ensure_initialized()
|
||||
pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
|
||||
|
||||
def check_collective_ops_peer_health(self, task):
|
||||
def check_collective_ops_peer_health(self, task, timeout_in_ms):
|
||||
"""Check collective peer health.
|
||||
|
||||
This probes each task to see if they're still alive. Note that restarted
|
||||
@ -758,6 +758,7 @@ class Context(object):
|
||||
|
||||
Args:
|
||||
task: a task string, must be in the format of /job:xxx/replica:0/task:N.
|
||||
timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
|
||||
|
||||
Raises:
|
||||
tf.errors.UnavailableError: when a peer is down.
|
||||
@ -766,7 +767,8 @@ class Context(object):
|
||||
tf.errors.InvalidArgumentError: when the task string is invalid.
|
||||
"""
|
||||
self.ensure_initialized()
|
||||
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task)
|
||||
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
|
||||
timeout_in_ms)
|
||||
|
||||
@property
|
||||
def _handle(self):
|
||||
|
@ -60,7 +60,8 @@ class CollectiveOpTest(test.TestCase):
|
||||
"/job:worker/replica:0/task:0",
|
||||
"/job:worker/replica:0/task:1",
|
||||
]:
|
||||
context.context().check_collective_ops_peer_health(task)
|
||||
context.context().check_collective_ops_peer_health(
|
||||
task, timeout_in_ms=1000)
|
||||
except errors.UnavailableError:
|
||||
continue
|
||||
break
|
||||
@ -73,18 +74,16 @@ class CollectiveOpTest(test.TestCase):
|
||||
|
||||
def testCheckHealthPeerDown(self):
|
||||
|
||||
if multi_process_runner.is_oss():
|
||||
self.skipTest("TODO(b/170838845): Failing in OSS")
|
||||
|
||||
def worker_fn():
|
||||
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
|
||||
context.context().check_collective_ops_peer_health(
|
||||
"/job:worker/replica:0/task:1",)
|
||||
"/job:worker/replica:0/task:1", timeout_in_ms=1000)
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||
mpr.start_single_process("worker", 0)
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
with self.assertRaises(
|
||||
(errors.UnavailableError, errors.DeadlineExceededError)):
|
||||
mpr.join()
|
||||
|
||||
def testCheckHealthPeerRestart(self):
|
||||
@ -112,7 +111,7 @@ class CollectiveOpTest(test.TestCase):
|
||||
time.sleep(1)
|
||||
try:
|
||||
context.context().check_collective_ops_peer_health(
|
||||
"/job:worker/replica:0/task:0",)
|
||||
"/job:worker/replica:0/task:0", timeout_in_ms=1000)
|
||||
except errors.UnavailableError:
|
||||
pass
|
||||
except errors.FailedPreconditionError:
|
||||
@ -129,7 +128,8 @@ class CollectiveOpTest(test.TestCase):
|
||||
|
||||
def worker_fn():
|
||||
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
|
||||
context.context().check_collective_ops_peer_health("localhost:12345",)
|
||||
context.context().check_collective_ops_peer_health(
|
||||
"localhost:12345", timeout_in_ms=1000)
|
||||
|
||||
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
|
||||
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
|
||||
|
@ -1054,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
|
||||
});
|
||||
m.def("TFE_CollectiveOpsCheckPeerHealth",
|
||||
[](const py::handle& ctx, const char* task) {
|
||||
[](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
|
||||
task, status.get());
|
||||
task, timeout_in_ms, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
|
||||
|
Loading…
Reference in New Issue
Block a user