First check-in for V1 of Networking C API.

PiperOrigin-RevId: 245293662
This commit is contained in:
Anna R 2019-04-25 13:13:54 -07:00 committed by TensorFlower Gardener
parent 17db87c632
commit 606c13f5a6
8 changed files with 1044 additions and 0 deletions

View File

@ -0,0 +1,122 @@
# Description:
# Experimental C APIs for TensorFlow.
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
tf_cuda_library(
name = "rendezvous_internal",
srcs = [
"rendezvous.cc",
],
hdrs = [
"rendezvous.h",
"rendezvous_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "rendezvous",
hdrs = [
"rendezvous.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api",
],
)
tf_cuda_library(
name = "network_internal",
srcs = [
"network.cc",
],
hdrs = [
"network.h",
"network_internal.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
":rendezvous_internal",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
tf_cuda_library(
name = "network",
hdrs = [
"network.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":network_internal",
":rendezvous",
"//tensorflow/c:c_api",
],
)
# -----------------------------------------------------------------------------
# Tests
tf_cuda_cc_test(
name = "network_test",
size = "medium",
srcs = ["network_test.cc"],
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":network",
":network_internal",
":rendezvous",
":rendezvous_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:env",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)

View File

@ -0,0 +1,166 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <memory>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
using tensorflow::ServerFactory;
namespace tensorflow {
/* static */ Status CGrpcServer::Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server) {
auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function,
join_function, delete_function);
GrpcServerOptions options;
options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) {
return new CRendezvousMgr(env, rendezvous_builder);
};
TF_RETURN_IF_ERROR(grpc_server->Init(options));
TF_Status* tf_status = TF_NewStatus();
grpc_server->SetContext(init_function(
reinterpret_cast<const TF_GrpcServer*>(grpc_server), tf_status));
TF_RETURN_IF_ERROR(tf_status->status);
TF_DeleteStatus(tf_status);
out_server->reset(grpc_server);
return Status::OK();
}
Status CGrpcServer::Start() {
Status status = GrpcServer::Start();
TF_Status* tf_status = TF_NewStatus();
(*start_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Stop() {
Status status = GrpcServer::Stop();
TF_Status* tf_status = TF_NewStatus();
(*stop_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
Status CGrpcServer::Join() {
Status status = GrpcServer::Join();
TF_Status* tf_status = TF_NewStatus();
(*join_function_)(reinterpret_cast<const TF_GrpcServer*>(this), context_,
tf_status);
status.Update(tf_status->status);
TF_DeleteStatus(tf_status);
return status;
}
namespace {
// Factory that creates CGrpcServer instances.
class CServerFactory : public ServerFactory {
public:
CServerFactory(bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*,
TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder)
: accept_function_(accept_function),
init_function_(init_function),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,
join_function_, delete_function_, rendezvous_builder_, out_server));
return Status::OK();
}
// Returns true if and only if this factory can create a server
// based on the given `server_def`.
bool AcceptsOptions(const ServerDef& server_def) override {
return (*accept_function_)(server_def.protocol().c_str());
}
private:
bool (*accept_function_)(const char* protocol);
void* (*init_function_)(const TF_GrpcServer*, TF_Status*);
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
TF_RemoteRendezvousBuilder* rendezvous_builder_;
};
} // namespace
} // namespace tensorflow
// Server factory representation to use in C API.
// Holds CServerFactory pointer.
struct TF_GrpcServerFactory {
::tensorflow::CServerFactory* factory;
};
TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder) {
TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory;
server_factory->factory = new ::tensorflow::CServerFactory(
accept_function, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
return server_factory;
}
void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) {
DCHECK_NE(server_factory, nullptr);
delete server_factory;
}
void TF_RegisterGrpcServerFactory(const char* server_type,
TF_GrpcServerFactory* server_factory) {
ServerFactory::Register(server_type, server_factory->factory);
}

View File

@ -0,0 +1,97 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for TensorFlow Networking.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Users wishing to register a custom GrpcServer should call
// TF_NewServerFactory and then TF_RegisterGrpcServerFactory.
//
// Example:
// ```c++
// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
// rendezvous_init_function,
// receive_from_remote_async_function,
// rendezvous_delete_function);
//
// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
// accept_function,
// init_function,
// start_function,
// stop_function,
// join_function,
// delete_function,
// rendezvous_builder);
// TF_RegisterGrpcServerFactory("customfactory", factory);
// ...
// TF_DeleteGrpcServerFactory(factory);
// ```
typedef struct TF_GrpcServerFactory TF_GrpcServerFactory;
typedef struct TF_GrpcServerOptions TF_GrpcServerOptions;
typedef struct TF_GrpcServer TF_GrpcServer;
typedef struct TF_ServerContext {
TF_GrpcServer* const server;
void* context;
} TF_ServerContext;
// Creates a new TF_GrpcServerFactory instance. Caller takes ownership
// of TF_GrpcServerFactory instance and should deallocate it by calling
// TF_GrpcDeleteServerFactory.
// accept_function should return true if this ServerFactory can create
// server instances for the given protocol name (for e.g. grpc+verbs).
// GRPC servers created by this factory will call provided
// init_function, start_function, stop_function, join_function and
// delete_function.
//
// Note that clean shutdown is currently not implemented for GrpcServer.
// So, stop_function will never be called now but may be in the future
// when stop mechanism is supported.
TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory(
bool (*accept_function)(const char*),
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Deletes TF_GrpcServerFactory instances.
// Note that this function only deletes TF_GrpcServerFactory wrapper.
// Actual underlying server factory would not be deleted and will
// remain registered.
TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory(
TF_GrpcServerFactory* server_factory);
// Registers provided server_factory for the given server_type.
// server_type must be unique to the server factory.
TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory(
const char* server_type, TF_GrpcServerFactory* server_factory);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
// GrpcServer implementation that forwards calls to callbacks.
class CGrpcServer : public GrpcServer {
protected:
CGrpcServer(const ServerDef& server_def,
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*))
: GrpcServer(server_def, ::tensorflow::Env::Default()),
start_function_(start_function),
stop_function_(stop_function),
join_function_(join_function),
delete_function_(delete_function),
context_(nullptr) {}
public:
static Status Create(
const ServerDef& server_def,
void* (*init_function)(const TF_GrpcServer*, TF_Status*),
void (*start_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*join_function)(const TF_GrpcServer*, void*, TF_Status*),
void (*delete_function)(void*),
TF_RemoteRendezvousBuilder* rendezvous_builder,
std::unique_ptr<ServerInterface>* out_server);
Status Start() override;
Status Stop() override;
Status Join() override;
~CGrpcServer() override { delete_function_(context_); }
protected:
void SetContext(void* context) { context_ = context; }
private:
void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*);
void (*delete_function_)(void*);
void* context_;
friend class NetworksTest;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_

View File

@ -0,0 +1,256 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/network.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/synchronization/notification.h"
#include "absl/time/time.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/network_internal.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
bool accept_functionA(const char* protocol_name) {
return strcmp(protocol_name, "grpc+A") == 0;
}
bool accept_functionB(const char* protocol_name) {
return strcmp(protocol_name, "grpc+B") == 0;
}
struct SomeServerData {
bool server_started = false;
};
struct SomeRendezvousData {
int test = 0;
};
void* init_function(const TF_GrpcServer* server, TF_Status* status) {
SomeServerData* server_data = new SomeServerData();
TF_SetStatus(status, TF_OK, "");
return server_data;
}
void start_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
auto* server_data = static_cast<SomeServerData*>(context);
server_data->server_started = true;
TF_SetStatus(status, TF_OK, "");
}
void stop_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void join_function(const TF_GrpcServer* server, void* context,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
void delete_function(void* context) {
auto* server_data = static_cast<SomeServerData*>(context);
delete server_data;
}
void* rendezvous_init_function(void* server_context) {
return new SomeRendezvousData();
}
void Deallocator(void* data, size_t, void* arg) {
tensorflow::cpu_allocator()->DeallocateRaw(data);
*reinterpret_cast<bool*>(arg) = true;
}
void receive_from_remote_async_function(TF_ParsedKey* key,
TF_RendezvousArgs* args,
TF_RendezvousDoneCallback* callback,
void* context) {
// Create dummy tensor
const int num_bytes = 6 * sizeof(float);
float* values =
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
EIGEN_MAX_ALIGN_BYTES, num_bytes));
int64_t dims[] = {2, 3};
bool deallocator_called = false;
auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
&Deallocator, &deallocator_called);
callback->tensor = tensor;
auto* tf_status = TF_NewStatus();
TF_SetStatus(tf_status, TF_OK, "");
callback->status = tf_status;
TF_RendezvousDone(callback);
TF_DeleteStatus(tf_status);
TF_DeleteTensor(tensor);
}
void rendezvous_delete_function(void* context) {
auto* rendezvous_data = static_cast<SomeRendezvousData*>(context);
delete rendezvous_data;
}
tensorflow::ServerDef GetServerDef(const string& protocol,
const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol(protocol);
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
class NetworksTest : public ::testing::Test {
public:
~NetworksTest() override {}
SomeServerData* GetServerData(CGrpcServer* server) {
EXPECT_NE(server->context_, nullptr);
return static_cast<SomeServerData*>(server->context_);
}
};
Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
const string& receiver, const string& name) {
Rendezvous::ParsedKey result;
CHECK(
Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
name, FrameAndIter(0, 0)),
&result)
.ok());
return result;
}
void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def,
RemoteRendezvous* remote_rendezvous) {
int rendezvous_id = 0;
auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id);
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, *server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get()));
}
TEST_F(NetworksTest, TestStartServer) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionA, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryA", factory);
ServerDef server_def = GetServerDef("grpc+A", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
auto* server_data = GetServerData(grpc_server);
ASSERT_FALSE(server_data->server_started);
TF_EXPECT_OK(server->Start());
ASSERT_TRUE(server_data->server_started);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// TODO(annarev): find a clean way to shutdown server.
server.release();
}
TEST_F(NetworksTest, TestReceiveData) {
auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder(
rendezvous_init_function, receive_from_remote_async_function,
rendezvous_delete_function);
TF_Status* tf_status = TF_NewStatus();
TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory(
accept_functionB, init_function, start_function, stop_function,
join_function, delete_function, rendezvous_builder);
TF_RegisterGrpcServerFactory("testfactoryB", factory);
ServerDef server_def = GetServerDef("grpc+B", "localhost", 1);
std::unique_ptr<ServerInterface> server;
TF_EXPECT_OK(NewServer(server_def, &server));
auto* grpc_server = static_cast<CGrpcServer*>(server.get());
TF_EXPECT_OK(server->Start());
auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr;
auto* remote_rendezvous = rendezvous_mgr->Find(0);
auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1,
"/job:localhost/replica:0/task:0/device:CPU:0", "test");
Rendezvous::Args args;
bool done_callback_called = false;
auto* done_callback_called_ptr = &done_callback_called;
absl::Notification notification;
auto* notification_ptr = &notification;
InitializeRendezvous(grpc_server, &server_def, remote_rendezvous);
remote_rendezvous->RecvAsync(
key, args,
[done_callback_called_ptr, notification_ptr](
const Status&, const Rendezvous::Args&, const Rendezvous::Args&,
const Tensor&, const bool) mutable {
*done_callback_called_ptr = true;
notification_ptr->Notify();
});
notification.WaitForNotificationWithTimeout(absl::Seconds(10));
ASSERT_EQ(done_callback_called, true);
TF_DeleteStatus(tf_status);
TF_DeleteGrpcServerFactory(factory);
TF_DeleteRemoteRendezvousBuilder(rendezvous_builder);
// Server doesn't have a clean shutdown.
server.release();
}
} // namespace tensorflow

View File

@ -0,0 +1,124 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/rendezvous.h"
#include <functional>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/experimental/rendezvous_internal.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context),
void* server_context)
: BaseRemoteRendezvous(env, step_id),
receive_from_remote_async_function_(receive_from_remote_async_function),
delete_function_(delete_function),
context_(nullptr) {}
void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) {
TF_ParsedKey key;
key.src_device = parsed.src_device.data();
key.src_device_len = parsed.src_device.size();
key.dst_device = parsed.dst_device.data();
key.dst_device_len = parsed.dst_device.size();
key.full_key = parsed.FullKey().data();
key.full_key_len = parsed.FullKey().size();
TF_DeviceContext* device_context = new TF_DeviceContext();
device_context->context = args.device_context;
TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes();
alloc_attrs->value = args.alloc_attrs.value;
alloc_attrs->scope_id = args.alloc_attrs.scope_id;
alloc_attrs->on_host = args.alloc_attrs.on_host();
alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible();
TF_RendezvousArgs* cargs = new TF_RendezvousArgs();
cargs->device_context = device_context;
cargs->alloc_attrs = alloc_attrs;
TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback();
done_callback->done_callback = done;
done_callback->recv_args = cargs;
receive_from_remote_async_function_(&key, cargs, done_callback, context_);
}
CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); }
} // namespace tensorflow
TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context)) {
TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder();
builder->init_function = init_function;
builder->delete_function = delete_function;
builder->receive_from_remote_async_function =
receive_from_remote_async_function;
return builder;
}
void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder) {
DCHECK_NE(rendezvous_builder, nullptr);
delete rendezvous_builder;
}
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback) {
DCHECK_NE(callback, nullptr);
::tensorflow::Tensor tensor;
TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor));
::tensorflow::Rendezvous::Args recv_args;
recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value;
recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id;
recv_args.device_context = callback->recv_args->device_context->context;
::tensorflow::Rendezvous::Args sent_args;
callback->done_callback(callback->status->status, sent_args, recv_args,
tensor, callback->dead);
if (callback->recv_args) {
DCHECK_NE(callback->recv_args, nullptr);
DCHECK_NE(callback->recv_args->alloc_attrs, nullptr);
DCHECK_NE(callback->recv_args->device_context, nullptr);
delete callback->recv_args->alloc_attrs;
delete callback->recv_args->device_context;
delete callback->recv_args;
}
delete callback;
callback = nullptr;
}

View File

@ -0,0 +1,67 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for Rendezvous.
// NOTE: This API is unstable and almost certainly will change in the near
// future.
//
// Custom rendezvous allows for custom implementations of Recv call.
//
// Users wishing to create custom rendezvous objects should call
// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder
// to to TF_NewServerFactory.
typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder;
typedef struct TF_ParsedKey TF_ParsedKey;
typedef struct TF_RendezvousArgs TF_RendezvousArgs;
typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback;
// Creates a new TF_RemoteRendezvousBuilder instance.
// Rendezvous instances will forward calls to init_function,
// receive_from_remote_async_function and delete_function passed here.
//
// Note that receive_from_remote_async_function implementation must call
// TF_Done with the TF_DoneCallback passed as an argument.
TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder(
void* (*init_function)(void* server_context),
void (*receive_from_remote_async_function)(TF_ParsedKey*,
TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context),
void (*delete_function)(void* context));
// Deletes TF_RemoteRendezvousBuilder instances.
TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder(
TF_RemoteRendezvousBuilder* rendezvous_builder);
// Calls TF_DoneCallback and destroys callback instance and
// TF_DoneCallback members except `tensor` and `status`. Caller is
// responsible for deleting `tensor` and `status` after TF_Done returns.
TF_CAPI_EXPORT extern void TF_RendezvousDone(
TF_RendezvousDoneCallback* callback);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_
#include <stddef.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/rendezvous.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/platform/macros.h"
struct TF_ParsedKey {
// char* members might not be null-terminated.
const char* src_device;
size_t src_device_len;
const char* dst_device;
size_t dst_device_len;
const char* full_key;
size_t full_key_len;
};
struct TF_AllocatorAttributes {
bool on_host;
bool nic_compatible;
// NOTE: The upper 8 bits of the value are reserved for
// device-specific uses. Implementors of a device can interpret these
// upper 8 bits in device-specific ways, and ops implemented for those
// devices are responsible for setting those 8 bits appropriately.
tensorflow::uint32 value = 0;
// EXPERIMENTAL: If this is greater than zero, then allocation is delegated to
// a named special-purpose allocator on the same device.
tensorflow::int32 scope_id = 0;
};
struct TF_DeviceContext {
::tensorflow::DeviceContext* context;
};
struct TF_RendezvousArgs {
const TF_DeviceContext* device_context;
const TF_AllocatorAttributes* alloc_attrs;
};
struct TF_RendezvousDoneCallback {
::tensorflow::Rendezvous::DoneCallback done_callback;
// TODO(annarev): figure out if we should also support sent_args.
const TF_RendezvousArgs* recv_args;
TF_Tensor* tensor = nullptr;
TF_Status* status;
bool dead;
};
struct TF_RemoteRendezvousBuilder {
void* (*init_function)(void* server_context);
void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function)(void* context);
void* server_context;
};
namespace tensorflow {
class CRemoteRendezvous : public BaseRemoteRendezvous {
public:
CRemoteRendezvous(const WorkerEnv* env, int64 step_id,
void (*receive_from_remote_async_function)(
TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*, void* context),
void (*delete_function)(void* context),
void* server_context);
void SetContext(void* context) { context_ = context; }
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
private:
~CRemoteRendezvous() override;
void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*,
TF_RendezvousDoneCallback*,
void* context);
void (*delete_function_)(void* context);
void* context_;
TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous);
};
class CRendezvousMgr : public BaseRendezvousMgr {
public:
CRendezvousMgr(const WorkerEnv* env,
const TF_RemoteRendezvousBuilder* rendezvous_builder)
: BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {}
protected:
BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env) override {
auto* rendezvous = new CRemoteRendezvous(
worker_env, step_id,
rendezvous_builder_->receive_from_remote_async_function,
rendezvous_builder_->delete_function,
rendezvous_builder_->server_context);
rendezvous->SetContext(rendezvous_builder_->init_function(
rendezvous_builder_->server_context));
return rendezvous;
}
private:
const TF_RemoteRendezvousBuilder* rendezvous_builder_;
TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr);
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_