Integrate multi-client TFRT initialization with TFE_EnableCollectiveOps C API.
PiperOrigin-RevId: 355213037 Change-Id: I006773d2cbb886065c1a015a42483971ea67348c
This commit is contained in:
parent
89c59719cf
commit
efd6ea66ea
@ -496,51 +496,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
|||||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|
||||||
TFE_Context* ctx) {
|
|
||||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
|
||||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
|
||||||
// we log the error, and then return to allow the user to see the error
|
|
||||||
// message.
|
|
||||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
|
||||||
do { \
|
|
||||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
|
||||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
|
||||||
LOG(ERROR) << _status.error_message(); \
|
|
||||||
return _status; \
|
|
||||||
} \
|
|
||||||
} while (0);
|
|
||||||
|
|
||||||
// New server created for new server_def. Unused if updating server_def.
|
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
tensorflow::GrpcServer* grpc_server =
|
|
||||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
|
||||||
if (grpc_server == nullptr) {
|
|
||||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
|
||||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
|
||||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
|
||||||
if (grpc_server == nullptr) {
|
|
||||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
|
|
||||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
|
|
||||||
}
|
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
|
||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
|
||||||
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
|
||||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
|
||||||
} else {
|
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
|
||||||
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
|
|
||||||
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
|
||||||
grpc_server->worker_env()->collective_executor_mgr.get()));
|
|
||||||
}
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
#undef LOG_AND_RETURN_IF_ERROR
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Set server_def on the context, possibly updating it.
|
// Set server_def on the context, possibly updating it.
|
||||||
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
||||||
const void* proto,
|
const void* proto,
|
||||||
@ -552,7 +507,9 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
|||||||
"Invalid tensorflow.ServerDef protocol buffer");
|
"Invalid tensorflow.ServerDef protocol buffer");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
status->status = EnableCollectiveOps(server_def, ctx);
|
status->status =
|
||||||
|
tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps(
|
||||||
|
server_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
|
||||||
|
@ -36,6 +36,12 @@ class ImmediateExecutionDistributedManager {
|
|||||||
bool reset_context,
|
bool reset_context,
|
||||||
int keep_alive_secs) = 0;
|
int keep_alive_secs) = 0;
|
||||||
|
|
||||||
|
// Set up a multi-client distributed execution environment. Must be called on
|
||||||
|
// all tasks in the cluster.
|
||||||
|
// This call internally coordinates with other tasks to initialize the eager
|
||||||
|
// context and TF server for multi-client execution.
|
||||||
|
virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0;
|
||||||
|
|
||||||
// Check if the remote task is alive.
|
// Check if the remote task is alive.
|
||||||
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||||
bool* is_alive) = 0;
|
bool* is_alive) = 0;
|
||||||
|
@ -677,6 +677,46 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef(
|
|||||||
keep_alive_secs);
|
keep_alive_secs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerContextDistributedManager::EnableCollectiveOps(
|
||||||
|
const ServerDef& server_def) {
|
||||||
|
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||||
|
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||||
|
// we log the error, and then return to allow the user to see the error
|
||||||
|
// message.
|
||||||
|
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||||
|
do { \
|
||||||
|
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||||
|
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||||
|
LOG(ERROR) << _status.error_message(); \
|
||||||
|
return _status; \
|
||||||
|
} \
|
||||||
|
} while (0);
|
||||||
|
|
||||||
|
tensorflow::GrpcServer* grpc_server =
|
||||||
|
dynamic_cast<tensorflow::GrpcServer*>(context_->GetServer());
|
||||||
|
if (grpc_server == nullptr) {
|
||||||
|
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||||
|
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||||
|
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||||
|
if (grpc_server == nullptr) {
|
||||||
|
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
|
||||||
|
"Currently, TF eager runtime only supports tensorflow::GrpcServer."));
|
||||||
|
}
|
||||||
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||||
|
|
||||||
|
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
|
||||||
|
std::move(new_server), grpc_server->worker_env()->device_mgr,
|
||||||
|
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||||
|
} else {
|
||||||
|
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||||
|
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
|
||||||
|
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
|
||||||
|
grpc_server->worker_env()->collective_executor_mgr.get()));
|
||||||
|
}
|
||||||
|
#undef LOG_AND_RETURN_IF_ERROR
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status EagerContextDistributedManager::CheckRemoteAlive(
|
Status EagerContextDistributedManager::CheckRemoteAlive(
|
||||||
const std::string& remote_task_name, bool* is_alive) {
|
const std::string& remote_task_name, bool* is_alive) {
|
||||||
*is_alive = false;
|
*is_alive = false;
|
||||||
|
@ -35,6 +35,8 @@ class EagerContextDistributedManager
|
|||||||
Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context,
|
Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context,
|
||||||
int keep_alive_secs) override;
|
int keep_alive_secs) override;
|
||||||
|
|
||||||
|
Status EnableCollectiveOps(const ServerDef& server_def) override;
|
||||||
|
|
||||||
Status CheckRemoteAlive(const std::string& remote_task_name,
|
Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||||
bool* is_alive) override;
|
bool* is_alive) override;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user