First check-in for V1 of Networking C API.
PiperOrigin-RevId: 245293662
This commit is contained in:
parent
17db87c632
commit
606c13f5a6
122
tensorflow/c/experimental/BUILD
Normal file
122
tensorflow/c/experimental/BUILD
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Description:
|
||||||
|
# Experimental C APIs for TensorFlow.
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_copts",
|
||||||
|
"tf_cuda_library",
|
||||||
|
)
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "rendezvous_internal",
|
||||||
|
srcs = [
|
||||||
|
"rendezvous.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"rendezvous.h",
|
||||||
|
"rendezvous_internal.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:c_api_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||||
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "rendezvous",
|
||||||
|
hdrs = [
|
||||||
|
"rendezvous.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":rendezvous_internal",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "network_internal",
|
||||||
|
srcs = [
|
||||||
|
"network.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"network.h",
|
||||||
|
"network_internal.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//tensorflow/c:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
":rendezvous_internal",
|
||||||
|
"//tensorflow/c:c_api_internal",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "network",
|
||||||
|
hdrs = [
|
||||||
|
"network.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":network_internal",
|
||||||
|
":rendezvous",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "network_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["network_test.cc"],
|
||||||
|
tags = ["noasan"],
|
||||||
|
# We must ensure that the dependencies can be dynamically linked since
|
||||||
|
# the shared library must be able to use core:framework.
|
||||||
|
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
deps = [
|
||||||
|
":network",
|
||||||
|
":network_internal",
|
||||||
|
":rendezvous",
|
||||||
|
":rendezvous_internal",
|
||||||
|
"//tensorflow/c:c_api",
|
||||||
|
"//tensorflow/c:env",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||||
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
|
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||||
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
|
"//tensorflow/core/distributed_runtime:worker_session",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
|
"@com_google_absl//absl/time",
|
||||||
|
],
|
||||||
|
)
|
166
tensorflow/c/experimental/network.cc
Normal file
166
tensorflow/c/experimental/network.cc
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/network.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/experimental/network_internal.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous_internal.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
using tensorflow::ServerFactory;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
/* static */ Status CGrpcServer::Create(
|
||||||
|
const ServerDef& server_def,
|
||||||
|
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||||
|
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*delete_function)(void*),
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder,
|
||||||
|
std::unique_ptr<ServerInterface>* out_server) {
|
||||||
|
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
|
||||||
|
join_function, delete_function);
|
||||||
|
|
||||||
|
GrpcServerOptions options;
|
||||||
|
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
|
||||||
|
return new CRendezvousMgr(env, rendezvous_builder);
|
||||||
|
};
|
||||||
|
TF_RETURN_IF_ERROR(grpc_server->Init(options));
|
||||||
|
TF_Status* tf_status = TF_NewStatus();
|
||||||
|
grpc_server->SetContext(init_function(
|
||||||
|
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
|
||||||
|
TF_RETURN_IF_ERROR(tf_status->status);
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
|
||||||
|
out_server->reset(grpc_server);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CGrpcServer::Start() {
|
||||||
|
Status status = GrpcServer::Start();
|
||||||
|
TF_Status* tf_status = TF_NewStatus();
|
||||||
|
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
|
||||||
|
tf_status);
|
||||||
|
status.Update(tf_status->status);
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CGrpcServer::Stop() {
|
||||||
|
Status status = GrpcServer::Stop();
|
||||||
|
TF_Status* tf_status = TF_NewStatus();
|
||||||
|
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
|
||||||
|
tf_status);
|
||||||
|
status.Update(tf_status->status);
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CGrpcServer::Join() {
|
||||||
|
Status status = GrpcServer::Join();
|
||||||
|
TF_Status* tf_status = TF_NewStatus();
|
||||||
|
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
|
||||||
|
tf_status);
|
||||||
|
status.Update(tf_status->status);
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Factory that creates CGrpcServer instances.
|
||||||
|
class CServerFactory : public ServerFactory {
|
||||||
|
public:
|
||||||
|
CServerFactory(bool (*accept_function)(const char*),
|
||||||
|
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||||
|
void (*start_function)(const TF_GrpcServer*, void*,
|
||||||
|
TF_Status*),
|
||||||
|
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*delete_function)(void*),
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder)
|
||||||
|
: accept_function_(accept_function),
|
||||||
|
init_function_(init_function),
|
||||||
|
start_function_(start_function),
|
||||||
|
stop_function_(stop_function),
|
||||||
|
join_function_(join_function),
|
||||||
|
delete_function_(delete_function),
|
||||||
|
rendezvous_builder_(rendezvous_builder) {}
|
||||||
|
|
||||||
|
Status NewServer(const ServerDef& server_def,
|
||||||
|
std::unique_ptr<ServerInterface>* out_server) override {
|
||||||
|
TF_RETURN_IF_ERROR(CGrpcServer::Create(
|
||||||
|
server_def, init_function_, start_function_, stop_function_,
|
||||||
|
join_function_, delete_function_, rendezvous_builder_, out_server));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if and only if this factory can create a server
|
||||||
|
// based on the given `server_def`.
|
||||||
|
bool AcceptsOptions(const ServerDef& server_def) override {
|
||||||
|
return (*accept_function_)(server_def.protocol().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool (*accept_function_)(const char* protocol);
|
||||||
|
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
|
||||||
|
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||||
|
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||||
|
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||||
|
void (*delete_function_)(void*);
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder_;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
// Server factory representation to use in C API.
|
||||||
|
// Holds CServerFactory pointer.
|
||||||
|
struct TF_GrpcServerFactory {
|
||||||
|
::tensorflow::CServerFactory* factory;
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
|
||||||
|
bool (*accept_function)(const char*),
|
||||||
|
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||||
|
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*delete_function)(void*),
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder) {
|
||||||
|
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
|
||||||
|
server_factory->factory = new ::tensorflow::CServerFactory(
|
||||||
|
accept_function, init_function, start_function, stop_function,
|
||||||
|
join_function, delete_function, rendezvous_builder);
|
||||||
|
return server_factory;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
|
||||||
|
DCHECK_NE(server_factory, nullptr);
|
||||||
|
delete server_factory;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_RegisterGrpcServerFactory(const char* server_type,
|
||||||
|
TF_GrpcServerFactory* server_factory) {
|
||||||
|
ServerFactory::Register(server_type, server_factory->factory);
|
||||||
|
}
|
97
tensorflow/c/experimental/network.h
Normal file
97
tensorflow/c/experimental/network.h
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// C API for TensorFlow Networking.
|
||||||
|
// NOTE: This API is unstable and almost certainly will change in the near
|
||||||
|
// future.
|
||||||
|
//
|
||||||
|
// Users wishing to register a custom GrpcServer should call
|
||||||
|
// TF_NewServerFactory and then TF_RegisterGrpcServerFactory.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// ```c++
|
||||||
|
// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
|
||||||
|
// rendezvous_init_function,
|
||||||
|
// receive_from_remote_async_function,
|
||||||
|
// rendezvous_delete_function);
|
||||||
|
//
|
||||||
|
// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
|
||||||
|
// accept_function,
|
||||||
|
// init_function,
|
||||||
|
// start_function,
|
||||||
|
// stop_function,
|
||||||
|
// join_function,
|
||||||
|
// delete_function,
|
||||||
|
// rendezvous_builder);
|
||||||
|
// TF_RegisterGrpcServerFactory("customfactory", factory);
|
||||||
|
// ...
|
||||||
|
// TF_DeleteGrpcServerFactory(factory);
|
||||||
|
// ```
|
||||||
|
|
||||||
|
typedef struct TF_GrpcServerFactory TF_GrpcServerFactory;
|
||||||
|
typedef struct TF_GrpcServerOptions TF_GrpcServerOptions;
|
||||||
|
typedef struct TF_GrpcServer TF_GrpcServer;
|
||||||
|
typedef struct TF_ServerContext {
|
||||||
|
TF_GrpcServer* const server;
|
||||||
|
void* context;
|
||||||
|
} TF_ServerContext;
|
||||||
|
|
||||||
|
// Creates a new TF_GrpcServerFactory instance. Caller takes ownership
|
||||||
|
// of TF_GrpcServerFactory instance and should deallocate it by calling
|
||||||
|
// TF_GrpcDeleteServerFactory.
|
||||||
|
// accept_function should return true if this ServerFactory can create
|
||||||
|
// server instances for the given protocol name (for e.g. grpc+verbs).
|
||||||
|
// GRPC servers created by this factory will call provided
|
||||||
|
// init_function, start_function, stop_function, join_function and
|
||||||
|
// delete_function.
|
||||||
|
//
|
||||||
|
// Note that clean shutdown is currently not implemented for GrpcServer.
|
||||||
|
// So, stop_function will never be called now but may be in the future
|
||||||
|
// when stop mechanism is supported.
|
||||||
|
TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory(
|
||||||
|
bool (*accept_function)(const char*),
|
||||||
|
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||||
|
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*delete_function)(void*),
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder);
|
||||||
|
|
||||||
|
// Deletes TF_GrpcServerFactory instances.
|
||||||
|
// Note that this function only deletes TF_GrpcServerFactory wrapper.
|
||||||
|
// Actual underlying server factory would not be deleted and will
|
||||||
|
// remain registered.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory(
|
||||||
|
TF_GrpcServerFactory* server_factory);
|
||||||
|
|
||||||
|
// Registers provided server_factory for the given server_type.
|
||||||
|
// server_type must be unique to the server factory.
|
||||||
|
TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory(
|
||||||
|
const char* server_type, TF_GrpcServerFactory* server_factory);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
|
77
tensorflow/c/experimental/network_internal.h
Normal file
77
tensorflow/c/experimental/network_internal.h
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/experimental/network.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// GrpcServer implementation that forwards calls to callbacks.
|
||||||
|
class CGrpcServer : public GrpcServer {
|
||||||
|
protected:
|
||||||
|
CGrpcServer(const ServerDef& server_def,
|
||||||
|
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*delete_function)(void*))
|
||||||
|
: GrpcServer(server_def, ::tensorflow::Env::Default()),
|
||||||
|
start_function_(start_function),
|
||||||
|
stop_function_(stop_function),
|
||||||
|
join_function_(join_function),
|
||||||
|
delete_function_(delete_function),
|
||||||
|
context_(nullptr) {}
|
||||||
|
|
||||||
|
public:
|
||||||
|
static Status Create(
|
||||||
|
const ServerDef& server_def,
|
||||||
|
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
|
||||||
|
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
|
||||||
|
void (*delete_function)(void*),
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder,
|
||||||
|
std::unique_ptr<ServerInterface>* out_server);
|
||||||
|
|
||||||
|
Status Start() override;
|
||||||
|
Status Stop() override;
|
||||||
|
Status Join() override;
|
||||||
|
|
||||||
|
~CGrpcServer() override { delete_function_(context_); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetContext(void* context) { context_ = context; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||||
|
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||||
|
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
|
||||||
|
void (*delete_function_)(void*);
|
||||||
|
void* context_;
|
||||||
|
|
||||||
|
friend class NetworksTest;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
|
256
tensorflow/c/experimental/network_test.cc
Normal file
256
tensorflow/c/experimental/network_test.cc
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/network.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/synchronization/notification.h"
|
||||||
|
#include "absl/time/time.h"
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/experimental/network_internal.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous_internal.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_session.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
bool accept_functionA(const char* protocol_name) {
|
||||||
|
return strcmp(protocol_name, "grpc+A") == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool accept_functionB(const char* protocol_name) {
|
||||||
|
return strcmp(protocol_name, "grpc+B") == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SomeServerData {
|
||||||
|
bool server_started = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SomeRendezvousData {
|
||||||
|
int test = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
void* init_function(const TF_GrpcServer* server, TF_Status* status) {
|
||||||
|
SomeServerData* server_data = new SomeServerData();
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
return server_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
void start_function(const TF_GrpcServer* server, void* context,
|
||||||
|
TF_Status* status) {
|
||||||
|
auto* server_data = static_cast<SomeServerData*>(context);
|
||||||
|
server_data->server_started = true;
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
void stop_function(const TF_GrpcServer* server, void* context,
|
||||||
|
TF_Status* status) {
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
void join_function(const TF_GrpcServer* server, void* context,
|
||||||
|
TF_Status* status) {
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
void delete_function(void* context) {
|
||||||
|
auto* server_data = static_cast<SomeServerData*>(context);
|
||||||
|
delete server_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* rendezvous_init_function(void* server_context) {
|
||||||
|
return new SomeRendezvousData();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Deallocator(void* data, size_t, void* arg) {
|
||||||
|
tensorflow::cpu_allocator()->DeallocateRaw(data);
|
||||||
|
*reinterpret_cast<bool*>(arg) = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void receive_from_remote_async_function(TF_ParsedKey* key,
|
||||||
|
TF_RendezvousArgs* args,
|
||||||
|
TF_RendezvousDoneCallback* callback,
|
||||||
|
void* context) {
|
||||||
|
// Create dummy tensor
|
||||||
|
const int num_bytes = 6 * sizeof(float);
|
||||||
|
float* values =
|
||||||
|
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
|
||||||
|
EIGEN_MAX_ALIGN_BYTES, num_bytes));
|
||||||
|
int64_t dims[] = {2, 3};
|
||||||
|
bool deallocator_called = false;
|
||||||
|
auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
|
||||||
|
&Deallocator, &deallocator_called);
|
||||||
|
callback->tensor = tensor;
|
||||||
|
auto* tf_status = TF_NewStatus();
|
||||||
|
TF_SetStatus(tf_status, TF_OK, "");
|
||||||
|
callback->status = tf_status;
|
||||||
|
TF_RendezvousDone(callback);
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
TF_DeleteTensor(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rendezvous_delete_function(void* context) {
|
||||||
|
auto* rendezvous_data = static_cast<SomeRendezvousData*>(context);
|
||||||
|
delete rendezvous_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::ServerDef GetServerDef(const string& protocol,
|
||||||
|
const string& job_name, int num_tasks) {
|
||||||
|
tensorflow::ServerDef server_def;
|
||||||
|
server_def.set_protocol(protocol);
|
||||||
|
server_def.set_job_name(job_name);
|
||||||
|
server_def.set_task_index(0);
|
||||||
|
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||||
|
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||||
|
job_def->set_name(job_name);
|
||||||
|
for (int i = 0; i < num_tasks; i++) {
|
||||||
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
|
job_def->mutable_tasks()->insert(
|
||||||
|
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||||
|
}
|
||||||
|
return server_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
class NetworksTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
~NetworksTest() override {}
|
||||||
|
|
||||||
|
SomeServerData* GetServerData(CGrpcServer* server) {
|
||||||
|
EXPECT_NE(server->context_, nullptr);
|
||||||
|
return static_cast<SomeServerData*>(server->context_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
|
||||||
|
const string& receiver, const string& name) {
|
||||||
|
Rendezvous::ParsedKey result;
|
||||||
|
CHECK(
|
||||||
|
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
|
||||||
|
name, FrameAndIter(0, 0)),
|
||||||
|
&result)
|
||||||
|
.ok());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def,
|
||||||
|
RemoteRendezvous* remote_rendezvous) {
|
||||||
|
int rendezvous_id = 0;
|
||||||
|
auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id);
|
||||||
|
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||||
|
session_name, *server_def, true));
|
||||||
|
|
||||||
|
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||||
|
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||||
|
session_name, &worker_session));
|
||||||
|
|
||||||
|
TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworksTest, TestStartServer) {
|
||||||
|
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
|
||||||
|
rendezvous_init_function, receive_from_remote_async_function,
|
||||||
|
rendezvous_delete_function);
|
||||||
|
|
||||||
|
TF_Status* tf_status = TF_NewStatus();
|
||||||
|
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
|
||||||
|
accept_functionA, init_function, start_function, stop_function,
|
||||||
|
join_function, delete_function, rendezvous_builder);
|
||||||
|
TF_RegisterGrpcServerFactory("testfactoryA", factory);
|
||||||
|
|
||||||
|
ServerDef server_def = GetServerDef("grpc+A", "localhost", 1);
|
||||||
|
std::unique_ptr<ServerInterface> server;
|
||||||
|
TF_EXPECT_OK(NewServer(server_def, &server));
|
||||||
|
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
|
||||||
|
auto* server_data = GetServerData(grpc_server);
|
||||||
|
ASSERT_FALSE(server_data->server_started);
|
||||||
|
|
||||||
|
TF_EXPECT_OK(server->Start());
|
||||||
|
ASSERT_TRUE(server_data->server_started);
|
||||||
|
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
TF_DeleteGrpcServerFactory(factory);
|
||||||
|
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
|
||||||
|
// TODO(annarev): find a clean way to shutdown server.
|
||||||
|
server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworksTest, TestReceiveData) {
|
||||||
|
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
|
||||||
|
rendezvous_init_function, receive_from_remote_async_function,
|
||||||
|
rendezvous_delete_function);
|
||||||
|
|
||||||
|
TF_Status* tf_status = TF_NewStatus();
|
||||||
|
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
|
||||||
|
accept_functionB, init_function, start_function, stop_function,
|
||||||
|
join_function, delete_function, rendezvous_builder);
|
||||||
|
TF_RegisterGrpcServerFactory("testfactoryB", factory);
|
||||||
|
|
||||||
|
ServerDef server_def = GetServerDef("grpc+B", "localhost", 1);
|
||||||
|
std::unique_ptr<ServerInterface> server;
|
||||||
|
TF_EXPECT_OK(NewServer(server_def, &server));
|
||||||
|
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
|
||||||
|
|
||||||
|
TF_EXPECT_OK(server->Start());
|
||||||
|
auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr;
|
||||||
|
auto* remote_rendezvous = rendezvous_mgr->Find(0);
|
||||||
|
|
||||||
|
auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1,
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0", "test");
|
||||||
|
Rendezvous::Args args;
|
||||||
|
bool done_callback_called = false;
|
||||||
|
auto* done_callback_called_ptr = &done_callback_called;
|
||||||
|
absl::Notification notification;
|
||||||
|
auto* notification_ptr = ¬ification;
|
||||||
|
|
||||||
|
InitializeRendezvous(grpc_server, &server_def, remote_rendezvous);
|
||||||
|
remote_rendezvous->RecvAsync(
|
||||||
|
key, args,
|
||||||
|
[done_callback_called_ptr, notification_ptr](
|
||||||
|
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
|
||||||
|
const Tensor&, const bool) mutable {
|
||||||
|
*done_callback_called_ptr = true;
|
||||||
|
notification_ptr->Notify();
|
||||||
|
});
|
||||||
|
notification.WaitForNotificationWithTimeout(absl::Seconds(10));
|
||||||
|
ASSERT_EQ(done_callback_called, true);
|
||||||
|
|
||||||
|
TF_DeleteStatus(tf_status);
|
||||||
|
TF_DeleteGrpcServerFactory(factory);
|
||||||
|
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
|
||||||
|
// Server doesn't have a clean shutdown.
|
||||||
|
server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
124
tensorflow/c/experimental/rendezvous.cc
Normal file
124
tensorflow/c/experimental/rendezvous.cc
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/rendezvous.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous_internal.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||||
|
void (*receive_from_remote_async_function)(
|
||||||
|
TF_ParsedKey*, TF_RendezvousArgs*,
|
||||||
|
TF_RendezvousDoneCallback*,
|
||||||
|
void* context),
|
||||||
|
void (*delete_function)(void* context),
|
||||||
|
void* server_context)
|
||||||
|
: BaseRemoteRendezvous(env, step_id),
|
||||||
|
receive_from_remote_async_function_(receive_from_remote_async_function),
|
||||||
|
delete_function_(delete_function),
|
||||||
|
context_(nullptr) {}
|
||||||
|
|
||||||
|
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
DoneCallback done) {
|
||||||
|
TF_ParsedKey key;
|
||||||
|
key.src_device = parsed.src_device.data();
|
||||||
|
key.src_device_len = parsed.src_device.size();
|
||||||
|
key.dst_device = parsed.dst_device.data();
|
||||||
|
key.dst_device_len = parsed.dst_device.size();
|
||||||
|
key.full_key = parsed.FullKey().data();
|
||||||
|
key.full_key_len = parsed.FullKey().size();
|
||||||
|
|
||||||
|
TF_DeviceContext* device_context = new TF_DeviceContext();
|
||||||
|
device_context->context = args.device_context;
|
||||||
|
|
||||||
|
TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes();
|
||||||
|
alloc_attrs->value = args.alloc_attrs.value;
|
||||||
|
alloc_attrs->scope_id = args.alloc_attrs.scope_id;
|
||||||
|
alloc_attrs->on_host = args.alloc_attrs.on_host();
|
||||||
|
alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible();
|
||||||
|
|
||||||
|
TF_RendezvousArgs* cargs = new TF_RendezvousArgs();
|
||||||
|
cargs->device_context = device_context;
|
||||||
|
cargs->alloc_attrs = alloc_attrs;
|
||||||
|
|
||||||
|
TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback();
|
||||||
|
done_callback->done_callback = done;
|
||||||
|
done_callback->recv_args = cargs;
|
||||||
|
|
||||||
|
receive_from_remote_async_function_(&key, cargs, done_callback, context_);
|
||||||
|
}
|
||||||
|
|
||||||
|
CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); }
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
|
||||||
|
void* (*init_function)(void* server_context),
|
||||||
|
void (*receive_from_remote_async_function)(TF_ParsedKey*,
|
||||||
|
TF_RendezvousArgs*,
|
||||||
|
TF_RendezvousDoneCallback*,
|
||||||
|
void* context),
|
||||||
|
void (*delete_function)(void* context)) {
|
||||||
|
TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder();
|
||||||
|
builder->init_function = init_function;
|
||||||
|
builder->delete_function = delete_function;
|
||||||
|
builder->receive_from_remote_async_function =
|
||||||
|
receive_from_remote_async_function;
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteRemoteRendezvousBuilder(
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder) {
|
||||||
|
DCHECK_NE(rendezvous_builder, nullptr);
|
||||||
|
delete rendezvous_builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CAPI_EXPORT extern void TF_RendezvousDone(
|
||||||
|
TF_RendezvousDoneCallback* callback) {
|
||||||
|
DCHECK_NE(callback, nullptr);
|
||||||
|
::tensorflow::Tensor tensor;
|
||||||
|
TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor));
|
||||||
|
::tensorflow::Rendezvous::Args recv_args;
|
||||||
|
recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value;
|
||||||
|
recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id;
|
||||||
|
recv_args.device_context = callback->recv_args->device_context->context;
|
||||||
|
::tensorflow::Rendezvous::Args sent_args;
|
||||||
|
|
||||||
|
callback->done_callback(callback->status->status, sent_args, recv_args,
|
||||||
|
tensor, callback->dead);
|
||||||
|
|
||||||
|
if (callback->recv_args) {
|
||||||
|
DCHECK_NE(callback->recv_args, nullptr);
|
||||||
|
DCHECK_NE(callback->recv_args->alloc_attrs, nullptr);
|
||||||
|
DCHECK_NE(callback->recv_args->device_context, nullptr);
|
||||||
|
delete callback->recv_args->alloc_attrs;
|
||||||
|
delete callback->recv_args->device_context;
|
||||||
|
delete callback->recv_args;
|
||||||
|
}
|
||||||
|
delete callback;
|
||||||
|
callback = nullptr;
|
||||||
|
}
|
67
tensorflow/c/experimental/rendezvous.h
Normal file
67
tensorflow/c/experimental/rendezvous.h
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// C API for Rendezvous.
|
||||||
|
// NOTE: This API is unstable and almost certainly will change in the near
|
||||||
|
// future.
|
||||||
|
//
|
||||||
|
// Custom rendezvous allows for custom implementations of Recv call.
|
||||||
|
//
|
||||||
|
// Users wishing to create custom rendezvous objects should call
|
||||||
|
// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder
|
||||||
|
// to to TF_NewServerFactory.
|
||||||
|
|
||||||
|
typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder;
|
||||||
|
typedef struct TF_ParsedKey TF_ParsedKey;
|
||||||
|
typedef struct TF_RendezvousArgs TF_RendezvousArgs;
|
||||||
|
typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback;
|
||||||
|
|
||||||
|
// Creates a new TF_RemoteRendezvousBuilder instance.
|
||||||
|
// Rendezvous instances will forward calls to init_function,
|
||||||
|
// receive_from_remote_async_function and delete_function passed here.
|
||||||
|
//
|
||||||
|
// Note that receive_from_remote_async_function implementation must call
|
||||||
|
// TF_Done with the TF_DoneCallback passed as an argument.
|
||||||
|
TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
|
||||||
|
void* (*init_function)(void* server_context),
|
||||||
|
void (*receive_from_remote_async_function)(TF_ParsedKey*,
|
||||||
|
TF_RendezvousArgs*,
|
||||||
|
TF_RendezvousDoneCallback*,
|
||||||
|
void* context),
|
||||||
|
void (*delete_function)(void* context));
|
||||||
|
|
||||||
|
// Deletes TF_RemoteRendezvousBuilder instances.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder(
|
||||||
|
TF_RemoteRendezvousBuilder* rendezvous_builder);
|
||||||
|
|
||||||
|
// Calls TF_DoneCallback and destroys callback instance and
|
||||||
|
// TF_DoneCallback members except `tensor` and `status`. Caller is
|
||||||
|
// responsible for deleting `tensor` and `status` after TF_Done returns.
|
||||||
|
TF_CAPI_EXPORT extern void TF_RendezvousDone(
|
||||||
|
TF_RendezvousDoneCallback* callback);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
|
135
tensorflow/c/experimental/rendezvous_internal.h
Normal file
135
tensorflow/c/experimental/rendezvous_internal.h
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/c/experimental/rendezvous.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
|
struct TF_ParsedKey {
|
||||||
|
// char* members might not be null-terminated.
|
||||||
|
const char* src_device;
|
||||||
|
size_t src_device_len;
|
||||||
|
const char* dst_device;
|
||||||
|
size_t dst_device_len;
|
||||||
|
const char* full_key;
|
||||||
|
size_t full_key_len;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_AllocatorAttributes {
|
||||||
|
bool on_host;
|
||||||
|
bool nic_compatible;
|
||||||
|
// NOTE: The upper 8 bits of the value are reserved for
|
||||||
|
// device-specific uses. Implementors of a device can interpret these
|
||||||
|
// upper 8 bits in device-specific ways, and ops implemented for those
|
||||||
|
// devices are responsible for setting those 8 bits appropriately.
|
||||||
|
tensorflow::uint32 value = 0;
|
||||||
|
// EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
|
||||||
|
// a named special-purpose allocator on the same device.
|
||||||
|
tensorflow::int32 scope_id = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_DeviceContext {
|
||||||
|
::tensorflow::DeviceContext* context;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_RendezvousArgs {
|
||||||
|
const TF_DeviceContext* device_context;
|
||||||
|
const TF_AllocatorAttributes* alloc_attrs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_RendezvousDoneCallback {
|
||||||
|
::tensorflow::Rendezvous::DoneCallback done_callback;
|
||||||
|
|
||||||
|
// TODO(annarev): figure out if we should also support sent_args.
|
||||||
|
const TF_RendezvousArgs* recv_args;
|
||||||
|
TF_Tensor* tensor = nullptr;
|
||||||
|
TF_Status* status;
|
||||||
|
bool dead;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TF_RemoteRendezvousBuilder {
|
||||||
|
void* (*init_function)(void* server_context);
|
||||||
|
void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*,
|
||||||
|
TF_RendezvousDoneCallback*,
|
||||||
|
void* context);
|
||||||
|
void (*delete_function)(void* context);
|
||||||
|
void* server_context;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class CRemoteRendezvous : public BaseRemoteRendezvous {
|
||||||
|
public:
|
||||||
|
CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
|
||||||
|
void (*receive_from_remote_async_function)(
|
||||||
|
TF_ParsedKey*, TF_RendezvousArgs*,
|
||||||
|
TF_RendezvousDoneCallback*, void* context),
|
||||||
|
void (*delete_function)(void* context),
|
||||||
|
void* server_context);
|
||||||
|
|
||||||
|
void SetContext(void* context) { context_ = context; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
DoneCallback done) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
~CRemoteRendezvous() override;
|
||||||
|
|
||||||
|
void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*,
|
||||||
|
TF_RendezvousDoneCallback*,
|
||||||
|
void* context);
|
||||||
|
void (*delete_function_)(void* context);
|
||||||
|
void* context_;
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CRendezvousMgr : public BaseRendezvousMgr {
|
||||||
|
public:
|
||||||
|
CRendezvousMgr(const WorkerEnv* env,
|
||||||
|
const TF_RemoteRendezvousBuilder* rendezvous_builder)
|
||||||
|
: BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
BaseRemoteRendezvous* Create(int64 step_id,
|
||||||
|
const WorkerEnv* worker_env) override {
|
||||||
|
auto* rendezvous = new CRemoteRendezvous(
|
||||||
|
worker_env, step_id,
|
||||||
|
rendezvous_builder_->receive_from_remote_async_function,
|
||||||
|
rendezvous_builder_->delete_function,
|
||||||
|
rendezvous_builder_->server_context);
|
||||||
|
|
||||||
|
rendezvous->SetContext(rendezvous_builder_->init_function(
|
||||||
|
rendezvous_builder_->server_context));
|
||||||
|
return rendezvous;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const TF_RemoteRendezvousBuilder* rendezvous_builder_;
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
|
Loading…
Reference in New Issue
Block a user