diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 08b3c73ed02..c44d0ee6873 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -97,6 +97,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core:gpu_runtime", ] + internal_tfrt_deps(), alwayslink = 1, diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5f388bfe0cd..3418bccf050 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -70,6 +70,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #endif // !IS_MOBILE_PLATFORM #include "tensorflow/core/framework/node_def_util.h" @@ -855,41 +856,42 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - // TODO(yuefengz): support partially specified `worker_name`. - tensorflow::core::RefCountPtr eager_client; - status->status = context->GetClient(worker_name, &eager_client); - if (!status->status.ok()) { + tensorflow::GrpcServer* grpc_server = + dynamic_cast(context->GetServer()); + if (grpc_server == nullptr) { + status->status = + tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer."); + return false; + } + tensorflow::WorkerInterface* wi = + grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name); + if (wi == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Unable to find worker interface corresponding to task ", worker_name); return false; } - // Send a rpc request to the worker to check aliveness. - tensorflow::eager::KeepAliveRequest request; - request.set_context_id(context->GetContextId()); - tensorflow::eager::KeepAliveResponse response; - - tensorflow::Status keep_alive_status; + tensorflow::GetStatusRequest request; + tensorflow::GetStatusResponse response; + tensorflow::Status remote_status; tensorflow::Notification done; - eager_client->KeepAliveAsync( - &request, &response, - [&keep_alive_status, &done](const tensorflow::Status& s) { - keep_alive_status = s; - done.Notify(); - }); + wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, + [&remote_status, &done](const tensorflow::Status& s) { + remote_status = s; + done.Notify(); + }); done.WaitForNotification(); + // We set OK status so the call does not raise any exceptions. Instead, caller + // users the return value to tell if the remote worker is alive. status->status = tensorflow::Status::OK(); - // If `context_id` doesn't exist on the remote worker, an InvalidArgument - // error will return. But this still indicates that the remote worker is - // alive. - if (keep_alive_status.ok() || - keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) { + if (remote_status.ok()) { return true; - } else { - LOG(INFO) << "Remote worker " << worker_name - << " is not alive: " << keep_alive_status.error_message(); - return false; } + LOG(INFO) << "Remote worker " << worker_name + << " is not alive: " << remote_status.error_message(); + return false; #endif // !IS_MOBILE_PLATFORM } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 2138ecdfe95..ff44642c68e 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -752,7 +752,7 @@ tensorflow::Status EagerServiceImpl::GetServerContext( auto iter = contexts_.find(context_id); if (iter == contexts_.end()) { *server_context = nullptr; - return errors::InvalidArgument(strings::Printf( + return errors::Unavailable(strings::Printf( "Unable to find a context_id matching the specified one " "(%llu). Perhaps the worker was restarted, or the context was GC'd?", static_cast(context_id))); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 9d35ddf08f7..4a97be5c0c4 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -1248,7 +1248,7 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) { // Unable to handle the request since there is no eager context. Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, &remote_enqueue_response); - EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); + EXPECT_EQ(error::UNAVAILABLE, status.code()); EXPECT_TRUE(absl::StrContains( status.error_message(), "Unable to find a context_id matching the specified one")); @@ -1285,7 +1285,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) { Status status = eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); - EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + EXPECT_EQ(status.code(), error::UNAVAILABLE); EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", status.error_message()); diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator.py b/tensorflow/python/distribute/coordinator/cluster_coordinator.py index 86dcad0d382..a9ef50614cb 100644 --- a/tensorflow/python/distribute/coordinator/cluster_coordinator.py +++ b/tensorflow/python/distribute/coordinator/cluster_coordinator.py @@ -1343,10 +1343,8 @@ def _extract_failed_ps_instances(err_msg): def _is_ps_failure(error): """Whether the error is considered a parameter server failure.""" - if (_RPC_ERROR_FROM_PS in str(error) or - (isinstance(error, errors.InvalidArgumentError) and - "/job:ps" in str(error))): - return True + return (isinstance(error, errors.UnavailableError) and + _RPC_ERROR_FROM_PS in str(error)) def _is_worker_failure(error): @@ -1366,8 +1364,7 @@ def _is_worker_failure(error): # failure. In that case, gRPC allows channel (which is different from a # connection) to be reused for a replaced server listening to same address. if isinstance(error, errors.InvalidArgumentError): - if ("Unable to find a context_id" in str(error) or - "unknown device" in str(error) or + if ("unknown device" in str(error) or "Unable to find the relevant tensor remote_handle" in str(error)): # TODO(b/159961667): Fix "Unable to find the relevant tensor # remote_handle" part. diff --git a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py index 9bba4d74927..cc075d09c3d 100644 --- a/tensorflow/python/distribute/coordinator/fault_tolerance_test.py +++ b/tensorflow/python/distribute/coordinator/fault_tolerance_test.py @@ -234,20 +234,17 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring try: self.thread_coord.join([run_thread]) - except (errors.UnavailableError, errors.InvalidArgumentError) as e: + except errors.UnavailableError as e: logging.info("Got exception %r, error message is %s", e, e) - self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) + self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) - if isinstance(e, errors.UnavailableError): - self.assertTrue("failed to connect to all addresses" in str(e) or - "Socket closed" in str(e) or - "Connection reset by peer" in str(e) or - "Transport closed" in str(e)) - - if isinstance(e, errors.InvalidArgumentError): - self.assertIn("Unable to find a context_id", str(e)) + self.assertTrue("failed to connect to all addresses" in str(e) or + "Unable to find a context_id" in str(e) or + "Socket closed" in str(e) or + "Connection reset by peer" in str(e) or + "Transport closed" in str(e)) def testWorkerPreemptionErrorTypeWithPythonFunction(self): @@ -271,20 +268,17 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring try: self.thread_coord.join([run_thread]) - except (errors.UnavailableError, errors.InvalidArgumentError) as e: + except errors.UnavailableError as e: logging.info("Got exception %r, error message is %s", e, e) - self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) + self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) - if isinstance(e, errors.UnavailableError): - self.assertTrue("failed to connect to all addresses" in str(e) or - "Socket closed" in str(e) or - "Connection reset by peer" in str(e) or - "Transport closed" in str(e)) - - if isinstance(e, errors.InvalidArgumentError): - self.assertIn("Unable to find a context_id", str(e)) + self.assertTrue("failed to connect to all addresses" in str(e) or + "Unable to find a context_id" in str(e) or + "Socket closed" in str(e) or + "Connection reset by peer" in str(e) or + "Transport closed" in str(e)) def testPSPreemptionErrorType(self): @@ -309,25 +303,23 @@ class FaultToleranceTest(test.TestCase): # pylint: disable=missing-docstring run_thread = threading.Thread(target=run_fn) run_thread.start() time.sleep(1) # Let it run a couple steps. - self._restart(5, "ps") + + # Use a short restart delay to cover the case that RPC channel is reused + self._restart(1, "ps") try: self.thread_coord.join([run_thread]) - except (errors.UnavailableError, errors.InvalidArgumentError, - errors.AbortedError) as e: + except (errors.UnavailableError, errors.AbortedError) as e: logging.info("Got exception %r, error message is %s", e, e) - - self.assertIn(_RPC_ERROR_FROM_PS, str(e)) + self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except if isinstance(e, errors.UnavailableError): self.assertTrue("failed to connect to all addresses" in str(e) or + "Unable to find a context_id" in str(e) or "Socket closed" in str(e) or "Connection reset by peer" in str(e) or "Transport closed" in str(e)) - if isinstance(e, errors.InvalidArgumentError): - self.assertIn("Unable to find a context_id", str(e)) - if isinstance(e, errors.AbortedError): self.assertIn("RecvTensor expects a different device incarnation", str(e)) diff --git a/tensorflow/python/eager/remote_cluster_test.py b/tensorflow/python/eager/remote_cluster_test.py index 84dbb11361a..e533ab8577d 100644 --- a/tensorflow/python/eager/remote_cluster_test.py +++ b/tensorflow/python/eager/remote_cluster_test.py @@ -320,6 +320,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): t.start() for _ in range(num_calls): + @def_function.function def worker_fn(i): return math_ops.matmul(i, i) @@ -389,10 +390,10 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): t1_results = [None] * num_calls t2_results = [None] * num_calls threads = [] - threads.append(threading.Thread(target=thread_fn, - args=(self.device_t1, t1_results))) - threads.append(threading.Thread(target=thread_fn, - args=(self.device_t2, t2_results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t1, t1_results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t2, t2_results))) threads.append(threading.Thread(target=update_server_def_fn)) for t in threads: t.start() @@ -535,6 +536,7 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): with ops.device(self.device_t2): add = mul + i return add - i + worker_fn.get_concrete_function(x1) num_calls = 10 @@ -551,13 +553,13 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): with self._coord.stop_on_exception(): for i in range(num_calls): context.update_server_def( - server_def=(self.server_def_s1_s2_s3 - if i % 2 == 0 else self.server_def_s1_s2)) + server_def=(self.server_def_s1_s2_s3 if i % + 2 == 0 else self.server_def_s1_s2)) results = [None] * num_calls threads = [] - threads.append(threading.Thread(target=thread_fn, - args=(self.device_t1, results))) + threads.append( + threading.Thread(target=thread_fn, args=(self.device_t1, results))) threads.append(threading.Thread(target=update_server_def_fn)) for t in threads: t.start() @@ -630,9 +632,8 @@ class DynamicClusterTest(test.TestCase, parameterized.TestCase): self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:0")) self.assertTrue(context.check_alive("/job:remote_device/replica:0/task:1")) - with self.assertRaisesRegex( - errors.InvalidArgumentError, - "Client for target /job:remote_device/replica:0/task:10 not found."): + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Unable to find worker interface"): context.check_alive("/job:remote_device/replica:0/task:10")