diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index cd5beec6296..5022ad6228d 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -156,6 +156,7 @@ cc_library( deps = [ ":libtftpu_header", ":tpu_api", + ":tpu_api_dlsym_set_fn", ":tpu_compilation_device", ":tpu_config_c_api", ":tpu_executor_init_fns", @@ -174,10 +175,17 @@ cc_library( ], ) +cc_library( + name = "tpu_api_dlsym_set_fn", + hdrs = ["tpu_api_dlsym_set_fn.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "tpu_library_init_fns", hdrs = ["tpu_library_init_fns.inc"], visibility = ["//visibility:public"], + deps = [":tpu_executor_init_fns"], ) cc_library( diff --git a/tensorflow/core/tpu/libtftpu.h b/tensorflow/core/tpu/libtftpu.h index a4405df8205..9171af87061 100644 --- a/tensorflow/core/tpu/libtftpu.h +++ b/tensorflow/core/tpu/libtftpu.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #ifndef TENSORFLOW_CORE_TPU_LIBTFTPU_H_ #define TENSORFLOW_CORE_TPU_LIBTFTPU_H_ @@ -39,7 +41,7 @@ limitations under the License. extern "C" { #endif -TFTPU_CAPI_EXPORT void TfTpu_Initialize(); +TFTPU_CAPI_EXPORT void TfTpu_Initialize(bool init_library); #ifdef __cplusplus } diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc index 4dc09770c38..e4d723305a9 100644 --- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc +++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h" #if !defined(PLATFORM_GOOGLE) #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/core/tpu/tpu_node_device.h" @@ -27,13 +28,6 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_platform.h" #endif -#define TFTPU_SET_FN(Struct, FnName) \ - Struct->FnName##Fn = \ - reinterpret_cast(dlsym(library_handle, #FnName)); \ - if (!(Struct->FnName##Fn)) { \ - LOG(FATAL) << #FnName " not available in this library."; \ - return errors::Unimplemented(#FnName " not available in this library."); \ - } // Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly // visible methods. @@ -55,10 +49,10 @@ Status InitializeTpuLibrary(void* library_handle) { // loaded. We do not want to register a TPU platform in XLA without the // supporting library providing the necessary APIs. if (s.ok()) { - void (*initialize_fn)(); + void (*initialize_fn)(bool init_library); initialize_fn = reinterpret_cast( dlsym(library_handle, "TfTpu_Initialize")); - (*initialize_fn)(); + (*initialize_fn)(/*init_library=*/true); RegisterTpuPlatform(); RegisterTpuSystemDevice(); diff --git a/tensorflow/core/tpu/tpu_api_dlsym_set_fn.h b/tensorflow/core/tpu/tpu_api_dlsym_set_fn.h new file mode 100644 index 00000000000..a1e13550d96 --- /dev/null +++ b/tensorflow/core/tpu/tpu_api_dlsym_set_fn.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_API_DLSYM_SET_FN_H_ +#define TENSORFLOW_CORE_TPU_TPU_API_DLSYM_SET_FN_H_ + +#define TFTPU_SET_FN(Struct, FnName) \ + Struct->FnName##Fn = \ + reinterpret_cast(dlsym(library_handle, #FnName)); \ + if (!(Struct->FnName##Fn)) { \ + LOG(FATAL) << #FnName " not available in this library."; \ + return errors::Unimplemented(#FnName " not available in this library."); \ + } + +#endif // TENSORFLOW_CORE_TPU_TPU_API_DLSYM_SET_FN_H_