Integrate multi-client TFRT initialization with TFE_EnableCollectiveOps C API.
PiperOrigin-RevId: 355213037 Change-Id: I006773d2cbb886065c1a015a42483971ea67348c
This commit is contained in:
		
							parent
							
								
									89c59719cf
								
							
						
					
					
						commit
						efd6ea66ea
					
				| @ -496,51 +496,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, | |||||||
|   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); |   return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| namespace { |  | ||||||
| tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, |  | ||||||
|                                        TFE_Context* ctx) { |  | ||||||
|   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
 |  | ||||||
|   // server object (which currently CHECK-fails) and we miss the error, instead,
 |  | ||||||
|   // we log the error, and then return to allow the user to see the error
 |  | ||||||
|   // message.
 |  | ||||||
| #define LOG_AND_RETURN_IF_ERROR(...)                    \ |  | ||||||
|   do {                                                  \ |  | ||||||
|     const ::tensorflow::Status _status = (__VA_ARGS__); \ |  | ||||||
|     if (TF_PREDICT_FALSE(!_status.ok())) {              \ |  | ||||||
|       LOG(ERROR) << _status.error_message();            \ |  | ||||||
|       return _status;                                   \ |  | ||||||
|     }                                                   \ |  | ||||||
|   } while (0); |  | ||||||
| 
 |  | ||||||
|   // New server created for new server_def. Unused if updating server_def.
 |  | ||||||
|   tensorflow::EagerContext* context = |  | ||||||
|       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |  | ||||||
|   tensorflow::GrpcServer* grpc_server = |  | ||||||
|       dynamic_cast<tensorflow::GrpcServer*>(context->GetServer()); |  | ||||||
|   if (grpc_server == nullptr) { |  | ||||||
|     std::unique_ptr<tensorflow::ServerInterface> new_server; |  | ||||||
|     LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); |  | ||||||
|     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get()); |  | ||||||
|     if (grpc_server == nullptr) { |  | ||||||
|       LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( |  | ||||||
|           "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); |  | ||||||
|     } |  | ||||||
|     LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); |  | ||||||
| 
 |  | ||||||
|     LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( |  | ||||||
|         std::move(new_server), grpc_server->worker_env()->device_mgr, |  | ||||||
|         grpc_server->worker_env()->collective_executor_mgr.get())); |  | ||||||
|   } else { |  | ||||||
|     LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); |  | ||||||
|     LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( |  | ||||||
|         /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, |  | ||||||
|         grpc_server->worker_env()->collective_executor_mgr.get())); |  | ||||||
|   } |  | ||||||
|   return tensorflow::Status::OK(); |  | ||||||
| #undef LOG_AND_RETURN_IF_ERROR |  | ||||||
| } |  | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| // Set server_def on the context, possibly updating it.
 | // Set server_def on the context, possibly updating it.
 | ||||||
| TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, | TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, | ||||||
|                                                    const void* proto, |                                                    const void* proto, | ||||||
| @ -552,7 +507,9 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, | |||||||
|         "Invalid tensorflow.ServerDef protocol buffer"); |         "Invalid tensorflow.ServerDef protocol buffer"); | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|   status->status = EnableCollectiveOps(server_def, ctx); |   status->status = | ||||||
|  |       tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps( | ||||||
|  |           server_def); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, | TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, | ||||||
|  | |||||||
| @ -36,6 +36,12 @@ class ImmediateExecutionDistributedManager { | |||||||
|                                       bool reset_context, |                                       bool reset_context, | ||||||
|                                       int keep_alive_secs) = 0; |                                       int keep_alive_secs) = 0; | ||||||
| 
 | 
 | ||||||
|  |   // Set up a multi-client distributed execution environment. Must be called on
 | ||||||
|  |   // all tasks in the cluster.
 | ||||||
|  |   // This call internally coordinates with other tasks to initialize the eager
 | ||||||
|  |   // context and TF server for multi-client execution.
 | ||||||
|  |   virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; | ||||||
|  | 
 | ||||||
|   // Check if the remote task is alive.
 |   // Check if the remote task is alive.
 | ||||||
|   virtual Status CheckRemoteAlive(const std::string& remote_task_name, |   virtual Status CheckRemoteAlive(const std::string& remote_task_name, | ||||||
|                                   bool* is_alive) = 0; |                                   bool* is_alive) = 0; | ||||||
|  | |||||||
| @ -677,6 +677,46 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( | |||||||
|                                     keep_alive_secs); |                                     keep_alive_secs); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Status EagerContextDistributedManager::EnableCollectiveOps( | ||||||
|  |     const ServerDef& server_def) { | ||||||
|  |   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
 | ||||||
|  |   // server object (which currently CHECK-fails) and we miss the error, instead,
 | ||||||
|  |   // we log the error, and then return to allow the user to see the error
 | ||||||
|  |   // message.
 | ||||||
|  | #define LOG_AND_RETURN_IF_ERROR(...)                    \ | ||||||
|  |   do {                                                  \ | ||||||
|  |     const ::tensorflow::Status _status = (__VA_ARGS__); \ | ||||||
|  |     if (TF_PREDICT_FALSE(!_status.ok())) {              \ | ||||||
|  |       LOG(ERROR) << _status.error_message();            \ | ||||||
|  |       return _status;                                   \ | ||||||
|  |     }                                                   \ | ||||||
|  |   } while (0); | ||||||
|  | 
 | ||||||
|  |   tensorflow::GrpcServer* grpc_server = | ||||||
|  |       dynamic_cast<tensorflow::GrpcServer*>(context_->GetServer()); | ||||||
|  |   if (grpc_server == nullptr) { | ||||||
|  |     std::unique_ptr<tensorflow::ServerInterface> new_server; | ||||||
|  |     LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); | ||||||
|  |     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get()); | ||||||
|  |     if (grpc_server == nullptr) { | ||||||
|  |       LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( | ||||||
|  |           "Currently, TF eager runtime only supports tensorflow::GrpcServer.")); | ||||||
|  |     } | ||||||
|  |     LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); | ||||||
|  | 
 | ||||||
|  |     LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer( | ||||||
|  |         std::move(new_server), grpc_server->worker_env()->device_mgr, | ||||||
|  |         grpc_server->worker_env()->collective_executor_mgr.get())); | ||||||
|  |   } else { | ||||||
|  |     LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); | ||||||
|  |     LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer( | ||||||
|  |         /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, | ||||||
|  |         grpc_server->worker_env()->collective_executor_mgr.get())); | ||||||
|  |   } | ||||||
|  | #undef LOG_AND_RETURN_IF_ERROR | ||||||
|  |   return Status::OK(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| Status EagerContextDistributedManager::CheckRemoteAlive( | Status EagerContextDistributedManager::CheckRemoteAlive( | ||||||
|     const std::string& remote_task_name, bool* is_alive) { |     const std::string& remote_task_name, bool* is_alive) { | ||||||
|   *is_alive = false; |   *is_alive = false; | ||||||
|  | |||||||
| @ -35,6 +35,8 @@ class EagerContextDistributedManager | |||||||
|   Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, |   Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, | ||||||
|                               int keep_alive_secs) override; |                               int keep_alive_secs) override; | ||||||
| 
 | 
 | ||||||
|  |   Status EnableCollectiveOps(const ServerDef& server_def) override; | ||||||
|  | 
 | ||||||
|   Status CheckRemoteAlive(const std::string& remote_task_name, |   Status CheckRemoteAlive(const std::string& remote_task_name, | ||||||
|                           bool* is_alive) override; |                           bool* is_alive) override; | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user