Fix some API linking issues within TPU 1VM (some API calls could only be statically linked)

PiperOrigin-RevId: 322942590
Change-Id: I763f8649fc38a555d0d2a6408a634314ab041b7a
This commit is contained in:
Frank Chen 2020-07-23 23:16:23 -07:00 committed by TensorFlower Gardener
parent 29d635bccc
commit 56853a42e4
7 changed files with 23 additions and 10 deletions

View File

@ -132,6 +132,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuStatus_New);
TFTPU_SET_FN(executor_fn, TpuStatus_Create);
TFTPU_SET_FN(executor_fn, TpuStatus_Set);
TFTPU_SET_FN(executor_fn, TpuStatus_Free);
TFTPU_SET_FN(executor_fn, TpuStatus_Message);
TFTPU_SET_FN(executor_fn, TpuStatus_Code);
@ -174,6 +175,9 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuCoreLocation_Index);
TFTPU_SET_FN(executor_fn, TpuCoreLocation_Id);
TFTPU_SET_FN(executor_fn, TpuCompiler_New);
TFTPU_SET_FN(executor_fn, TpuCompiler_Free);
TFTPU_SET_FN(executor_fn, TpuCompiler_RunHloPasses);
TFTPU_SET_FN(executor_fn, TpuCompiler_RunBackend);
TFTPU_SET_FN(executor_fn, TpuCompiler_Compile);

View File

@ -167,8 +167,8 @@ XLA_HloModuleConfig HloModuleConfigToC(const xla::HloModuleConfig& config) {
class TpuCompiler : public Compiler {
public:
TpuCompiler() { compiler_ = TpuCompiler_New(); }
~TpuCompiler() override {}
TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
stream_executor::Platform::Id PlatformId() const override {
return tensorflow::TpuPlatform::kId;

View File

@ -42,10 +42,10 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/container:inlined_vector",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/c_api_defn.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
@ -91,8 +92,9 @@ SE_DeviceMemoryAllocator ToC(
->Allocate(device_ordinal, size, retry_on_failure, memory_space);
if (!allocation.ok()) {
auto status = allocation.status();
TpuStatus_Set(se_status, status.code(), status.error_message().data(),
status.error_message().size());
tensorflow::tpu::ExecutorApiFn()->TpuStatus_SetFn(
se_status, status.code(), status.error_message().data(),
status.error_message().size());
} else {
auto& scoped_memory = allocation.ValueOrDie();
memory->wrapped = ApiConverter::ToC(scoped_memory.Release());
@ -105,8 +107,9 @@ SE_DeviceMemoryAllocator ToC(
auto status = reinterpret_cast<stream_executor::DeviceMemoryAllocator*>(ctx)
->Deallocate(device_ordinal, ApiConverter::FromC(*base));
if (!status.ok()) {
TpuStatus_Set(se_status, status.code(), status.error_message().data(),
status.error_message().size());
tensorflow::tpu::ExecutorApiFn()->TpuStatus_SetFn(
se_status, status.code(), status.error_message().data(),
status.error_message().size());
}
};
return se_allocator;

View File

@ -304,6 +304,7 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_New);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Set);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Message);
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code);
@ -346,6 +347,9 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Index);
TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Id);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunHloPasses);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunBackend);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);

View File

@ -62,7 +62,8 @@ Status TpuNodeContext::CloseTpuHost() {
/* static */
Status TpuNodeContext::Initialize(int device_ordinal) {
StatusHelper status;
TpuNodeContext_Initialize(device_ordinal, status.c_status);
tpu::NodeContextApiFn()->TpuNodeContext_InitializeFn(device_ordinal,
status.c_status);
return status.status();
}

View File

@ -84,10 +84,11 @@ struct TransferFromDeviceState {
std::function<void(Status)> done;
void TransferFinished(SE_Status* status) {
if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) {
if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
tpu::ExecutorApiFn()->TpuStatus_OkFn(status_helper.c_status)) {
status_helper.c_status = status;
} else {
TpuStatus_Free(status);
tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
}
if (--remaining_transfers == 0) {