Set a timeout to check health RPC

fail_fast is ignored by default in OSS, so we need to set timeout to the RPC to avoid it taking forever if a remote worker is down.

PiperOrigin-RevId: 338320405
Change-Id: I154e3ebbcd3a224a0caa3a39a0a7f9927e3ea220
This commit is contained in:
Ran Chen 2020-10-21 12:47:34 -07:00 committed by TensorFlower Gardener
parent 9fe37b34f5
commit c2b5bebd70
25 changed files with 70 additions and 51 deletions

View File

@ -561,15 +561,15 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status);
}
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status) {
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) {
task, timeout_in_ms, [&done, status](const Status& s) {
status->status = s;
done.Notify();
});

View File

@ -241,9 +241,9 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
// Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status);
// Information about the shape of a Tensor and its type.
struct TF_ShapeAndType {

View File

@ -107,6 +107,7 @@ void CollectiveRemoteAccessLocal::PostToPeer(
}
void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
int64 timeout_in_ms,
const StatusCallback& done) {
// Assume local devices are always healthy.
done(errors::Internal(

View File

@ -55,7 +55,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess {
CancellationManager* cancellation_manager,
const StatusCallback& done) override;
void CheckPeerHealth(const string& peer_task,
void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
const StatusCallback& done) override;
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }

View File

@ -156,10 +156,11 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) {
Status status;
Notification done;
rma_->CheckPeerHealth(kTaskName, [&status, &done](const Status& s) {
status = s;
done.Notify();
});
rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0,
[&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
EXPECT_TRUE(errors::IsInternal(status));
}

View File

@ -531,6 +531,7 @@ cc_library(
srcs = ["collective_rma_distributed.cc"],
hdrs = ["collective_rma_distributed.h"],
deps = [
":call_options",
":cancellable_call",
":request_id",
":worker_cache",

View File

@ -53,7 +53,7 @@ class FakeWorker : public TestWorkerInterface {
CollectiveParamResolverDistributed* cpres)
: name_(name), device_mgr_(dev_mgr), param_resolver_(cpres) {}
void GetStatusAsync(const GetStatusRequest* request,
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
std::vector<DeviceAttributes> dev_attr;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
@ -179,7 +180,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
}
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
const string& peer_task, const StatusCallback& done) {
const string& peer_task, int64 timeout_in_ms, const StatusCallback& done) {
if (peer_task == task_name_) {
// Fast path if the peer is the worker itself.
done(Status::OK());
@ -196,13 +197,16 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
"valid form is /job:xxx/replica:0/task:N"));
return;
}
auto opts = new CallOptions();
opts->SetTimeout(timeout_in_ms);
auto req = new GetStatusRequest();
auto resp = new GetStatusResponse();
// We're not using Cancellable call because GetStatusAsync doesn't support
// cancellation yet.
// Note that fail_fast is not always respected, so we set a timeout as well.
// We're not using CancellableCall since check health shouldn't need to be
// cancelled.
wi->GetStatusAsync(
req, resp, /*fail_fast*/ true,
[this, req, resp, wi, peer_task, done](Status s) {
opts, req, resp, /*fail_fast*/ true,
[this, opts, req, resp, wi, peer_task, done](Status s) {
std::vector<DeviceAttributes> cached_attrs;
if (s.ok()) {
s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
@ -227,6 +231,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
// first collective.
s = Status::OK();
}
delete opts;
delete req;
delete resp;
worker_cache_->ReleaseWorker(peer_task, wi);

View File

@ -45,7 +45,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
CancellationManager* cancellation_manager,
const StatusCallback& done) override;
void CheckPeerHealth(const string& peer_task,
void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
const StatusCallback& done) override;
void StartAbort(const Status& s) override;

View File

@ -74,7 +74,7 @@ class FakeWorker : public TestWorkerInterface {
// worker is supposed to have.
BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
void GetStatusAsync(const GetStatusRequest* request,
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
if (is_failed_) {
@ -459,7 +459,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
Status check_health_status;
Notification check_health_done;
rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1",
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) {
check_health_status = s;
check_health_done.Notify();
@ -472,7 +472,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
Status check_health_status;
Notification check_health_done;
rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1",
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) {
check_health_status = s;
check_health_done.Notify();
@ -488,7 +488,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
Status check_health_status;
Notification check_health_done;
rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1",
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) {
check_health_status = s;
check_health_done.Notify();
@ -505,7 +505,7 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
Status check_health_status;
Notification check_health_done;
rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1",
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) {
check_health_status = s;
check_health_done.Notify();
@ -520,7 +520,7 @@ TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
Status check_health_status;
Notification check_health_done;
rma_->CheckPeerHealth(
"/job:worker/replica:0/task:1",
"/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
[&check_health_status, &check_health_done](const Status s) {
check_health_status = s;
check_health_done.Notify();

View File

@ -143,7 +143,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
}
}
};
wi->GetStatusAsync(&call->req, &call->resp, /*fail_fast=*/false, cb);
wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
/*fail_fast=*/false, cb);
}
} // namespace tensorflow

View File

@ -94,6 +94,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_interface",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/generic/generic_stub.h"
#include "grpcpp/grpcpp.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@ -70,10 +71,10 @@ class GrpcRemoteWorker : public WorkerInterface {
~GrpcRemoteWorker() override {}
void GetStatusAsync(const GetStatusRequest* request,
void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
IssueRequest(request, response, getstatus_, std::move(done), nullptr,
IssueRequest(request, response, getstatus_, std::move(done), call_opts,
fail_fast);
}

View File

@ -84,7 +84,8 @@ class RPCState : public GrpcClientCQTag {
return false;
}
}(),
/*timeout_in_ms=*/0, max_retries, target) {
(call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
target) {
}
template <typename Request>

View File

@ -30,7 +30,7 @@ namespace tensorflow {
// testing.
class TestWorkerInterface : public WorkerInterface {
public:
void GetStatusAsync(const GetStatusRequest* request,
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
done(errors::Unimplemented("GetStatusAsync"));

View File

@ -35,7 +35,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
StatusGroup::ConfigureLogHistory();
}
void Worker::GetStatusAsync(const GetStatusRequest* request,
void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) {
const DeviceMgr* dm = env_->device_mgr;

View File

@ -45,7 +45,7 @@ class Worker : public WorkerInterface {
Worker(WorkerEnv* env);
virtual ~Worker() {}
void GetStatusAsync(const GetStatusRequest* request,
void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override;

View File

@ -36,7 +36,8 @@ class TensorResponse;
// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
public:
virtual void GetStatusAsync(const GetStatusRequest* request,
virtual void GetStatusAsync(CallOptions* opts,
const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) = 0;
@ -132,7 +133,7 @@ class WorkerInterface {
GetStatusResponse* response) {
Status ret;
Notification n;
GetStatusAsync(request, response, /*fail_fast=*/true,
GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
[&ret, &n](const Status& s) {
ret = s;
n.Notify();

View File

@ -283,7 +283,7 @@ class CollectiveRemoteAccess {
// Checks the health of a collective peer. It probes the peer to see if it is
// alive. Note that if a peer has restarted, it's considered a different one,
// so CheckPeerHealth fails.
virtual void CheckPeerHealth(const string& peer_task,
virtual void CheckPeerHealth(const string& peer_task, int64 timeout_in_ms,
const StatusCallback& done) = 0;
virtual BufRendezvous* buf_rendezvous() = 0;

View File

@ -309,6 +309,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
_check_health_initial_timeout = 0
# Times to retry before considering the peer is down.
_check_health_retry_limit = 3
# Timeout in seconds the each check health.
_check_health_timeout = 10
def __init__(self, container_strategy, cluster_resolver,
communication_options):
@ -780,12 +782,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
while True:
attempts += 1
try:
context.context().check_collective_ops_peer_health(peer)
context.context().check_collective_ops_peer_health(
peer, timeout_in_ms=self._check_health_timeout * 1000)
# If check_collective_ops_peer_health doesn't raise an Exception,
# the peer is healthy.
break
except (errors.UnavailableError,
errors.FailedPreconditionError) as e:
except (errors.UnavailableError, errors.FailedPreconditionError,
errors.DeadlineExceededError) as e:
# TODO(b/151232436): Always raise UnavailableError when a peer
# fails. Now there could be many kinds of errors:
# - Unavailable: when the peer is not reachable, e.g. it's down.

View File

@ -36,7 +36,6 @@ cuda_py_test(
shard_count = 2,
tags = [
"multi_and_single_gpu",
"no_oss", # TODO(b/170838851): UnavailableError: Connection reset by peer
],
deps = [
"//tensorflow:tensorflow_py",

View File

@ -37,6 +37,8 @@ from tensorflow.python.eager import test
mwms_lib.CollectiveAllReduceExtended._enable_check_health = True
mwms_lib.CollectiveAllReduceExtended._check_health_interval = 3
mwms_lib.CollectiveAllReduceExtended._check_health_initial_timeout = 0
# This is needed for OSS, which issues all RPCs with fail_fast=false by default.
mwms_lib.CollectiveAllReduceExtended._check_health_timeout = 1
def get_attempt(strategy, attempts):

View File

@ -748,7 +748,7 @@ class Context(object):
self.ensure_initialized()
pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
def check_collective_ops_peer_health(self, task):
def check_collective_ops_peer_health(self, task, timeout_in_ms):
"""Check collective peer health.
This probes each task to see if they're still alive. Note that restarted
@ -758,6 +758,7 @@ class Context(object):
Args:
task: a task string, must be in the format of /job:xxx/replica:0/task:N.
timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
Raises:
tf.errors.UnavailableError: when a peer is down.
@ -766,7 +767,8 @@ class Context(object):
tf.errors.InvalidArgumentError: when the task string is invalid.
"""
self.ensure_initialized()
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task)
pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
timeout_in_ms)
@property
def _handle(self):

View File

@ -60,7 +60,8 @@ class CollectiveOpTest(test.TestCase):
"/job:worker/replica:0/task:0",
"/job:worker/replica:0/task:1",
]:
context.context().check_collective_ops_peer_health(task)
context.context().check_collective_ops_peer_health(
task, timeout_in_ms=1000)
except errors.UnavailableError:
continue
break
@ -73,18 +74,16 @@ class CollectiveOpTest(test.TestCase):
def testCheckHealthPeerDown(self):
if multi_process_runner.is_oss():
self.skipTest("TODO(b/170838845): Failing in OSS")
def worker_fn():
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
context.context().check_collective_ops_peer_health(
"/job:worker/replica:0/task:1",)
"/job:worker/replica:0/task:1", timeout_in_ms=1000)
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
mpr.start_single_process("worker", 0)
with self.assertRaises(errors.UnavailableError):
with self.assertRaises(
(errors.UnavailableError, errors.DeadlineExceededError)):
mpr.join()
def testCheckHealthPeerRestart(self):
@ -112,7 +111,7 @@ class CollectiveOpTest(test.TestCase):
time.sleep(1)
try:
context.context().check_collective_ops_peer_health(
"/job:worker/replica:0/task:0",)
"/job:worker/replica:0/task:0", timeout_in_ms=1000)
except errors.UnavailableError:
pass
except errors.FailedPreconditionError:
@ -129,7 +128,8 @@ class CollectiveOpTest(test.TestCase):
def worker_fn():
enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
context.context().check_collective_ops_peer_health("localhost:12345",)
context.context().check_collective_ops_peer_health(
"localhost:12345", timeout_in_ms=1000)
cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)

View File

@ -1054,11 +1054,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
});
m.def("TFE_CollectiveOpsCheckPeerHealth",
[](const py::handle& ctx, const char* task) {
[](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
task, status.get());
task, timeout_in_ms, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);