Add a rudimentary library loader implementation for some TPU related functions

PiperOrigin-RevId: 314863001
Change-Id: Iafe056ab3fcf592cd28873e6fd740121f17d1a91
This commit is contained in:
Frank Chen 2020-06-04 21:45:35 -07:00 committed by TensorFlower Gardener
parent 9244dd50fc
commit 75c40f6bff
7 changed files with 202 additions and 23 deletions

View File

@ -107,7 +107,7 @@ Status LoadLibrary(const char* library_filename, void** result,
if (env->GetSymbolFromLibrary(library.handle, "TfTpu_Initialize",
&unused_symbol)
.ok()) {
TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTPULibrary(library.handle));
TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTpuLibrary(library.handle));
}
#endif // IS_MOBILE_PLATFORM

View File

@ -1,5 +1,10 @@
# Description: Utilities for TPU Operations
load(
"//tensorflow:tensorflow.bzl",
"if_windows",
)
package(
default_visibility = [
"//tensorflow/core/tpu:__subpackages__",
@ -8,6 +13,13 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "libtftpu_header",
hdrs = ["libtftpu.h"],
visibility = ["//visibility:public"],
deps = [],
)
cc_library(
name = "tpu_embedding_optimization_parameters_utils",
srcs = ["tpu_embedding_optimization_parameters_utils.cc"],
@ -88,14 +100,23 @@ cc_library(
name = "tpu_config_c_api",
hdrs = ["tpu_config_c_api.h"],
deps = [
":libtftpu_header",
"//tensorflow/c:tf_status",
],
)
cc_library(
name = "tpu_library_loader",
srcs = ["tpu_library_loader.cc"],
srcs = if_windows(
["tpu_library_loader_windows.cc"],
otherwise = ["tpu_library_loader.cc"],
),
hdrs = ["tpu_library_loader.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = ["//tensorflow/core/platform:status"],
visibility = ["//visibility:public"],
deps = [
":libtftpu_header",
":tpu_config_c_api",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
],
)

View File

@ -0,0 +1,52 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_LIBTFTPU_H_
#define TENSORFLOW_CORE_TPU_LIBTFTPU_H_
// Unfortunately we have to add an Fn suffix because we cannot have the same
// name for both a function and a element within a struct in the global
// namespace in gcc. This restriction doesn't exist in clang.
#define TFTPU_ADD_FN_IN_STRUCT(FnName) decltype(FnName)* FnName##Fn;
#ifdef SWIG
#define TFTPU_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TFTPU_CAPI_EXPORT __declspec(dllexport)
#else
#define TFTPU_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TFTPU_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
TFTPU_CAPI_EXPORT void TfTpu_Initialize();
#ifdef __cplusplus
}
#endif
struct TfTpu_BaseFn {
TFTPU_ADD_FN_IN_STRUCT(TfTpu_Initialize);
};
#endif // TENSORFLOW_CORE_TPU_LIBTFTPU_H_

View File

@ -20,40 +20,53 @@ limitations under the License.
#include <cstdint>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/tpu/libtftpu.h"
typedef struct TpuSerializedProto TpuSerializedProto;
extern "C" {
bool TPUHostInitialized();
TFTPU_CAPI_EXPORT bool TPUHostInitialized();
void ConfigureDistributedTpuOp_DoWork(const size_t num_cores_per_host_size,
const int32_t* num_cores_per_host,
size_t* host_config_output_size,
char** host_config_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
size_t* host_config_output_size, char** host_config_output,
TF_Status* status);
void WaitForDistributedTpuOp_DoWork(
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,
const int32_t** host_ordinal_to_global_core_id_map,
size_t* tpu_topology_output_size, char** tpu_topology_output,
TF_Status* status);
void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
void InitializeHostForDistributedTpuOp_DoWork(
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
const size_t tpu_host_config_size, const char* tpu_host_config,
const bool enable_whole_mesh_compilations, size_t* core_id_output_size,
int32_t** core_id_output, TF_Status* status);
void SetGlobalTPUArrayOp_DoWork(const size_t tpu_topology_size,
const char* tpu_topology, TF_Status* status);
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
const size_t tpu_topology_size, const char* tpu_topology,
TF_Status* status);
void DisconnectDistributedTpuChipsOp_DoWork(int32_t* number_of_chips_output,
TF_Status* status);
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
int32_t* number_of_chips_output, TF_Status* status);
void TpuConfigurationApi_FreeCharArray(char* output);
void TpuConfigurationApi_FreeInt32Array(int32_t* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
}
struct TfTpu_ConfigApiFn {
TFTPU_ADD_FN_IN_STRUCT(TPUHostInitialized);
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
};
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_

View File

@ -15,14 +15,63 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_library_loader.h"
#include <dlfcn.h>
#define TFTPU_SET_FN(Struct, FnName) \
Struct->FnName##Fn = \
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
// visible methods.
namespace tensorflow {
namespace tpu {
Status InitializeTPULibrary(void* library) {
// TODO(frankchn): dlsym the loaded library and populate a struct with the
// relevant C APIs necessary for TPUs.
Status SetTpuInitializeStructFns(void* library_handle) {
auto* base_fn = InitializeApiFn();
TFTPU_SET_FN(base_fn, TfTpu_Initialize);
return Status::OK();
}
Status SetTpuConfigStructFns(void* library_handle) {
auto* config_fn = ConfigApiFn();
TFTPU_SET_FN(config_fn, TPUHostInitialized);
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
return Status::OK();
}
TfTpu_BaseFn* InitializeApiFn() {
static TfTpu_BaseFn base_fn;
return &base_fn;
}
TfTpu_ConfigApiFn* ConfigApiFn() {
static TfTpu_ConfigApiFn config_api_fn;
return &config_api_fn;
}
Status InitializeTpuLibrary(void* library_handle) {
if (library_handle == nullptr) {
library_handle = dlopen(nullptr, RTLD_LAZY);
}
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
return Status::OK();
}

View File

@ -17,13 +17,21 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
// LINT.IfChange
namespace tensorflow {
namespace tpu {
Status InitializeTPULibrary(void* library);
Status InitializeTpuLibrary(void* library_handle);
TfTpu_BaseFn* InitializeApiFn();
TfTpu_ConfigApiFn* ConfigApiFn();
} // namespace tpu
} // namespace tensorflow
// LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc)
#endif // TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/tpu_library_loader.h"
// Reminder: Update tpu_library_loader.cc if you are adding new publicly
// visible methods.
namespace tensorflow {
namespace tpu {
TfTpu_BaseFn* InitializeApiFn() { return nullptr; }
TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; }
Status InitializeTpuLibrary(void* library_handle) {
return errors::Unimplemented(
"Loading TPU library is not supported on Windows.");
}
} // namespace tpu
} // namespace tensorflow