Add TPU runtime version.
PiperOrigin-RevId: 351862770 Change-Id: Ia7de40826afabd8946b43272e8c1fb2cd574d290
This commit is contained in:
parent
7e6ea327fe
commit
ec6cfabc82
@ -14,6 +14,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_GetTopologyPtr);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_GetHostLocation);
|
||||
TFTPU_SET_FN(executor_fn, TpuPlatform_GetRuntimeVersion);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_Init);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_Free);
|
||||
|
@ -41,6 +41,13 @@ enum TpuVersionEnum {
|
||||
kTpuV4,
|
||||
};
|
||||
|
||||
typedef struct TpuRuntimeVersion {
|
||||
// The three version numbers are: major, minor, patch
|
||||
int version[3];
|
||||
const char* metadata;
|
||||
size_t metadata_size;
|
||||
} TpuRuntimeVersion;
|
||||
|
||||
typedef struct SE_Platform SE_Platform;
|
||||
typedef struct SE_StreamExecutor SE_StreamExecutor;
|
||||
typedef struct SE_Stream SE_Stream;
|
||||
|
@ -41,6 +41,7 @@ int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
|
||||
bool TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(SE_Platform* platform);
|
||||
SE_TpuTopology* TpuPlatform_GetTopologyPtr(SE_Platform* platform);
|
||||
SE_TpuTopology_Host* TpuPlatform_GetHostLocation(SE_Platform* platform);
|
||||
TpuRuntimeVersion TpuPlatform_GetRuntimeVersion(SE_Platform* platform);
|
||||
|
||||
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
|
||||
SE_DeviceOptions* device_options, TF_Status* status);
|
||||
@ -353,6 +354,7 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetTopologyPtr);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetHostLocation);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetRuntimeVersion);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Init);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Free);
|
||||
|
@ -128,6 +128,10 @@ const tensorflow::tpu::TpuHostLocationExternal TpuPlatform::GetTpuHostLocation()
|
||||
tpu::ExecutorApiFn()->TpuPlatform_GetHostLocationFn(platform_));
|
||||
}
|
||||
|
||||
TpuRuntimeVersion TpuPlatform::version() const {
|
||||
return tpu::ExecutorApiFn()->TpuPlatform_GetRuntimeVersionFn(platform_);
|
||||
}
|
||||
|
||||
void TpuPlatform::InsertEvent(stream_executor::internal::EventInterface* key,
|
||||
SE_Event* val) {
|
||||
tensorflow::mutex_lock lock(event_map_mu_);
|
||||
|
@ -66,6 +66,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
const tensorflow::tpu::TpuHostLocationExternal GetTpuHostLocation()
|
||||
const override;
|
||||
|
||||
TpuRuntimeVersion version() const override;
|
||||
|
||||
bool Initialized() const override;
|
||||
|
||||
Status Initialize(
|
||||
|
@ -57,6 +57,8 @@ class TpuPlatformInterface : public stream_executor::Platform {
|
||||
|
||||
virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0;
|
||||
|
||||
virtual TpuRuntimeVersion version() const = 0;
|
||||
|
||||
TpuTopologyExternal topology() {
|
||||
return TpuTopologyExternal(GetTopologyPtr());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user