Add a rudimentary library loader implementation for some TPU related functions
PiperOrigin-RevId: 314863001 Change-Id: Iafe056ab3fcf592cd28873e6fd740121f17d1a91
This commit is contained in:
parent
9244dd50fc
commit
75c40f6bff
@ -107,7 +107,7 @@ Status LoadLibrary(const char* library_filename, void** result,
|
||||
if (env->GetSymbolFromLibrary(library.handle, "TfTpu_Initialize",
|
||||
&unused_symbol)
|
||||
.ok()) {
|
||||
TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTPULibrary(library.handle));
|
||||
TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTpuLibrary(library.handle));
|
||||
}
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
# Description: Utilities for TPU Operations
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_windows",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
@ -8,6 +13,13 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libtftpu_header",
|
||||
hdrs = ["libtftpu.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_embedding_optimization_parameters_utils",
|
||||
srcs = ["tpu_embedding_optimization_parameters_utils.cc"],
|
||||
@ -88,14 +100,23 @@ cc_library(
|
||||
name = "tpu_config_c_api",
|
||||
hdrs = ["tpu_config_c_api.h"],
|
||||
deps = [
|
||||
":libtftpu_header",
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_library_loader",
|
||||
srcs = ["tpu_library_loader.cc"],
|
||||
srcs = if_windows(
|
||||
["tpu_library_loader_windows.cc"],
|
||||
otherwise = ["tpu_library_loader.cc"],
|
||||
),
|
||||
hdrs = ["tpu_library_loader.h"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = ["//tensorflow/core/platform:status"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":libtftpu_header",
|
||||
":tpu_config_c_api",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
52
tensorflow/core/tpu/libtftpu.h
Normal file
52
tensorflow/core/tpu/libtftpu.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_LIBTFTPU_H_
|
||||
#define TENSORFLOW_CORE_TPU_LIBTFTPU_H_
|
||||
|
||||
// Unfortunately we have to add an Fn suffix because we cannot have the same
|
||||
// name for both a function and a element within a struct in the global
|
||||
// namespace in gcc. This restriction doesn't exist in clang.
|
||||
#define TFTPU_ADD_FN_IN_STRUCT(FnName) decltype(FnName)* FnName##Fn;
|
||||
|
||||
#ifdef SWIG
|
||||
#define TFTPU_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TFTPU_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TFTPU_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TFTPU_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
TFTPU_CAPI_EXPORT void TfTpu_Initialize();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
struct TfTpu_BaseFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TfTpu_Initialize);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_LIBTFTPU_H_
|
@ -20,40 +20,53 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
|
||||
typedef struct TpuSerializedProto TpuSerializedProto;
|
||||
|
||||
extern "C" {
|
||||
|
||||
bool TPUHostInitialized();
|
||||
TFTPU_CAPI_EXPORT bool TPUHostInitialized();
|
||||
|
||||
void ConfigureDistributedTpuOp_DoWork(const size_t num_cores_per_host_size,
|
||||
const int32_t* num_cores_per_host,
|
||||
size_t* host_config_output_size,
|
||||
char** host_config_output,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
size_t* host_config_output_size, char** host_config_output,
|
||||
TF_Status* status);
|
||||
|
||||
void WaitForDistributedTpuOp_DoWork(
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
const int32_t** host_ordinal_to_global_core_id_map,
|
||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||
TF_Status* status);
|
||||
|
||||
void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
|
||||
|
||||
void InitializeHostForDistributedTpuOp_DoWork(
|
||||
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
||||
const size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
const bool enable_whole_mesh_compilations, size_t* core_id_output_size,
|
||||
int32_t** core_id_output, TF_Status* status);
|
||||
|
||||
void SetGlobalTPUArrayOp_DoWork(const size_t tpu_topology_size,
|
||||
const char* tpu_topology, TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
||||
const size_t tpu_topology_size, const char* tpu_topology,
|
||||
TF_Status* status);
|
||||
|
||||
void DisconnectDistributedTpuChipsOp_DoWork(int32_t* number_of_chips_output,
|
||||
TF_Status* status);
|
||||
TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
|
||||
int32_t* number_of_chips_output, TF_Status* status);
|
||||
|
||||
void TpuConfigurationApi_FreeCharArray(char* output);
|
||||
void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
||||
}
|
||||
|
||||
struct TfTpu_ConfigApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TPUHostInitialized);
|
||||
TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
||||
|
@ -15,14 +15,63 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#define TFTPU_SET_FN(Struct, FnName) \
|
||||
Struct->FnName##Fn = \
|
||||
reinterpret_cast<decltype(FnName)*>(dlsym(library_handle, #FnName));
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
// Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly
|
||||
// visible methods.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
Status InitializeTPULibrary(void* library) {
|
||||
// TODO(frankchn): dlsym the loaded library and populate a struct with the
|
||||
// relevant C APIs necessary for TPUs.
|
||||
Status SetTpuInitializeStructFns(void* library_handle) {
|
||||
auto* base_fn = InitializeApiFn();
|
||||
|
||||
TFTPU_SET_FN(base_fn, TfTpu_Initialize);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetTpuConfigStructFns(void* library_handle) {
|
||||
auto* config_fn = ConfigApiFn();
|
||||
|
||||
TFTPU_SET_FN(config_fn, TPUHostInitialized);
|
||||
TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TfTpu_BaseFn* InitializeApiFn() {
|
||||
static TfTpu_BaseFn base_fn;
|
||||
return &base_fn;
|
||||
}
|
||||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn() {
|
||||
static TfTpu_ConfigApiFn config_api_fn;
|
||||
return &config_api_fn;
|
||||
}
|
||||
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
if (library_handle == nullptr) {
|
||||
library_handle = dlopen(nullptr, RTLD_LAZY);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -17,13 +17,21 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
|
||||
// LINT.IfChange
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
Status InitializeTPULibrary(void* library);
|
||||
Status InitializeTpuLibrary(void* library_handle);
|
||||
|
||||
TfTpu_BaseFn* InitializeApiFn();
|
||||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
// LINT.ThenChange(//tensorflow/core/tpu/tpu_library_loader_windows.cc)
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_
|
||||
|
36
tensorflow/core/tpu/tpu_library_loader_windows.cc
Normal file
36
tensorflow/core/tpu/tpu_library_loader_windows.cc
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/tpu_library_loader.h"
|
||||
|
||||
// Reminder: Update tpu_library_loader.cc if you are adding new publicly
|
||||
// visible methods.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
TfTpu_BaseFn* InitializeApiFn() { return nullptr; }
|
||||
|
||||
TfTpu_ConfigApiFn* ConfigApiFn() { return nullptr; }
|
||||
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
return errors::Unimplemented(
|
||||
"Loading TPU library is not supported on Windows.");
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user