Use Unavailable error for non-existing server context since this indicates the remote server has restarted.
PiperOrigin-RevId: 338367526 Change-Id: If9a9a09fd09a55a3f4734518a90905bf0764b669
This commit is contained in:
parent
d6831dff73
commit
6d4f1d5c09
@ -97,6 +97,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
] + internal_tfrt_deps(),
|
||||
alwayslink = 1,
|
||||
|
@ -70,6 +70,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
// TODO(yuefengz): support partially specified `worker_name`.
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
status->status = context->GetClient(worker_name, &eager_client);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
|
||||
return false;
|
||||
}
|
||||
tensorflow::WorkerInterface* wi =
|
||||
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
|
||||
if (wi == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to find worker interface corresponding to task ", worker_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Send a rpc request to the worker to check aliveness.
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context->GetContextId());
|
||||
tensorflow::eager::KeepAliveResponse response;
|
||||
|
||||
tensorflow::Status keep_alive_status;
|
||||
tensorflow::GetStatusRequest request;
|
||||
tensorflow::GetStatusResponse response;
|
||||
tensorflow::Status remote_status;
|
||||
tensorflow::Notification done;
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &response,
|
||||
[&keep_alive_status, &done](const tensorflow::Status& s) {
|
||||
keep_alive_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||
[&remote_status, &done](const tensorflow::Status& s) {
|
||||
remote_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
// We set OK status so the call does not raise any exceptions. Instead, caller
|
||||
// users the return value to tell if the remote worker is alive.
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
||||
// If `context_id` doesn't exist on the remote worker, an InvalidArgument
|
||||
// error will return. But this still indicates that the remote worker is
|
||||
// alive.
|
||||
if (keep_alive_status.ok() ||
|
||||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
||||
if (remote_status.ok()) {
|
||||
return true;
|
||||
} else {
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << keep_alive_status.error_message();
|
||||
return false;
|
||||
}
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << remote_status.error_message();
|
||||
return false;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -752,7 +752,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext(
|
||||
auto iter = contexts_.find(context_id);
|
||||
if (iter == contexts_.end()) {
|
||||
*server_context = nullptr;
|
||||
return errors::InvalidArgument(strings::Printf(
|
||||
return errors::Unavailable(strings::Printf(
|
||||
"Unable to find a context_id matching the specified one "
|
||||
"(%llu). Perhaps the worker was restarted, or the context was GC'd?",
|
||||
static_cast<unsigned long long>(context_id)));
|
||||
|
@ -1248,7 +1248,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
||||
// Unable to handle the request since there is no eager context.
|
||||
Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
|
||||
&remote_enqueue_response);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||
EXPECT_EQ(error::UNAVAILABLE, status.code());
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
status.error_message(),
|
||||
"Unable to find a context_id matching the specified one"));
|
||||
@ -1285,7 +1285,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
|
||||
Status status =
|
||||
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
|
||||
|
||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(status.code(), error::UNAVAILABLE);
|
||||
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
|
||||
status.error_message());
|
||||
|
||||
|
@ -1343,10 +1343,8 @@ def _extract_failed_ps_instances(err_msg):
|
||||
|
||||
def _is_ps_failure(error):
|
||||
"""Whether the error is considered a parameter server failure."""
|
||||
if (_RPC_ERROR_FROM_PS in str(error) or
|
||||
(isinstance(error, errors.InvalidArgumentError) and
|
||||
"/job:ps" in str(error))):
|
||||
return True
|
||||
return (isinstance(error, errors.UnavailableError) and
|
||||
_RPC_ERROR_FROM_PS in str(error))
|
||||
|
||||
|
||||
def _is_worker_failure(error):
|
||||
@ -1366,8 +1364,7 @@ def _is_worker_failure(error):
|
||||
# failure. In that case, gRPC allows channel (which is different from a
|
||||
# connection) to be reused for a replaced server listening to same address.
|
||||
if isinstance(error, errors.InvalidArgumentError):
|
||||
if ("Unable to find a context_id" in str(error) or
|
||||
"unknown device" in str(error) or
|
||||
if ("unknown device" in str(error) or
|
||||
"Unable to find the relevant tensor remote_handle" in str(error)):
|
||||
# TODO(b/159961667): Fix "Unable to find the relevant tensor
|
||||
# remote_handle" part.
|
||||
|
@ -234,20 +234,17 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
try:
|
||||
self.thread_coord.join([run_thread])
|
||||
except (errors.UnavailableError, errors.InvalidArgumentError) as e:
|
||||
except errors.UnavailableError as e:
|
||||
logging.info("Got exception %r, error message is %s", e, e)
|
||||
|
||||
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))
|
||||
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
|
||||
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
|
||||
|
||||
if isinstance(e, errors.UnavailableError):
|
||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||
"Socket closed" in str(e) or
|
||||
"Connection reset by peer" in str(e) or
|
||||
"Transport closed" in str(e))
|
||||
|
||||
if isinstance(e, errors.InvalidArgumentError):
|
||||
self.assertIn("Unable to find a context_id", str(e))
|
||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||
"Unable to find a context_id" in str(e) or
|
||||
"Socket closed" in str(e) or
|
||||
"Connection reset by peer" in str(e) or
|
||||
"Transport closed" in str(e))
|
||||
|
||||
def testWorkerPreemptionErrorTypeWithPythonFunction(self):
|
||||
|
||||
@ -271,20 +268,17 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
|
||||
try:
|
||||
self.thread_coord.join([run_thread])
|
||||
except (errors.UnavailableError, errors.InvalidArgumentError) as e:
|
||||
except errors.UnavailableError as e:
|
||||
logging.info("Got exception %r, error message is %s", e, e)
|
||||
|
||||
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))
|
||||
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
|
||||
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
|
||||
|
||||
if isinstance(e, errors.UnavailableError):
|
||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||
"Socket closed" in str(e) or
|
||||
"Connection reset by peer" in str(e) or
|
||||
"Transport closed" in str(e))
|
||||
|
||||
if isinstance(e, errors.InvalidArgumentError):
|
||||
self.assertIn("Unable to find a context_id", str(e))
|
||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||
"Unable to find a context_id" in str(e) or
|
||||
"Socket closed" in str(e) or
|
||||
"Connection reset by peer" in str(e) or
|
||||
"Transport closed" in str(e))
|
||||
|
||||
def testPSPreemptionErrorType(self):
|
||||
|
||||
@ -309,25 +303,23 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
||||
run_thread = threading.Thread(target=run_fn)
|
||||
run_thread.start()
|
||||
time.sleep(1) # Let it run a couple steps.
|
||||
self._restart(5, "ps")
|
||||
|
||||
# Use a short restart delay to cover the case that RPC channel is reused
|
||||
self._restart(1, "ps")
|
||||
|
||||
try:
|
||||
self.thread_coord.join([run_thread])
|
||||
except (errors.UnavailableError, errors.InvalidArgumentError,
|
||||
errors.AbortedError) as e:
|
||||
except (errors.UnavailableError, errors.AbortedError) as e:
|
||||
logging.info("Got exception %r, error message is %s", e, e)
|
||||
|
||||
self.assertIn(_RPC_ERROR_FROM_PS, str(e))
|
||||
self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except
|
||||
|
||||
if isinstance(e, errors.UnavailableError):
|
||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||
"Unable to find a context_id" in str(e) or
|
||||
"Socket closed" in str(e) or
|
||||
"Connection reset by peer" in str(e) or
|
||||
"Transport closed" in str(e))
|
||||
|
||||
if isinstance(e, errors.InvalidArgumentError):
|
||||
self.assertIn("Unable to find a context_id", str(e))
|
||||
|
||||
if isinstance(e, errors.AbortedError):
|
||||
self.assertIn("RecvTensor expects a different device incarnation",
|
||||
str(e))
|
||||
|
@ -320,6 +320,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
||||
t.start()
|
||||
|
||||
for _ in range(num_calls):
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(i):
|
||||
return math_ops.matmul(i, i)
|
||||
@ -389,10 +390,10 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
||||
t1_results = [None] * num_calls
|
||||
t2_results = [None] * num_calls
|
||||
threads = []
|
||||
threads.append(threading.Thread(target=thread_fn,
|
||||
args=(self.device_t1, t1_results)))
|
||||
threads.append(threading.Thread(target=thread_fn,
|
||||
args=(self.device_t2, t2_results)))
|
||||
threads.append(
|
||||
threading.Thread(target=thread_fn, args=(self.device_t1, t1_results)))
|
||||
threads.append(
|
||||
threading.Thread(target=thread_fn, args=(self.device_t2, t2_results)))
|
||||
threads.append(threading.Thread(target=update_server_def_fn))
|
||||
for t in threads:
|
||||
t.start()
|
||||
@ -535,6 +536,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device(self.device_t2):
|
||||
add = mul + i
|
||||
return add - i
|
||||
|
||||
worker_fn.get_concrete_function(x1)
|
||||
|
||||
num_calls = 10
|
||||
@ -551,13 +553,13 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
||||
with self._coord.stop_on_exception():
|
||||
for i in range(num_calls):
|
||||
context.update_server_def(
|
||||
server_def=(self.server_def_s1_s2_s3
|
||||
if i % 2 == 0 else self.server_def_s1_s2))
|
||||
server_def=(self.server_def_s1_s2_s3 if i %
|
||||
2 == 0 else self.server_def_s1_s2))
|
||||
|
||||
results = [None] * num_calls
|
||||
threads = []
|
||||
threads.append(threading.Thread(target=thread_fn,
|
||||
args=(self.device_t1, results)))
|
||||
threads.append(
|
||||
threading.Thread(target=thread_fn, args=(self.device_t1, results)))
|
||||
threads.append(threading.Thread(target=update_server_def_fn))
|
||||
for t in threads:
|
||||
t.start()
|
||||
@ -630,9 +632,8 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
|
||||
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Client for target /job:remote_device/replica:0/task:10 not found."):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Unable to find worker interface"):
|
||||
context.check_alive("/job:remote_device/replica:0/task:10")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user