From efd6ea66ea2b7f2a43693b9ab70d70260d428b43 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Tue, 2 Feb 2021 11:31:06 -0800 Subject: [PATCH] Integrate multi-client TFRT initialization with TFE_EnableCollectiveOps C API. PiperOrigin-RevId: 355213037 Change-Id: I006773d2cbb886065c1a015a42483971ea67348c --- tensorflow/c/c_api_experimental.cc | 49 ++----------------- .../immediate_execution_distributed_manager.h | 6 +++ .../eager/context_distributed_manager.cc | 40 +++++++++++++++ .../eager/context_distributed_manager.h | 2 + 4 files changed, 51 insertions(+), 46 deletions(-) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index e9734427bb0..417b13e1d51 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -496,51 +496,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } -namespace { -tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, - TFE_Context* ctx) { - // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the - // server object (which currently CHECK-fails) and we miss the error, instead, - // we log the error, and then return to allow the user to see the error - // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const ::tensorflow::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - LOG(ERROR) << _status.error_message(); \ - return _status; \ - } \ - } while (0); - - // New server created for new server_def. Unused if updating server_def. - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server = - dynamic_cast(context->GetServer()); - if (grpc_server == nullptr) { - std::unique_ptr new_server; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); - grpc_server = dynamic_cast(new_server.get()); - if (grpc_server == nullptr) { - LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( - "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); - } - LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - - LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( - std::move(new_server), grpc_server->worker_env()->device_mgr, - grpc_server->worker_env()->collective_executor_mgr.get())); - } else { - LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); - LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( - /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, - grpc_server->worker_env()->collective_executor_mgr.get())); - } - return tensorflow::Status::OK(); -#undef LOG_AND_RETURN_IF_ERROR -} -} // namespace - // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, const void* proto, @@ -552,7 +507,9 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - status->status = EnableCollectiveOps(server_def, ctx); + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps( + server_def); } TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h index 65b3008146c..b43649a5966 100644 --- a/tensorflow/c/eager/immediate_execution_distributed_manager.h +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -36,6 +36,12 @@ class ImmediateExecutionDistributedManager { bool reset_context, int keep_alive_secs) = 0; + // Set up a multi-client distributed execution environment. Must be called on + // all tasks in the cluster. + // This call internally coordinates with other tasks to initialize the eager + // context and TF server for multi-client execution. + virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; + // Check if the remote task is alive. virtual Status CheckRemoteAlive(const std::string& remote_task_name, bool* is_alive) = 0; diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index aa6932a763a..079b7700258 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -677,6 +677,46 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( keep_alive_secs); } +Status EagerContextDistributedManager::EnableCollectiveOps( + const ServerDef& server_def) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(context_->GetServer()); + if (grpc_server == nullptr) { + std::unique_ptr new_server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); + grpc_server = dynamic_cast(new_server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TF eager runtime only supports tensorflow::GrpcServer.")); + } + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer( + std::move(new_server), grpc_server->worker_env()->device_mgr, + grpc_server->worker_env()->collective_executor_mgr.get())); + } else { + LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); + LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer( + /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, + grpc_server->worker_env()->collective_executor_mgr.get())); + } +#undef LOG_AND_RETURN_IF_ERROR + return Status::OK(); +} + Status EagerContextDistributedManager::CheckRemoteAlive( const std::string& remote_task_name, bool* is_alive) { *is_alive = false; diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/tensorflow/core/common_runtime/eager/context_distributed_manager.h index 58a304bb1e9..73121043788 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.h +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -35,6 +35,8 @@ class EagerContextDistributedManager Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, int keep_alive_secs) override; + Status EnableCollectiveOps(const ServerDef& server_def) override; + Status CheckRemoteAlive(const std::string& remote_task_name, bool* is_alive) override;