Add TPU runtime version.

PiperOrigin-RevId: 351862770
Change-Id: Ia7de40826afabd8946b43272e8c1fb2cd574d290
This commit is contained in:
Skye Wanderman-Milne 2021-01-14 13:14:57 -08:00 committed by TensorFlower Gardener
parent 7e6ea327fe
commit ec6cfabc82
6 changed files with 18 additions and 0 deletions

View File

@ -14,6 +14,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
TFTPU_SET_FN(executor_fn, TpuPlatform_GetTopologyPtr);
TFTPU_SET_FN(executor_fn, TpuPlatform_GetHostLocation);
TFTPU_SET_FN(executor_fn, TpuPlatform_GetRuntimeVersion);
TFTPU_SET_FN(executor_fn, TpuExecutor_Init);
TFTPU_SET_FN(executor_fn, TpuExecutor_Free);

View File

@ -41,6 +41,13 @@ enum TpuVersionEnum {
kTpuV4,
};
typedef struct TpuRuntimeVersion {
// The three version numbers are: major, minor, patch
int version[3];
const char* metadata;
size_t metadata_size;
} TpuRuntimeVersion;
typedef struct SE_Platform SE_Platform;
typedef struct SE_StreamExecutor SE_StreamExecutor;
typedef struct SE_Stream SE_Stream;

View File

@ -41,6 +41,7 @@ int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
bool TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(SE_Platform* platform);
SE_TpuTopology* TpuPlatform_GetTopologyPtr(SE_Platform* platform);
SE_TpuTopology_Host* TpuPlatform_GetHostLocation(SE_Platform* platform);
TpuRuntimeVersion TpuPlatform_GetRuntimeVersion(SE_Platform* platform);
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
SE_DeviceOptions* device_options, TF_Status* status);
@ -353,6 +354,7 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetTopologyPtr);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetHostLocation);
TFTPU_ADD_FN_IN_STRUCT(TpuPlatform_GetRuntimeVersion);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Init);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_Free);

View File

@ -128,6 +128,10 @@ const tensorflow::tpu::TpuHostLocationExternal TpuPlatform::GetTpuHostLocation()
tpu::ExecutorApiFn()->TpuPlatform_GetHostLocationFn(platform_));
}
TpuRuntimeVersion TpuPlatform::version() const {
return tpu::ExecutorApiFn()->TpuPlatform_GetRuntimeVersionFn(platform_);
}
void TpuPlatform::InsertEvent(stream_executor::internal::EventInterface* key,
SE_Event* val) {
tensorflow::mutex_lock lock(event_map_mu_);

View File

@ -66,6 +66,8 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
const tensorflow::tpu::TpuHostLocationExternal GetTpuHostLocation()
const override;
TpuRuntimeVersion version() const override;
bool Initialized() const override;
Status Initialize(

View File

@ -57,6 +57,8 @@ class TpuPlatformInterface : public stream_executor::Platform {
virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0;
virtual TpuRuntimeVersion version() const = 0;
TpuTopologyExternal topology() {
return TpuTopologyExternal(GetTopologyPtr());
}