Add a rudimentary library loader implementation for some TPU related functions

PiperOrigin-RevId: 314496921
Change-Id: I9c499f37525dee7b2cf138089b052e020766d54f
This commit is contained in:
Frank Chen 2020-06-03 02:52:48 -07:00 committed by TensorFlower Gardener
parent 10c005554d
commit 2cbb43e1fc
6 changed files with 152 additions and 22 deletions

View File

@ -107,7 +107,7 @@ Status LoadLibrary(const char* library_filename, void** result,
if (env->GetSymbolFromLibrary(library.handle, "TfTpu_Initialize",
&unused_symbol)
.ok()) {
TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTPULibrary(library.handle));
TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTpuLibrary(library.handle));
}
#endif // IS_MOBILE_PLATFORM

View File

@ -8,6 +8,13 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "libtftpu_header",
hdrs = ["libtftpu.h"],
visibility = ["//visibility:public"],
deps = [],
)
cc_library(
name = "tpu_embedding_optimization_parameters_utils",
srcs = ["tpu_embedding_optimization_parameters_utils.cc"],
@ -88,6 +95,7 @@ cc_library(
name = "tpu_config_c_api",
hdrs = ["tpu_config_c_api.h"],
deps = [
":libtftpu_header",
"//tensorflow/c:tf_status",
],
)
@ -96,6 +104,11 @@ cc_library(
name = "tpu_library_loader",
srcs = ["tpu_library_loader.cc"],
hdrs = ["tpu_library_loader.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = ["//tensorflow/core/platform:status"],
visibility = ["//visibility:public"],
deps = [
":libtftpu_header",
":tpu_config_c_api",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
],
)

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_LIBTFTPU_H_
#define TENSORFLOW_CORE_TPU_LIBTFTPU_H_
// Unfortunately we have to add an Fn suffix because we cannot have the same
// name for both a function and a element within a struct in the global
// namespace in gcc. This restriction doesn't exist in clang.
#define TFTPU_ADD_FN_IN_STRUCT(FnName) decltype(FnName)* FnName##Fn;
#ifdef SWIG
#define TFTPU_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TFTPU_CAPI_EXPORT __declspec(dllexport)
#else
#define TFTPU_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TFTPU_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
TFTPU_CAPI_EXPORT void TfTpu_Initialize();
#ifdef __cplusplus
}
#endif
struct TfTpu_BaseFn {
TFTPU_ADD_FN_IN_STRUCT(TfTpu_Initialize);
};
#endif // TENSORFLOW_CORE_TPU_LIBTFTPU_H_

View File

@ -20,40 +20,53 @@ limitations under the License.
#include <cstdint>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/tpu/libtftpu.h"
typedef struct TpuSerializedProto TpuSerializedProto;
extern "C" {
bool TPUHostInitialized();
TFTPU_CAPI_EXPORT bool TPUHostInitialized();
void ConfigureDistributedTpuOp_DoWork(const size_t num_cores_per_host_size,
const int32_t* num_cores_per_host,
size_t* host_config_output_size,
char** host_config_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
size_t* host_config_output_size, char** host_config_output,
TF_Status* status);
void WaitForDistributedTpuOp_DoWork(
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,
const int32_t** host_ordinal_to_global_core_id_map,
size_t* tpu_topology_output_size, char** tpu_topology_output,
TF_Status* status);
void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
void InitializeHostForDistributedTpuOp_DoWork(
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
const size_t tpu_host_config_size, const char* tpu_host_config,
const bool enable_whole_mesh_compilations, size_t* core_id_output_size,
int32_t** core_id_output, TF_Status* status);
void SetGlobalTPUArrayOp_DoWork(const size_t tpu_topology_size,
const char* tpu_topology, TF_Status* status);
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
const size_t tpu_topology_size, const char* tpu_topology,
TF_Status* status);
void DisconnectDistributedTpuChipsOp_DoWork(int32_t* number_of_chips_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
int32_t* number_of_chips_output, TF_Status* status);
void TpuConfigurationApi_FreeCharArray(char* output);
void TpuConfigurationApi_FreeInt32Array(int32_t* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
}
struct TfTpu_ConfigApiFn {
TFTPU_ADD_FN_IN_STRUCT(TPUHostInitialized);
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
};
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_

View File

@ -15,14 +15,60 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include <dlfcn.h>
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#define TFTPU_SET_FN(Struct, FnName) \
Struct->FnName##Fn = \
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));
namespace tensorflow {
namespace tpu {
Status InitializeTPULibrary(void* library) {
// TODO(frankchn): dlsym the loaded library and populate a struct with the
// relevant C APIs necessary for TPUs.
TfTpu_BaseFn* InitializeApiFn() {
static TfTpu_BaseFn base_fn;
return &base_fn;
}
TfTpu_ConfigApiFn* ConfigApiFn() {
static TfTpu_ConfigApiFn config_api_fn;
return &config_api_fn;
}
Status SetTpuInitializeStructFns(void* library_handle) {
auto* base_fn = InitializeApiFn();
TFTPU_SET_FN(base_fn, TfTpu_Initialize);
return Status::OK();
}
Status SetTpuConfigStructFns(void* library_handle) {
auto* config_fn = ConfigApiFn();
TFTPU_SET_FN(config_fn, TPUHostInitialized);
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
return Status::OK();
}
Status InitializeTpuLibrary(void* library_handle) {
if (library_handle == nullptr) {
library_handle = dlopen(nullptr, RTLD_LAZY);
}
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
return Status::OK();
}

View File

@ -17,11 +17,17 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
namespace tensorflow {
namespace tpu {
Status InitializeTPULibrary(void* library);
Status InitializeTpuLibrary(void* library_handle);
TfTpu_BaseFn* InitializeApiFn();
TfTpu_ConfigApiFn* ConfigApiFn();
} // namespace tpu
} // namespace tensorflow