Integrate multi-client TFRT initialization with TFE_EnableCollectiveOps C API.
PiperOrigin-RevId: 355213037 Change-Id: I006773d2cbb886065c1a015a42483971ea67348c
This commit is contained in:
parent
89c59719cf
commit
efd6ea66ea
@ -496,51 +496,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
namespace {
|
||||
tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
if (grpc_server == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
const void* proto,
|
||||
@ -552,7 +507,9 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
status->status = EnableCollectiveOps(server_def, ctx);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps(
|
||||
server_def);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||
|
@ -36,6 +36,12 @@ class ImmediateExecutionDistributedManager {
|
||||
bool reset_context,
|
||||
int keep_alive_secs) = 0;
|
||||
|
||||
// Set up a multi-client distributed execution environment. Must be called on
|
||||
// all tasks in the cluster.
|
||||
// This call internally coordinates with other tasks to initialize the eager
|
||||
// context and TF server for multi-client execution.
|
||||
virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0;
|
||||
|
||||
// Check if the remote task is alive.
|
||||
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) = 0;
|
||||
|
@ -677,6 +677,46 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef(
|
||||
keep_alive_secs);
|
||||
}
|
||||
|
||||
Status EagerContextDistributedManager::EnableCollectiveOps(
|
||||
const ServerDef& server_def) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context_->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
if (grpc_server == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
|
||||
"Currently, TF eager runtime only supports tensorflow::GrpcServer."));
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
|
||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
|
||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerContextDistributedManager::CheckRemoteAlive(
|
||||
const std::string& remote_task_name, bool* is_alive) {
|
||||
*is_alive = false;
|
||||
|
@ -35,6 +35,8 @@ class EagerContextDistributedManager
|
||||
Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context,
|
||||
int keep_alive_secs) override;
|
||||
|
||||
Status EnableCollectiveOps(const ServerDef& server_def) override;
|
||||
|
||||
Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) override;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user