Use Unavailable error for non-existing server context since this indicates the remote server has restarted.

PiperOrigin-RevId: 338367526
Change-Id: If9a9a09fd09a55a3f4734518a90905bf0764b669
This commit is contained in:
Haoyu Zhang 2020-10-21 16:51:10 -07:00 committed by TensorFlower Gardener
parent d6831dff73
commit 6d4f1d5c09
7 changed files with 66 additions and 73 deletions

View File

@ -97,6 +97,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core:gpu_runtime",
] + internal_tfrt_deps(),
alwayslink = 1,

View File

@ -70,6 +70,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h"
@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
// TODO(yuefengz): support partially specified `worker_name`.
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
status->status = context->GetClient(worker_name, &eager_client);
if (!status->status.ok()) {
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
status->status =
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
return false;
}
tensorflow::WorkerInterface* wi =
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Unable to find worker interface corresponding to task ", worker_name);
return false;
}
// Send a rpc request to the worker to check aliveness.
tensorflow::eager::KeepAliveRequest request;
request.set_context_id(context->GetContextId());
tensorflow::eager::KeepAliveResponse response;
tensorflow::Status keep_alive_status;
tensorflow::GetStatusRequest request;
tensorflow::GetStatusResponse response;
tensorflow::Status remote_status;
tensorflow::Notification done;
eager_client->KeepAliveAsync(
&request, &response,
[&keep_alive_status, &done](const tensorflow::Status& s) {
keep_alive_status = s;
done.Notify();
});
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
[&remote_status, &done](const tensorflow::Status& s) {
remote_status = s;
done.Notify();
});
done.WaitForNotification();
// We set OK status so the call does not raise any exceptions. Instead, caller
// users the return value to tell if the remote worker is alive.
status->status = tensorflow::Status::OK();
// If `context_id` doesn't exist on the remote worker, an InvalidArgument
// error will return. But this still indicates that the remote worker is
// alive.
if (keep_alive_status.ok() ||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
if (remote_status.ok()) {
return true;
} else {
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << keep_alive_status.error_message();
return false;
}
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << remote_status.error_message();
return false;
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -752,7 +752,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext(
auto iter = contexts_.find(context_id);
if (iter == contexts_.end()) {
*server_context = nullptr;
return errors::InvalidArgument(strings::Printf(
return errors::Unavailable(strings::Printf(
"Unable to find a context_id matching the specified one "
"(%llu). Perhaps the worker was restarted, or the context was GC'd?",
static_cast<unsigned long long>(context_id)));

View File

@ -1248,7 +1248,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
// Unable to handle the request since there is no eager context.
Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
&remote_enqueue_response);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
EXPECT_EQ(error::UNAVAILABLE, status.code());
EXPECT_TRUE(absl::StrContains(
status.error_message(),
"Unable to find a context_id matching the specified one"));
@ -1285,7 +1285,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
Status status =
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
EXPECT_EQ(status.code(), error::UNAVAILABLE);
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
status.error_message());

View File

@ -1343,10 +1343,8 @@ def _extract_failed_ps_instances(err_msg):
def _is_ps_failure(error):
"""Whether the error is considered a parameter server failure."""
if (_RPC_ERROR_FROM_PS in str(error) or
(isinstance(error, errors.InvalidArgumentError) and
"/job:ps" in str(error))):
return True
return (isinstance(error, errors.UnavailableError) and
_RPC_ERROR_FROM_PS in str(error))
def _is_worker_failure(error):
@ -1366,8 +1364,7 @@ def _is_worker_failure(error):
# failure. In that case, gRPC allows channel (which is different from a
# connection) to be reused for a replaced server listening to same address.
if isinstance(error, errors.InvalidArgumentError):
if ("Unable to find a context_id" in str(error) or
"unknown device" in str(error) or
if ("unknown device" in str(error) or
"Unable to find the relevant tensor remote_handle" in str(error)):
# TODO(b/159961667): Fix "Unable to find the relevant tensor
# remote_handle" part.

View File

@ -234,20 +234,17 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
try:
self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.InvalidArgumentError) as e:
except errors.UnavailableError as e:
logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
if isinstance(e, errors.InvalidArgumentError):
self.assertIn("Unable to find a context_id", str(e))
self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
def testWorkerPreemptionErrorTypeWithPythonFunction(self):
@ -271,20 +268,17 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
try:
self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.InvalidArgumentError) as e:
except errors.UnavailableError as e:
logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
if isinstance(e, errors.InvalidArgumentError):
self.assertIn("Unable to find a context_id", str(e))
self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
def testPSPreemptionErrorType(self):
@ -309,25 +303,23 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
run_thread = threading.Thread(target=run_fn)
run_thread.start()
time.sleep(1) # Let it run a couple steps.
self._restart(5, "ps")
# Use a short restart delay to cover the case that RPC channel is reused
self._restart(1, "ps")
try:
self.thread_coord.join([run_thread])
except (errors.UnavailableError, errors.InvalidArgumentError,
errors.AbortedError) as e:
except (errors.UnavailableError, errors.AbortedError) as e:
logging.info("Got exception %r, error message is %s", e, e)
self.assertIn(_RPC_ERROR_FROM_PS, str(e))
self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except
if isinstance(e, errors.UnavailableError):
self.assertTrue("failed to connect to all addresses" in str(e) or
"Unable to find a context_id" in str(e) or
"Socket closed" in str(e) or
"Connection reset by peer" in str(e) or
"Transport closed" in str(e))
if isinstance(e, errors.InvalidArgumentError):
self.assertIn("Unable to find a context_id", str(e))
if isinstance(e, errors.AbortedError):
self.assertIn("RecvTensor expects a different device incarnation",
str(e))

View File

@ -320,6 +320,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
t.start()
for _ in range(num_calls):
@def_function.function
def worker_fn(i):
return math_ops.matmul(i, i)
@ -389,10 +390,10 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
t1_results = [None] * num_calls
t2_results = [None] * num_calls
threads = []
threads.append(threading.Thread(target=thread_fn,
args=(self.device_t1, t1_results)))
threads.append(threading.Thread(target=thread_fn,
args=(self.device_t2, t2_results)))
threads.append(
threading.Thread(target=thread_fn, args=(self.device_t1, t1_results)))
threads.append(
threading.Thread(target=thread_fn, args=(self.device_t2, t2_results)))
threads.append(threading.Thread(target=update_server_def_fn))
for t in threads:
t.start()
@ -535,6 +536,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
with ops.device(self.device_t2):
add = mul + i
return add - i
worker_fn.get_concrete_function(x1)
num_calls = 10
@ -551,13 +553,13 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
with self._coord.stop_on_exception():
for i in range(num_calls):
context.update_server_def(
server_def=(self.server_def_s1_s2_s3
if i % 2 == 0 else self.server_def_s1_s2))
server_def=(self.server_def_s1_s2_s3 if i %
2 == 0 else self.server_def_s1_s2))
results = [None] * num_calls
threads = []
threads.append(threading.Thread(target=thread_fn,
args=(self.device_t1, results)))
threads.append(
threading.Thread(target=thread_fn, args=(self.device_t1, results)))
threads.append(threading.Thread(target=update_server_def_fn))
for t in threads:
t.start()
@ -630,9 +632,8 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Client for target /job:remote_device/replica:0/task:10 not found."):
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Unable to find worker interface"):
context.check_alive("/job:remote_device/replica:0/task:10")