diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 8703eebd359..57be9497dac 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -342,7 +342,7 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "c_api_remote_test", - size = "small", + size = "medium", srcs = [ "c_api_remote_test.cc", ], @@ -364,6 +364,8 @@ tf_cuda_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/platform:env", + "@com_google_absl//absl/debugging:leak_check", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index be85a239378..4ef178eb30c 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -500,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( grpc_server->master_env()->worker_cache->GetEagerClientCache( &remote_eager_workers)); + // For cluster update, use a status group to aggregate statuses from + // * adding and removing remote devices + // * creating remote contexts on newly added workers + // * updating remote contexts on existing workers + // * updating the master context + // Note that we should not return immediately on errors in the middle of these + // updates to prevent cluster from having inconsistent context views. + // + // Unused if `reset_context` is True. + tensorflow::StatusGroup sg; + // When updating an existing context, populate the following lists with: // * added_workers: set(remote_workers) - set(curr_remote_workers) // * removed_workers: set(curr_remote_workers) - set(remote_workers) @@ -535,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( DifferentiateWorkerLists(&curr_remote_workers, &remote_workers, &added_workers, &removed_workers, &existing_workers); - LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers( + sg.Update(GetReplacedFromExistingWorkers( &existing_workers, context_id, context->GetContextViewId(), server_def, remote_eager_workers.get(), &replaced_workers)); if (VLOG_IS_ON(1)) { @@ -559,11 +570,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( existing_workers.end()); } } - LOG_AND_RETURN_IF_ERROR( - RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); - LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr( - added_workers, grpc_server->master_env()->worker_cache, - remote_device_mgr)); + sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); + sg.Update(AddRemoteDevicesToMgr(added_workers, + grpc_server->master_env()->worker_cache, + remote_device_mgr)); } std::vector<tensorflow::DeviceAttributes> cluster_device_attributes; @@ -584,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( } // Initialize remote eager workers. - // TODO(b/138847548) Create remote eager contexts in async mode by default. if (reset_context) { LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( ctx, remote_workers, context_id, context_view_id, keep_alive_secs, @@ -596,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // existing workers to also have the updated context_view_id, so // we must set their context_view_id to the existing master's // context_view_id + 1. - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + sg.Update(CreateRemoteContexts( ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); @@ -606,10 +615,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( VLOG(1) << "Updating cluster with existing worker " << w; } } - LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( - ctx, existing_workers, added_workers, removed_workers, context_id, - context_view_id + 1, server_def, remote_eager_workers.get(), - base_request)); + sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, + removed_workers, context_id, + context_view_id + 1, server_def, + remote_eager_workers.get(), base_request)); } } @@ -645,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // GrpcServer cannot be destroyed after it is started. LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); } else { - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->UpdateSession( - session_name, server_def, base_request.cluster_device_attributes(), - /*isolate_session_state=*/true)); - LOG_AND_RETURN_IF_ERROR( - context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers), - added_workers, removed_workers)); + sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession( + session_name, server_def, base_request.cluster_device_attributes(), + /*isolate_session_state=*/true)); + sg.Update(context->UpdateRemoteMaster(context_id, + std::move(remote_eager_workers), + added_workers, removed_workers)); + LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); } #undef LOG_AND_RETURN_IF_ERROR diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 7c6836af69b..ad3c4da75aa 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/debugging/leak_check.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -21,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" @@ -527,4 +529,124 @@ TEST(CAPI, RemoteExecuteChangeServerDefAsync) { TestRemoteExecuteChangeServerDef(true); } +void TestRemoteExecuteUpdateServerDef(bool async) { + // TODO(b/136478427): Skip heap checker for leaked gRPC server instances. + absl::LeakCheckDisabler disabler; + + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr<tensorflow::GrpcServer> worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteUpdateServerDef) { + TestRemoteExecuteUpdateServerDef(false); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { + TestRemoteExecuteUpdateServerDef(true); +} + +void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { + // TODO(b/136478427): Skip heap checker for leaked gRPC server instances. + absl::LeakCheckDisabler disabler; + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr<tensorflow::GrpcServer> worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + // Adding a non-existent remote worker to cluster def. This should cause the + // UpdateServerDef call to fail. + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {2, tensorflow::strings::StrCat("localhost:", port)}); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Even after the prevoiusly failed cluster update, another update and op + // execution should work fine as long as the provided server_def is valid. + TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) { + TestRemoteExecuteUpdateServerDefWithFailures(false); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { + TestRemoteExecuteUpdateServerDefWithFailures(true); +} + } // namespace