Integrate multi-client TFRT initialization with TFE_EnableCollectiveOps C API.

PiperOrigin-RevId: 355213037
Change-Id: I006773d2cbb886065c1a015a42483971ea67348c
This commit is contained in:
Haoyu Zhang 2021-02-02 11:31:06 -08:00 committed by TensorFlower Gardener
parent 89c59719cf
commit efd6ea66ea
4 changed files with 51 additions and 46 deletions

View File

@ -496,51 +496,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
}
namespace {
tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
if (grpc_server == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
"Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
}
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr.get()));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr.get()));
}
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
}
} // namespace
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
const void* proto,
@ -552,7 +507,9 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
status->status = EnableCollectiveOps(server_def, ctx);
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps(
server_def);
}
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,

View File

@ -36,6 +36,12 @@ class ImmediateExecutionDistributedManager {
bool reset_context,
int keep_alive_secs) = 0;
// Set up a multi-client distributed execution environment. Must be called on
// all tasks in the cluster.
// This call internally coordinates with other tasks to initialize the eager
// context and TF server for multi-client execution.
virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0;
// Check if the remote task is alive.
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
bool* is_alive) = 0;

View File

@ -677,6 +677,46 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef(
keep_alive_secs);
}
Status EagerContextDistributedManager::EnableCollectiveOps(
const ServerDef& server_def) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context_->GetServer());
if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
if (grpc_server == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
"Currently, TF eager runtime only supports tensorflow::GrpcServer."));
}
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr.get()));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr.get()));
}
#undef LOG_AND_RETURN_IF_ERROR
return Status::OK();
}
Status EagerContextDistributedManager::CheckRemoteAlive(
const std::string& remote_task_name, bool* is_alive) {
*is_alive = false;

View File

@ -35,6 +35,8 @@ class EagerContextDistributedManager
Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context,
int keep_alive_secs) override;
Status EnableCollectiveOps(const ServerDef& server_def) override;
Status CheckRemoteAlive(const std::string& remote_task_name,
bool* is_alive) override;