STT-tensorflow/tensorflow/c/experimental/network.cc

167 lines
6.4 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <memory>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
using tensorflow::ServerFactory;
namespace tensorflow {
/* static */ Status CGrpcServer::Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server) {
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
join_function, delete_function);
GrpcServerOptions options;
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
return new CRendezvousMgr(env, rendezvous_builder);
};
TF_RETURN_IF_ERROR(grpc_server->Init(options));
TF_Status* tf_status = TF_NewStatus();
grpc_server->SetContext(init_function(
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
TF_RETURN_IF_ERROR(tf_status->status);
TF_DeleteStatus(tf_status);
out_server->reset(grpc_server);
return Status::OK();
}
Status CGrpcServer::Start() {
Status status = GrpcServer::Start();
TF_Status* tf_status = TF_NewStatus();
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Stop() {
Status status = GrpcServer::Stop();
TF_Status* tf_status = TF_NewStatus();
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Join() {
Status status = GrpcServer::Join();
TF_Status* tf_status = TF_NewStatus();
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
namespace {
// Factory that creates CGrpcServer instances.
class CServerFactory : public ServerFactory {
public:
CServerFactory(bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*,
TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder)
: accept_function_(accept_function),
init_function_(init_function),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def, const Options& options,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,
join_function_, delete_function_, rendezvous_builder_, out_server));
return Status::OK();
}
// Returns true if and only if this factory can create a server
// based on the given `server_def`.
bool AcceptsOptions(const ServerDef& server_def) override {
return (*accept_function_)(server_def.protocol().c_str());
}
private:
bool (*accept_function_)(const char* protocol);
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
TF_RemoteRendezvousBuilder* rendezvous_builder_;
};
} // namespace
} // namespace tensorflow
// Server factory representation to use in C API.
// Holds CServerFactory pointer.
struct TF_GrpcServerFactory {
::tensorflow::CServerFactory* factory;
};
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder) {
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
server_factory->factory = new ::tensorflow::CServerFactory(
accept_function, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
return server_factory;
}
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
DCHECK_NE(server_factory, nullptr);
delete server_factory;
}
void TF_RegisterGrpcServerFactory(const char* server_type,
TF_GrpcServerFactory* server_factory) {
ServerFactory::Register(server_type, server_factory->factory);
}