Use Unavailable error for non-existing server context since this indicates the remote server has restarted.
PiperOrigin-RevId: 338367526 Change-Id: If9a9a09fd09a55a3f4734518a90905bf0764b669
This commit is contained in:
parent
d6831dff73
commit
6d4f1d5c09
@ -97,6 +97,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/distributed_runtime:remote_device",
|
"//tensorflow/core/distributed_runtime:remote_device",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
|
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||||
"//tensorflow/core:gpu_runtime",
|
"//tensorflow/core:gpu_runtime",
|
||||||
] + internal_tfrt_deps(),
|
] + internal_tfrt_deps(),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -70,6 +70,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
#else // !defined(IS_MOBILE_PLATFORM)
|
#else // !defined(IS_MOBILE_PLATFORM)
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::EagerContext* context =
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
// TODO(yuefengz): support partially specified `worker_name`.
|
tensorflow::GrpcServer* grpc_server =
|
||||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||||
status->status = context->GetClient(worker_name, &eager_client);
|
if (grpc_server == nullptr) {
|
||||||
if (!status->status.ok()) {
|
status->status =
|
||||||
|
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tensorflow::WorkerInterface* wi =
|
||||||
|
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
|
||||||
|
if (wi == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Unable to find worker interface corresponding to task ", worker_name);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a rpc request to the worker to check aliveness.
|
tensorflow::GetStatusRequest request;
|
||||||
tensorflow::eager::KeepAliveRequest request;
|
tensorflow::GetStatusResponse response;
|
||||||
request.set_context_id(context->GetContextId());
|
tensorflow::Status remote_status;
|
||||||
tensorflow::eager::KeepAliveResponse response;
|
|
||||||
|
|
||||||
tensorflow::Status keep_alive_status;
|
|
||||||
tensorflow::Notification done;
|
tensorflow::Notification done;
|
||||||
eager_client->KeepAliveAsync(
|
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||||
&request, &response,
|
[&remote_status, &done](const tensorflow::Status& s) {
|
||||||
[&keep_alive_status, &done](const tensorflow::Status& s) {
|
remote_status = s;
|
||||||
keep_alive_status = s;
|
|
||||||
done.Notify();
|
done.Notify();
|
||||||
});
|
});
|
||||||
done.WaitForNotification();
|
done.WaitForNotification();
|
||||||
|
|
||||||
|
// We set OK status so the call does not raise any exceptions. Instead, caller
|
||||||
|
// users the return value to tell if the remote worker is alive.
|
||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::Status::OK();
|
||||||
|
|
||||||
// If `context_id` doesn't exist on the remote worker, an InvalidArgument
|
if (remote_status.ok()) {
|
||||||
// error will return. But this still indicates that the remote worker is
|
|
||||||
// alive.
|
|
||||||
if (keep_alive_status.ok() ||
|
|
||||||
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
|
||||||
return true;
|
return true;
|
||||||
} else {
|
|
||||||
LOG(INFO) << "Remote worker " << worker_name
|
|
||||||
<< " is not alive: " << keep_alive_status.error_message();
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
LOG(INFO) << "Remote worker " << worker_name
|
||||||
|
<< " is not alive: " << remote_status.error_message();
|
||||||
|
return false;
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -752,7 +752,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext(
|
|||||||
auto iter = contexts_.find(context_id);
|
auto iter = contexts_.find(context_id);
|
||||||
if (iter == contexts_.end()) {
|
if (iter == contexts_.end()) {
|
||||||
*server_context = nullptr;
|
*server_context = nullptr;
|
||||||
return errors::InvalidArgument(strings::Printf(
|
return errors::Unavailable(strings::Printf(
|
||||||
"Unable to find a context_id matching the specified one "
|
"Unable to find a context_id matching the specified one "
|
||||||
"(%llu). Perhaps the worker was restarted, or the context was GC'd?",
|
"(%llu). Perhaps the worker was restarted, or the context was GC'd?",
|
||||||
static_cast<unsigned long long>(context_id)));
|
static_cast<unsigned long long>(context_id)));
|
||||||
|
@ -1248,7 +1248,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) {
|
|||||||
// Unable to handle the request since there is no eager context.
|
// Unable to handle the request since there is no eager context.
|
||||||
Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
|
Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request,
|
||||||
&remote_enqueue_response);
|
&remote_enqueue_response);
|
||||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
EXPECT_EQ(error::UNAVAILABLE, status.code());
|
||||||
EXPECT_TRUE(absl::StrContains(
|
EXPECT_TRUE(absl::StrContains(
|
||||||
status.error_message(),
|
status.error_message(),
|
||||||
"Unable to find a context_id matching the specified one"));
|
"Unable to find a context_id matching the specified one"));
|
||||||
@ -1285,7 +1285,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
|
|||||||
Status status =
|
Status status =
|
||||||
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
|
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
|
||||||
|
|
||||||
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
|
EXPECT_EQ(status.code(), error::UNAVAILABLE);
|
||||||
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
|
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
|
||||||
status.error_message());
|
status.error_message());
|
||||||
|
|
||||||
|
@ -1343,10 +1343,8 @@ def _extract_failed_ps_instances(err_msg):
|
|||||||
|
|
||||||
def _is_ps_failure(error):
|
def _is_ps_failure(error):
|
||||||
"""Whether the error is considered a parameter server failure."""
|
"""Whether the error is considered a parameter server failure."""
|
||||||
if (_RPC_ERROR_FROM_PS in str(error) or
|
return (isinstance(error, errors.UnavailableError) and
|
||||||
(isinstance(error, errors.InvalidArgumentError) and
|
_RPC_ERROR_FROM_PS in str(error))
|
||||||
"/job:ps" in str(error))):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _is_worker_failure(error):
|
def _is_worker_failure(error):
|
||||||
@ -1366,8 +1364,7 @@ def _is_worker_failure(error):
|
|||||||
# failure. In that case, gRPC allows channel (which is different from a
|
# failure. In that case, gRPC allows channel (which is different from a
|
||||||
# connection) to be reused for a replaced server listening to same address.
|
# connection) to be reused for a replaced server listening to same address.
|
||||||
if isinstance(error, errors.InvalidArgumentError):
|
if isinstance(error, errors.InvalidArgumentError):
|
||||||
if ("Unable to find a context_id" in str(error) or
|
if ("unknown device" in str(error) or
|
||||||
"unknown device" in str(error) or
|
|
||||||
"Unable to find the relevant tensor remote_handle" in str(error)):
|
"Unable to find the relevant tensor remote_handle" in str(error)):
|
||||||
# TODO(b/159961667): Fix "Unable to find the relevant tensor
|
# TODO(b/159961667): Fix "Unable to find the relevant tensor
|
||||||
# remote_handle" part.
|
# remote_handle" part.
|
||||||
|
@ -234,21 +234,18 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.thread_coord.join([run_thread])
|
self.thread_coord.join([run_thread])
|
||||||
except (errors.UnavailableError, errors.InvalidArgumentError) as e:
|
except errors.UnavailableError as e:
|
||||||
logging.info("Got exception %r, error message is %s", e, e)
|
logging.info("Got exception %r, error message is %s", e, e)
|
||||||
|
|
||||||
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))
|
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
|
||||||
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
|
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
|
||||||
|
|
||||||
if isinstance(e, errors.UnavailableError):
|
|
||||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||||
|
"Unable to find a context_id" in str(e) or
|
||||||
"Socket closed" in str(e) or
|
"Socket closed" in str(e) or
|
||||||
"Connection reset by peer" in str(e) or
|
"Connection reset by peer" in str(e) or
|
||||||
"Transport closed" in str(e))
|
"Transport closed" in str(e))
|
||||||
|
|
||||||
if isinstance(e, errors.InvalidArgumentError):
|
|
||||||
self.assertIn("Unable to find a context_id", str(e))
|
|
||||||
|
|
||||||
def testWorkerPreemptionErrorTypeWithPythonFunction(self):
|
def testWorkerPreemptionErrorTypeWithPythonFunction(self):
|
||||||
|
|
||||||
def worker_train_fn():
|
def worker_train_fn():
|
||||||
@ -271,21 +268,18 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.thread_coord.join([run_thread])
|
self.thread_coord.join([run_thread])
|
||||||
except (errors.UnavailableError, errors.InvalidArgumentError) as e:
|
except errors.UnavailableError as e:
|
||||||
logging.info("Got exception %r, error message is %s", e, e)
|
logging.info("Got exception %r, error message is %s", e, e)
|
||||||
|
|
||||||
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))
|
self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except
|
||||||
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
|
self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
|
||||||
|
|
||||||
if isinstance(e, errors.UnavailableError):
|
|
||||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||||
|
"Unable to find a context_id" in str(e) or
|
||||||
"Socket closed" in str(e) or
|
"Socket closed" in str(e) or
|
||||||
"Connection reset by peer" in str(e) or
|
"Connection reset by peer" in str(e) or
|
||||||
"Transport closed" in str(e))
|
"Transport closed" in str(e))
|
||||||
|
|
||||||
if isinstance(e, errors.InvalidArgumentError):
|
|
||||||
self.assertIn("Unable to find a context_id", str(e))
|
|
||||||
|
|
||||||
def testPSPreemptionErrorType(self):
|
def testPSPreemptionErrorType(self):
|
||||||
|
|
||||||
with ops.device("/job:ps/replica:0/task:0"):
|
with ops.device("/job:ps/replica:0/task:0"):
|
||||||
@ -309,25 +303,23 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring
|
|||||||
run_thread = threading.Thread(target=run_fn)
|
run_thread = threading.Thread(target=run_fn)
|
||||||
run_thread.start()
|
run_thread.start()
|
||||||
time.sleep(1) # Let it run a couple steps.
|
time.sleep(1) # Let it run a couple steps.
|
||||||
self._restart(5, "ps")
|
|
||||||
|
# Use a short restart delay to cover the case that RPC channel is reused
|
||||||
|
self._restart(1, "ps")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.thread_coord.join([run_thread])
|
self.thread_coord.join([run_thread])
|
||||||
except (errors.UnavailableError, errors.InvalidArgumentError,
|
except (errors.UnavailableError, errors.AbortedError) as e:
|
||||||
errors.AbortedError) as e:
|
|
||||||
logging.info("Got exception %r, error message is %s", e, e)
|
logging.info("Got exception %r, error message is %s", e, e)
|
||||||
|
self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except
|
||||||
self.assertIn(_RPC_ERROR_FROM_PS, str(e))
|
|
||||||
|
|
||||||
if isinstance(e, errors.UnavailableError):
|
if isinstance(e, errors.UnavailableError):
|
||||||
self.assertTrue("failed to connect to all addresses" in str(e) or
|
self.assertTrue("failed to connect to all addresses" in str(e) or
|
||||||
|
"Unable to find a context_id" in str(e) or
|
||||||
"Socket closed" in str(e) or
|
"Socket closed" in str(e) or
|
||||||
"Connection reset by peer" in str(e) or
|
"Connection reset by peer" in str(e) or
|
||||||
"Transport closed" in str(e))
|
"Transport closed" in str(e))
|
||||||
|
|
||||||
if isinstance(e, errors.InvalidArgumentError):
|
|
||||||
self.assertIn("Unable to find a context_id", str(e))
|
|
||||||
|
|
||||||
if isinstance(e, errors.AbortedError):
|
if isinstance(e, errors.AbortedError):
|
||||||
self.assertIn("RecvTensor expects a different device incarnation",
|
self.assertIn("RecvTensor expects a different device incarnation",
|
||||||
str(e))
|
str(e))
|
||||||
|
@ -320,6 +320,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
|||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
for _ in range(num_calls):
|
for _ in range(num_calls):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def worker_fn(i):
|
def worker_fn(i):
|
||||||
return math_ops.matmul(i, i)
|
return math_ops.matmul(i, i)
|
||||||
@ -389,10 +390,10 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
|||||||
t1_results = [None] * num_calls
|
t1_results = [None] * num_calls
|
||||||
t2_results = [None] * num_calls
|
t2_results = [None] * num_calls
|
||||||
threads = []
|
threads = []
|
||||||
threads.append(threading.Thread(target=thread_fn,
|
threads.append(
|
||||||
args=(self.device_t1, t1_results)))
|
threading.Thread(target=thread_fn, args=(self.device_t1, t1_results)))
|
||||||
threads.append(threading.Thread(target=thread_fn,
|
threads.append(
|
||||||
args=(self.device_t2, t2_results)))
|
threading.Thread(target=thread_fn, args=(self.device_t2, t2_results)))
|
||||||
threads.append(threading.Thread(target=update_server_def_fn))
|
threads.append(threading.Thread(target=update_server_def_fn))
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.start()
|
t.start()
|
||||||
@ -535,6 +536,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device(self.device_t2):
|
with ops.device(self.device_t2):
|
||||||
add = mul + i
|
add = mul + i
|
||||||
return add - i
|
return add - i
|
||||||
|
|
||||||
worker_fn.get_concrete_function(x1)
|
worker_fn.get_concrete_function(x1)
|
||||||
|
|
||||||
num_calls = 10
|
num_calls = 10
|
||||||
@ -551,13 +553,13 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self._coord.stop_on_exception():
|
with self._coord.stop_on_exception():
|
||||||
for i in range(num_calls):
|
for i in range(num_calls):
|
||||||
context.update_server_def(
|
context.update_server_def(
|
||||||
server_def=(self.server_def_s1_s2_s3
|
server_def=(self.server_def_s1_s2_s3 if i %
|
||||||
if i % 2 == 0 else self.server_def_s1_s2))
|
2 == 0 else self.server_def_s1_s2))
|
||||||
|
|
||||||
results = [None] * num_calls
|
results = [None] * num_calls
|
||||||
threads = []
|
threads = []
|
||||||
threads.append(threading.Thread(target=thread_fn,
|
threads.append(
|
||||||
args=(self.device_t1, results)))
|
threading.Thread(target=thread_fn, args=(self.device_t1, results)))
|
||||||
threads.append(threading.Thread(target=update_server_def_fn))
|
threads.append(threading.Thread(target=update_server_def_fn))
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.start()
|
t.start()
|
||||||
@ -630,9 +632,8 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
|
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0"))
|
||||||
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
|
self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1"))
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
errors.InvalidArgumentError,
|
"Unable to find worker interface"):
|
||||||
"Client for target /job:remote_device/replica:0/task:10 not found."):
|
|
||||||
context.check_alive("/job:remote_device/replica:0/task:10")
|
context.check_alive("/job:remote_device/replica:0/task:10")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user