Change all SE_Status to the common TF_Status for better compatibility
PiperOrigin-RevId: 339599554 Change-Id: I739ca5fee92734783a1c50dc0f19bd063253cfe4
This commit is contained in:
parent
06abb2cbb6
commit
f162769b41
@ -89,7 +89,7 @@ typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
|
||||
// API respectively.
|
||||
TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
|
||||
TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
|
||||
XLA_TpuProgram** tpu_programs[], size_t* count, SE_Status* status);
|
||||
XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
|
||||
|
||||
// Creates a new TPU mesh state object.
|
||||
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
|
||||
@ -107,7 +107,7 @@ TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
|
||||
size_t arguments_len, SE_DeviceMemoryBase* result,
|
||||
SE_DeviceMemoryBase* cross_program_prefetch_addr, int32_t rng_seed,
|
||||
XLA_DeviceAssignment* device_assignment, SE_Stream* stream,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
|
||||
XLA_Shape* host_shape, XLA_Shape* device_shape);
|
||||
@ -118,7 +118,7 @@ TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
|
||||
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
|
||||
uint32_t* runtime_input_ptr, size_t runtime_input_size,
|
||||
int8_t* padded_data_ptr, size_t padded_data_size, XLA_Shape* runtime_shape,
|
||||
XLA_Shape* compile_time_shape, SE_Status* status);
|
||||
XLA_Shape* compile_time_shape, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
@ -180,7 +180,7 @@ TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
|
||||
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
|
||||
// destroyed, it is in an unusable state.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
// Gets TPU program size in bytes from the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT int64_t
|
||||
@ -193,17 +193,17 @@ TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
|
||||
// Gets TPU program executable info from the `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
// Gets host transfer info proto.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
// Gets HLO metadata proto.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
||||
const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
// Gets may modify variables boolean value.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||
@ -221,17 +221,17 @@ TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
||||
// Gets TPU executable proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
|
||||
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
// Gets compilation metadata proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
|
||||
const XLA_TpuProgram* tpu_program,
|
||||
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
|
||||
CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status);
|
||||
|
||||
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
|
||||
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
// Checks if whether a TPU compilation is enabled.
|
||||
TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
|
||||
@ -267,14 +267,14 @@ TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
|
||||
uint64_t fingerprint, const char* data, size_t size);
|
||||
|
||||
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||
|
||||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||
void TpuNodeContext_StopChipHeartbeats(TF_Status* status);
|
||||
|
||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||
void TpuNodeContext_CloseTpuHost(TF_Status* status);
|
||||
|
||||
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
|
||||
void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status);
|
||||
|
||||
struct TfTpu_OpsApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
|
||||
|
@ -92,7 +92,7 @@ SE_DeviceMemoryAllocator ToC(
|
||||
se_allocator.allocate = [](void* ctx, int device_ordinal, uint64_t size,
|
||||
bool retry_on_failure, int64_t memory_space,
|
||||
SE_ScopedDeviceMemory* memory,
|
||||
SE_Status* se_status) {
|
||||
TF_Status* se_status) {
|
||||
auto allocation =
|
||||
reinterpret_cast<stream_executor::DeviceMemoryAllocator*>(ctx)
|
||||
->Allocate(device_ordinal, size, retry_on_failure, memory_space);
|
||||
@ -109,7 +109,7 @@ SE_DeviceMemoryAllocator ToC(
|
||||
};
|
||||
|
||||
se_allocator.deallocate = [](void* ctx, SE_DeviceMemoryBase* base,
|
||||
int device_ordinal, SE_Status* se_status) {
|
||||
int device_ordinal, TF_Status* se_status) {
|
||||
auto status = reinterpret_cast<stream_executor::DeviceMemoryAllocator*>(ctx)
|
||||
->Deallocate(device_ordinal, ApiConverter::FromC(*base));
|
||||
if (!status.ok()) {
|
||||
|
@ -41,8 +41,6 @@ enum TpuVersionEnum {
|
||||
kTpuV4,
|
||||
};
|
||||
|
||||
typedef struct SE_Status SE_Status;
|
||||
|
||||
typedef struct SE_Platform SE_Platform;
|
||||
typedef struct SE_StreamExecutor SE_StreamExecutor;
|
||||
typedef struct SE_Stream SE_Stream;
|
||||
@ -59,7 +57,7 @@ typedef struct SE_PlatformId {
|
||||
} SE_PlatformId;
|
||||
typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig;
|
||||
typedef struct SE_DeviceOptions SE_DeviceOptions;
|
||||
typedef SE_Status* (*SE_StatusCallbackFn)(void*);
|
||||
typedef TF_Status* (*SE_StatusCallbackFn)(void*);
|
||||
|
||||
typedef struct SE_DeviceMemoryBase {
|
||||
void* opaque;
|
||||
@ -95,10 +93,10 @@ typedef struct SE_AllocatorStats {
|
||||
// direction and request memory via a callback.
|
||||
typedef void (*SE_AllocateFn)(void* ctx, int device_ordinal, uint64_t size,
|
||||
bool retry_on_failure, int64_t memory_space,
|
||||
SE_ScopedDeviceMemory* result, SE_Status* status);
|
||||
SE_ScopedDeviceMemory* result, TF_Status* status);
|
||||
|
||||
typedef void (*SE_DeallocateFn)(void* ctx, SE_DeviceMemoryBase* base,
|
||||
int device_ordinal, SE_Status* status);
|
||||
int device_ordinal, TF_Status* status);
|
||||
|
||||
typedef struct SE_DeviceMemoryAllocator {
|
||||
SE_Platform* platform;
|
||||
@ -299,7 +297,7 @@ typedef struct XLA_TransferManager XLA_TransferManager;
|
||||
typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
|
||||
|
||||
typedef void (*XLA_CallbackFn)(void*);
|
||||
typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*);
|
||||
typedef void (*XLA_StatusCallbackFn)(void*, TF_Status*);
|
||||
|
||||
typedef struct SE_TpuTopology SE_TpuTopology;
|
||||
typedef struct SE_TpuTopology_Core SE_TpuTopology_Core;
|
||||
|
@ -29,7 +29,7 @@ class StatusHelper {
|
||||
tensorflow::tpu::ExecutorApiFn()->TpuStatus_FreeFn(c_status);
|
||||
}
|
||||
|
||||
static tensorflow::Status FromC(SE_Status* const c_status) {
|
||||
static tensorflow::Status FromC(TF_Status* const c_status) {
|
||||
if (tensorflow::tpu::ExecutorApiFn()->TpuStatus_OkFn(c_status)) {
|
||||
return tensorflow::Status::OK();
|
||||
} else {
|
||||
@ -46,7 +46,7 @@ class StatusHelper {
|
||||
|
||||
tensorflow::Status status() const { return FromC(c_status); }
|
||||
|
||||
SE_Status* const c_status; // NOLINT
|
||||
TF_Status* const c_status; // NOLINT
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
|
||||
|
@ -331,10 +331,10 @@ struct HostCallbackContext {
|
||||
std::function<Status()> callback;
|
||||
};
|
||||
|
||||
SE_Status* HostCallbackTrampoline(void* ctx) {
|
||||
TF_Status* HostCallbackTrampoline(void* ctx) {
|
||||
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
|
||||
Status status = host_ctx->callback();
|
||||
SE_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
|
||||
TF_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
|
||||
status.code(), status.error_message().c_str());
|
||||
delete host_ctx;
|
||||
return c_status;
|
||||
|
@ -30,11 +30,11 @@ SE_Platform* TpuPlatform_New();
|
||||
void TpuPlatform_Free(SE_Platform* platform);
|
||||
void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size,
|
||||
const char** options_key,
|
||||
const char** options_value, SE_Status* status);
|
||||
const char** options_value, TF_Status* status);
|
||||
bool TpuPlatform_Initialized(SE_Platform* platform);
|
||||
SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
|
||||
SE_StreamExecutorConfig* config,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
|
||||
int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
|
||||
int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
|
||||
@ -43,7 +43,7 @@ SE_TpuTopology* TpuPlatform_GetTopologyPtr(SE_Platform* platform);
|
||||
SE_TpuTopology_Host* TpuPlatform_GetHostLocation(SE_Platform* platform);
|
||||
|
||||
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
|
||||
SE_DeviceOptions* device_options, SE_Status* status);
|
||||
SE_DeviceOptions* device_options, TF_Status* status);
|
||||
void TpuExecutor_Free(SE_StreamExecutor* executor);
|
||||
|
||||
int TpuExecutor_PlatformDeviceCount(SE_StreamExecutor* executor);
|
||||
@ -63,20 +63,20 @@ void TpuExecutor_DeallocateStream(SE_StreamExecutor* executor,
|
||||
bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor,
|
||||
SE_Stream* dependent, SE_Stream* other);
|
||||
void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
SE_TpuTopology_Core* TpuExecutor_GetCoreLocation(SE_StreamExecutor* executor);
|
||||
|
||||
void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
int TpuExecutor_PollForEventStatus(SE_StreamExecutor* executor,
|
||||
SE_Event* event);
|
||||
void TpuExecutor_RecordEvent(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Event* event, SE_Status* status);
|
||||
SE_Event* event, TF_Status* status);
|
||||
void TpuExecutor_WaitForEvent(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Event* event, SE_Status* status);
|
||||
SE_Event* event, TF_Status* status);
|
||||
|
||||
bool TpuExecutor_AllocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
|
||||
void TpuExecutor_DeallocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
|
||||
@ -88,11 +88,11 @@ bool TpuExecutor_StopTimer(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
void TpuExecutor_SynchronousMemcpyToHost(SE_StreamExecutor* executor,
|
||||
void* host_dst,
|
||||
const SE_DeviceMemoryBase* device_src,
|
||||
uint64_t size, SE_Status* status);
|
||||
uint64_t size, TF_Status* status);
|
||||
void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor,
|
||||
SE_DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64_t size,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
void* host_dst,
|
||||
const SE_DeviceMemoryBase* device_src,
|
||||
@ -104,21 +104,21 @@ bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
|
||||
void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor,
|
||||
int32_t infeed_queue_index, const uint8_t* data,
|
||||
int64_t size, SE_Status* status);
|
||||
int64_t size, TF_Status* status);
|
||||
void TpuExecutor_DequeueOutfeed(SE_StreamExecutor* executor,
|
||||
int32_t outfeed_queue_index, uint8_t* data,
|
||||
int64_t size, SE_Status* status);
|
||||
int64_t size, TF_Status* status);
|
||||
void TpuExecutor_WaitForInfeedReady(SE_StreamExecutor* executor,
|
||||
int32_t infeed_queue_index,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuExecutor_WaitForOutfeedReady(SE_StreamExecutor* executor,
|
||||
int32_t outfeed_queue_index,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
void TpuExecutor_BlockHostUntilDone(SE_StreamExecutor* executor,
|
||||
SE_Stream* stream, SE_Status* status);
|
||||
SE_Stream* stream, TF_Status* status);
|
||||
void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
|
||||
bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
|
||||
|
||||
@ -130,15 +130,15 @@ bool TpuStream_IsSameSharedMemoryLocation(SE_Stream*, SE_Stream*);
|
||||
void TpuStream_EnqueueTransferHostToDevice(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase device_dst,
|
||||
void* host_src, uint64_t size,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuStream_EnqueueTransferDeviceToHost(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase device_src,
|
||||
void* host_dst, uint64_t size,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuStream_TpuEnqueueOnDeviceSendRecvLocal(SE_Stream* stream,
|
||||
SE_DeviceMemoryBase send_buffer,
|
||||
SE_DeviceMemoryBase recv_buffer,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
|
||||
void TpuEvent_Free(SE_Event*);
|
||||
@ -148,14 +148,14 @@ void TpuTimer_Free(SE_Timer*);
|
||||
int64_t TpuTimer_Nanoseconds(SE_Timer*);
|
||||
int64_t TpuTimer_Microseconds(SE_Timer*);
|
||||
|
||||
SE_Status* TpuStatus_New();
|
||||
SE_Status* TpuStatus_Create(int32_t code, const char* msg);
|
||||
void TpuStatus_Set(SE_Status* status, int32_t code, const char* msg,
|
||||
TF_Status* TpuStatus_New();
|
||||
TF_Status* TpuStatus_Create(int32_t code, const char* msg);
|
||||
void TpuStatus_Set(TF_Status* status, int32_t code, const char* msg,
|
||||
int32_t len);
|
||||
void TpuStatus_Free(SE_Status* status);
|
||||
const char* TpuStatus_Message(SE_Status* status);
|
||||
int TpuStatus_Code(SE_Status* status);
|
||||
bool TpuStatus_Ok(SE_Status* status);
|
||||
void TpuStatus_Free(TF_Status* status);
|
||||
const char* TpuStatus_Message(TF_Status* status);
|
||||
int TpuStatus_Code(TF_Status* status);
|
||||
bool TpuStatus_Ok(TF_Status* status);
|
||||
|
||||
SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default();
|
||||
void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal);
|
||||
@ -165,7 +165,7 @@ SE_DeviceDescription* TpuDeviceDescription_New();
|
||||
void TpuDeviceDescription_Free(SE_DeviceDescription* description);
|
||||
void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor,
|
||||
SE_DeviceDescription* description,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags);
|
||||
void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options);
|
||||
@ -181,7 +181,7 @@ void TpuTransferManager_HostShapeToDeviceShape(XLA_TransferManager* manager,
|
||||
XLA_Shape* device_shape);
|
||||
void TpuTransferManager_TransferLiteralToDeviceAsync(
|
||||
XLA_TransferManager* manager, SE_Stream* stream, XLA_Literal* literal,
|
||||
XLA_ShapedBuffer* device_buffer, SE_Status* status);
|
||||
XLA_ShapedBuffer* device_buffer, TF_Status* status);
|
||||
void TpuTransferManager_TransferLiteralFromDevice(
|
||||
XLA_TransferManager* manager, SE_Stream* stream,
|
||||
XLA_ShapedBuffer* device_buffer, XLA_Literal* literal,
|
||||
@ -190,7 +190,7 @@ int64_t TpuTransferManager_GetByteSizeRequirement(XLA_TransferManager* manager,
|
||||
XLA_Shape* shape);
|
||||
void TpuTransferManager_ChooseCompactLayoutForShape(
|
||||
XLA_TransferManager* manager, XLA_Shape* host_shape, XLA_Shape* output,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
bool TpuTransferManager_CanShapedBufferBeAccessedNow(
|
||||
XLA_TransferManager* manager, SE_StreamExecutor* executor,
|
||||
XLA_ShapedBuffer* device_buffer);
|
||||
@ -200,32 +200,32 @@ bool TpuTransferManager_CanBufferBeAccessedNow(
|
||||
void TpuTransferManager_WriteSingleTupleIndexTable(
|
||||
XLA_TransferManager* manager, SE_Stream* stream,
|
||||
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
||||
SE_DeviceMemoryBase* region, SE_Status* status);
|
||||
SE_DeviceMemoryBase* region, TF_Status* status);
|
||||
void TpuTransferManager_GetInfeedLayout(XLA_Shape* shape,
|
||||
XLA_Shape* infeed_shape);
|
||||
void TpuTransferManager_LinearizeToBuffers(
|
||||
XLA_TransferManager* manager, XLA_Literal* c_literal, char*** buffers_array,
|
||||
int64_t** buffers_size, int64_t* buffers_array_size, SE_Status* status);
|
||||
int64_t** buffers_size, int64_t* buffers_array_size, TF_Status* status);
|
||||
void TpuTransferManager_FreeBuffers(char** buffers_array, int64_t* buffers_size,
|
||||
int64_t buffers_array_size);
|
||||
void TpuTransferManager_TransferLiteralToInfeed(XLA_TransferManager* manager,
|
||||
SE_StreamExecutor* executor,
|
||||
XLA_Literal* c_literal,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuTransferManager_TransferBuffersToInfeed(XLA_TransferManager* manager,
|
||||
SE_StreamExecutor* executor,
|
||||
uint32_t** buffers_array,
|
||||
int64_t* buffers_size_in_uint32,
|
||||
int64_t buffers_array_size,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuTransferManager_TransferLiteralFromOutfeed(XLA_TransferManager* manager,
|
||||
SE_StreamExecutor* executor,
|
||||
XLA_Shape* shape,
|
||||
XLA_Literal* c_literal,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuTransferManager_ResetDevices(XLA_TransferManager* manager,
|
||||
SE_StreamExecutor** executors,
|
||||
int64_t num_executors, SE_Status* status);
|
||||
int64_t num_executors, TF_Status* status);
|
||||
|
||||
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
||||
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||
@ -235,12 +235,12 @@ void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||
void TpuComputationPlacer_AssignDevices(XLA_ComputationPlacer* placer,
|
||||
int replica_count,
|
||||
int computation_count, int* assignment,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
void TpuComputationPlacer_AssignLocalDevices(SE_TpuTopology_Host* host,
|
||||
int replica_count,
|
||||
int computation_count,
|
||||
int* assignment,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
int TpuTopology_LogicalDevicesPerHost(SE_TpuTopology* tpu_topology,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
@ -287,27 +287,27 @@ TFTPU_CAPI_EXPORT void TpuCompiler_Free(Tpu_Compiler* compiler);
|
||||
TFTPU_CAPI_EXPORT void TpuCompiler_RunHloPasses(
|
||||
Tpu_Compiler* compiler, XLA_HloModule* se_hlo_module,
|
||||
SE_StreamExecutor* stream_executor, SE_DeviceMemoryAllocator* allocator,
|
||||
XLA_HloModule* result, SE_Status* status);
|
||||
XLA_HloModule* result, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuCompiler_RunBackend(
|
||||
Tpu_Compiler* compiler, XLA_HloModule* se_hlo_module,
|
||||
SE_StreamExecutor* stream_executor, SE_DeviceMemoryAllocator* allocator,
|
||||
SE_Executable** result, SE_Status* status);
|
||||
SE_Executable** result, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuCompiler_Compile(
|
||||
Tpu_Compiler* compiler, XLA_HloModuleGroup* se_hlo_module_group,
|
||||
SE_StreamExecutorList* stream_exec_lists, int num_lists,
|
||||
SE_DeviceMemoryAllocator* allocator, SE_Executable** executables,
|
||||
SE_Status* status);
|
||||
TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT int64_t TpuCompiler_ShapeSize(Tpu_Compiler* compiler,
|
||||
XLA_Shape* c_shape);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_ExecuteAsyncOnStream(
|
||||
SE_Executable* executable, SE_ExecutableRunOptions* run_options,
|
||||
SE_Executable* executable, SE_ExecutableRunOptions* se_options,
|
||||
SE_ExecutionInput** se_arguments, int se_arguments_size,
|
||||
SE_HloExecutionProfile* hlo_execution_profile, SE_ExecutionOutput* output,
|
||||
SE_Status* status);
|
||||
SE_HloExecutionProfile* hlo_execution_profile,
|
||||
SE_ExecutionOutput* se_output, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
|
||||
const char** fingerprint,
|
||||
@ -323,11 +323,11 @@ TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
||||
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
||||
TFTPU_CAPI_EXPORT void XlaShapeToTpuShapeRepresentation(
|
||||
XLA_Shape* serialized_xla_shape, int data_type, bool use_fast_memory,
|
||||
XLA_Shape* serialized_tpu_shape, SE_Status* status);
|
||||
XLA_Shape* serialized_tpu_shape, TF_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void XlaShapeToTpuPaddedShape(XLA_Shape* serialized_xla_shape,
|
||||
XLA_Shape* serialized_tpu_shape,
|
||||
SE_Status* status);
|
||||
XLA_Shape* padded_shape,
|
||||
TF_Status* status);
|
||||
|
||||
struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_New);
|
||||
|
@ -163,11 +163,11 @@ Status TpuTransferManager::ResetDevices(
|
||||
|
||||
struct TransferFromDeviceState {
|
||||
std::atomic<int64_t> remaining_transfers;
|
||||
SE_Status* overall_status =
|
||||
TF_Status* overall_status =
|
||||
tpu::ExecutorApiFn()->TpuStatus_NewFn(); // OK or the first error
|
||||
std::function<void(Status)> done;
|
||||
|
||||
void TransferFinished(SE_Status* status) {
|
||||
void TransferFinished(TF_Status* status) {
|
||||
if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
|
||||
tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
|
||||
std::swap(overall_status, status);
|
||||
@ -182,7 +182,7 @@ struct TransferFromDeviceState {
|
||||
}
|
||||
};
|
||||
|
||||
void TransferLiteralFromDeviceTrampoline(void* ctx, SE_Status* status) {
|
||||
void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
|
||||
reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user