Fix some API linking issues within TPU 1VM (some API calls could only be statically linked)
PiperOrigin-RevId: 322942590 Change-Id: I763f8649fc38a555d0d2a6408a634314ab041b7a
This commit is contained in:
parent
29d635bccc
commit
56853a42e4
@ -132,6 +132,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Create);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Set);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Free);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Message);
|
||||
TFTPU_SET_FN(executor_fn, TpuStatus_Code);
|
||||
@ -174,6 +175,9 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuCoreLocation_Index);
|
||||
TFTPU_SET_FN(executor_fn, TpuCoreLocation_Id);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_Free);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_RunHloPasses);
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_RunBackend);
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_Compile);
|
||||
|
@ -167,8 +167,8 @@ XLA_HloModuleConfig HloModuleConfigToC(const xla::HloModuleConfig& config) {
|
||||
|
||||
class TpuCompiler : public Compiler {
|
||||
public:
|
||||
TpuCompiler() { compiler_ = TpuCompiler_New(); }
|
||||
~TpuCompiler() override {}
|
||||
TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
|
||||
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
|
||||
|
||||
stream_executor::Platform::Id PlatformId() const override {
|
||||
return tensorflow::TpuPlatform::kId;
|
||||
|
@ -42,10 +42,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_defn.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
@ -91,8 +92,9 @@ SE_DeviceMemoryAllocator ToC(
|
||||
->Allocate(device_ordinal, size, retry_on_failure, memory_space);
|
||||
if (!allocation.ok()) {
|
||||
auto status = allocation.status();
|
||||
TpuStatus_Set(se_status, status.code(), status.error_message().data(),
|
||||
status.error_message().size());
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuStatus_SetFn(
|
||||
se_status, status.code(), status.error_message().data(),
|
||||
status.error_message().size());
|
||||
} else {
|
||||
auto& scoped_memory = allocation.ValueOrDie();
|
||||
memory->wrapped = ApiConverter::ToC(scoped_memory.Release());
|
||||
@ -105,8 +107,9 @@ SE_DeviceMemoryAllocator ToC(
|
||||
auto status = reinterpret_cast<stream_executor::DeviceMemoryAllocator*>(ctx)
|
||||
->Deallocate(device_ordinal, ApiConverter::FromC(*base));
|
||||
if (!status.ok()) {
|
||||
TpuStatus_Set(se_status, status.code(), status.error_message().data(),
|
||||
status.error_message().size());
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuStatus_SetFn(
|
||||
se_status, status.code(), status.error_message().data(),
|
||||
status.error_message().size());
|
||||
}
|
||||
};
|
||||
return se_allocator;
|
||||
|
@ -304,6 +304,7 @@ struct TfTpu_ExecutorApiFn {
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Create);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Set);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Free);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Message);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code);
|
||||
@ -346,6 +347,9 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Index);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Id);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunHloPasses);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_RunBackend);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
||||
|
@ -62,7 +62,8 @@ Status TpuNodeContext::CloseTpuHost() {
|
||||
/* static */
|
||||
Status TpuNodeContext::Initialize(int device_ordinal) {
|
||||
StatusHelper status;
|
||||
TpuNodeContext_Initialize(device_ordinal, status.c_status);
|
||||
tpu::NodeContextApiFn()->TpuNodeContext_InitializeFn(device_ordinal,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
|
@ -84,10 +84,11 @@ struct TransferFromDeviceState {
|
||||
std::function<void(Status)> done;
|
||||
|
||||
void TransferFinished(SE_Status* status) {
|
||||
if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) {
|
||||
if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
|
||||
tpu::ExecutorApiFn()->TpuStatus_OkFn(status_helper.c_status)) {
|
||||
status_helper.c_status = status;
|
||||
} else {
|
||||
TpuStatus_Free(status);
|
||||
tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
|
||||
}
|
||||
|
||||
if (--remaining_transfers == 0) {
|
||||
|
Loading…
Reference in New Issue
Block a user