Use Unavailable error for non-existing server context since this indicates the remote server has restarted.

PiperOrigin-RevId: 338367526
Change-Id: If9a9a09fd09a55a3f4734518a90905bf0764b669
This commit is contained in:
Haoyu Zhang 2020-10-21 16:51:10 -07:00 committed by TensorFlower Gardener
parent d6831dff73
commit 6d4f1d5c09
7 changed files with 66 additions and 73 deletions

View File

@ -97,6 +97,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core:gpu_runtime", "//tensorflow/core:gpu_runtime",
] + internal_tfrt_deps(), ] + internal_tfrt_deps(),
alwayslink = 1, alwayslink = 1,

View File

@ -70,6 +70,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM) #else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
// TODO(yuefengz): support partially specified `worker_name`. tensorflow::GrpcServer* grpc_server =
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client; dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
status->status = context->GetClient(worker_name, &eager_client); if (grpc_server == nullptr) {
if (!status->status.ok()) { status->status =
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
return false;
}
tensorflow::WorkerInterface* wi =
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Unable to find worker interface corresponding to task ", worker_name);
return false; return false;
} }
// Send a rpc request to the worker to check aliveness. tensorflow::GetStatusRequest request;
tensorflow::eager::KeepAliveRequest request; tensorflow::GetStatusResponse response;
request.set_context_id(context->GetContextId()); tensorflow::Status remote_status;
tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status;
tensorflow::Notification done; tensorflow::Notification done;
eager_client->KeepAliveAsync( wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
&request, &response, [&remote_status, &done](const tensorflow::Status& s) {
[&keep_alive_status, &done](const tensorflow::Status& s) { remote_status = s;
keep_alive_status = s;
done.Notify(); done.Notify();
}); });
done.WaitForNotification(); done.WaitForNotification();
// We set OK status so the call does not raise any exceptions. Instead, caller
// users the return value to tell if the remote worker is alive.
status->status = tensorflow::Status::OK(); status->status = tensorflow::Status::OK();
// If `context_id` doesn't exist on the remote worker, an InvalidArgument if (remote_status.ok()) {
// error will return. But this still indicates that the remote worker is
// alive.
if (keep_alive_status.ok() ||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
return true; return true;
} else {
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << keep_alive_status.error_message();
return false;
} }
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << remote_status.error_message();
return false;
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }

View File

@ -752,7 +752,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext(
auto iter = contexts_.find(context_id); auto iter = contexts_.find(context_id);
if (iter == contexts_.end()) { if (iter == contexts_.end()) {
*server_context = nullptr; *server_context = nullptr;
return errors::InvalidArgument(strings::Printf( return errors::Unavailable(strings::Printf(
"Unable to find a context_id matching the specified one " "Unable to find a context_id matching the specified one "
"(%llu). Perhaps the worker was restarted, or the context was GC'd?", "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
static_cast<unsigned long long>(context_id))); static_cast<unsigned long long>(context_id)));

View File

@ -1248,7 +1248,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
// Unable to handle the request since there is no eager context. // Unable to handle the request since there is no eager context.
Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
&remote_enqueue_response); &remote_enqueue_response);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_EQ(error::UNAVAILABLE, status.code());
EXPECT_TRUE(absl::StrContains( EXPECT_TRUE(absl::StrContains(
status.error_message(), status.error_message(),
"Unable to find a context_id matching the specified one")); "Unable to find a context_id matching the specified one"));
@ -1285,7 +1285,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
Status status = Status status =
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_EQ(status.code(), error::UNAVAILABLE);
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
status.error_message()); status.error_message());

View File

@ -1343,10 +1343,8 @@ def _extract_failed_ps_instances(err_msg):
def _is_ps_failure(error): def _is_ps_failure(error):
"""Whether the error is considered a parameter server failure.""" """Whether the error is considered a parameter server failure."""
if (_RPC_ERROR_FROM_PS in str(error) or return (isinstance(error, errors.UnavailableError) and
(isinstance(error, errors.InvalidArgumentError) and _RPC_ERROR_FROM_PS in str(error))
"/job:ps" in str(error))):
return True
def _is_worker_failure(error): def _is_worker_failure(error):
@ -1366,8 +1364,7 @@ def _is_worker_failure(error):
# failure. In that case, gRPC allows channel (which is different from a # failure. In that case, gRPC allows channel (which is different from a
# connection) to be reused for a replaced server listening to same address. # connection) to be reused for a replaced server listening to same address.
if isinstance(error, errors.InvalidArgumentError): if isinstance(error, errors.InvalidArgumentError):
if ("Unable to find a context_id" in str(error) or if ("unknown device" in str(error) or
"unknown device" in str(error) or
"Unable to find the relevant tensor remote_handle" in str(error)): "Unable to find the relevant tensor remote_handle" in str(error)):
# TODO(b/159961667): Fix "Unable to find the relevant tensor # TODO(b/159961667): Fix "Unable to find the relevant tensor
# remote_handle" part. # remote_handle" part.

View File

@ -234,21 +234,18 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
try: try:
self.thread_coord.join([run_thread]) self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.InvalidArgumentError) as e: except errors.UnavailableError as e:
logging.info("Got exception %r, error message is %s", e, e) logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or "Socket closed" in str(e) or
"Connection reset by peer" in str(e) or "Connection reset by peer" in str(e) or
"Transport closed" in str(e)) "Transport closed" in str(e))
if isinstance(e, errors.InvalidArgumentError):
self.assertIn("Unable to find a context_id", str(e))
def testWorkerPreemptionErrorTypeWithPythonFunction(self): def testWorkerPreemptionErrorTypeWithPythonFunction(self):
def worker_train_fn(): def worker_train_fn():
@ -271,21 +268,18 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
try: try:
self.thread_coord.join([run_thread]) self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.InvalidArgumentError) as e: except errors.UnavailableError as e:
logging.info("Got exception %r, error message is %s", e, e) logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or "Socket closed" in str(e) or
"Connection reset by peer" in str(e) or "Connection reset by peer" in str(e) or
"Transport closed" in str(e)) "Transport closed" in str(e))
if isinstance(e, errors.InvalidArgumentError):
self.assertIn("Unable to find a context_id", str(e))
def testPSPreemptionErrorType(self): def testPSPreemptionErrorType(self):
with ops.device("/job:ps/replica:0/task:0"): with ops.device("/job:ps/replica:0/task:0"):
@ -309,25 +303,23 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
run_thread = threading.Thread(target=run_fn) run_thread = threading.Thread(target=run_fn)
run_thread.start() run_thread.start()
time.sleep(1) # Let it run a couple steps. time.sleep(1) # Let it run a couple steps.
self._restart(5, "ps")
# Use a short restart delay to cover the case that RPC channel is reused
self._restart(1, "ps")
try: try:
self.thread_coord.join([run_thread]) self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.InvalidArgumentError, except (errors.UnavailableError, errors.AbortedError) as e:
errors.AbortedError) as e:
logging.info("Got exception %r, error message is %s", e, e) logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except
self.assertIn(_RPC_ERROR_FROM_PS, str(e))
if isinstance(e, errors.UnavailableError): if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or "Socket closed" in str(e) or
"Connection reset by peer" in str(e) or "Connection reset by peer" in str(e) or
"Transport closed" in str(e)) "Transport closed" in str(e))
if isinstance(e, errors.InvalidArgumentError):
self.assertIn("Unable to find a context_id", str(e))
if isinstance(e, errors.AbortedError): if isinstance(e, errors.AbortedError):
self.assertIn("RecvTensor expects a different device incarnation", self.assertIn("RecvTensor expects a different device incarnation",
str(e)) str(e))

View File

@ -320,6 +320,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
t.start() t.start()
for _ in range(num_calls): for _ in range(num_calls):
@def_function.function @def_function.function
def worker_fn(i): def worker_fn(i):
return math_ops.matmul(i, i) return math_ops.matmul(i, i)
@ -389,10 +390,10 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
t1_results = [None] * num_calls t1_results = [None] * num_calls
t2_results = [None] * num_calls t2_results = [None] * num_calls
threads = [] threads = []
threads.append(threading.Thread(target=thread_fn, threads.append(
args=(self.device_t1, t1_results))) threading.Thread(target=thread_fn, args=(self.device_t1, t1_results)))
threads.append(threading.Thread(target=thread_fn, threads.append(
args=(self.device_t2, t2_results))) threading.Thread(target=thread_fn, args=(self.device_t2, t2_results)))
threads.append(threading.Thread(target=update_server_def_fn)) threads.append(threading.Thread(target=update_server_def_fn))
for t in threads: for t in threads:
t.start() t.start()
@ -535,6 +536,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
with ops.device(self.device_t2): with ops.device(self.device_t2):
add = mul + i add = mul + i
return add - i return add - i
worker_fn.get_concrete_function(x1) worker_fn.get_concrete_function(x1)
num_calls = 10 num_calls = 10
@ -551,13 +553,13 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
with self._coord.stop_on_exception(): with self._coord.stop_on_exception():
for i in range(num_calls): for i in range(num_calls):
context.update_server_def( context.update_server_def(
server_def=(self.server_def_s1_s2_s3 server_def=(self.server_def_s1_s2_s3 if i %
if i % 2 == 0 else self.server_def_s1_s2)) 2 == 0 else self.server_def_s1_s2))
results = [None] * num_calls results = [None] * num_calls
threads = [] threads = []
threads.append(threading.Thread(target=thread_fn, threads.append(
args=(self.device_t1, results))) threading.Thread(target=thread_fn, args=(self.device_t1, results)))
threads.append(threading.Thread(target=update_server_def_fn)) threads.append(threading.Thread(target=update_server_def_fn))
for t in threads: for t in threads:
t.start() t.start()
@ -630,9 +632,8 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
with self.assertRaisesRegex( with self.assertRaisesRegex(errors.InvalidArgumentError,
errors.InvalidArgumentError, "Unable to find worker interface"):
"Client for target /job:remote_device/replica:0/task:10 not found."):
context.check_alive("/job:remote_device/replica:0/task:10") context.check_alive("/job:remote_device/replica:0/task:10")