Update Configure TPU, Wait For TPU, and Initialize TPU APIs to backward compatible API style
PiperOrigin-RevId: 342301514 Change-Id: I867bd45628db09df8854685d70328391a8ffb6ed
This commit is contained in:
parent
55a470ba5d
commit
0c444b107f
@ -207,10 +207,20 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
auto* mesh_common_state = mesh_state->mesh_common_state();
|
auto* mesh_common_state = mesh_state->mesh_common_state();
|
||||||
tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
|
||||||
num_hosts, num_devices_per_host,
|
WaitForDistributedTpuOp_DoWork_Params params;
|
||||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
|
params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
|
||||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
params.priv = nullptr;
|
||||||
|
params.num_hosts = num_hosts;
|
||||||
|
params.num_cores_per_host = num_devices_per_host;
|
||||||
|
params.host_ordinal_to_global_core_id_map =
|
||||||
|
const_cast<const int32_t**>(mapping_arg.data());
|
||||||
|
params.tpu_mesh_common_state = mesh_common_state;
|
||||||
|
params.tpu_topology_output_size = &tpu_topology_output_size;
|
||||||
|
params.tpu_topology_output = &tpu_topology_output;
|
||||||
|
params.status = status;
|
||||||
|
|
||||||
|
tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(¶ms);
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
|
|
||||||
@ -284,10 +294,19 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||||
});
|
});
|
||||||
tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
|
||||||
tpu_host_config.size(), tpu_host_config.data(),
|
InitializeHostForDistributedTpuOp_DoWork_Params params;
|
||||||
enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
|
params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
|
||||||
&device_id_output, status);
|
params.priv = nullptr;
|
||||||
|
params.tpu_host_config_size = tpu_host_config.size();
|
||||||
|
params.tpu_host_config = tpu_host_config.data();
|
||||||
|
params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
|
||||||
|
params.is_master_worker = is_master_worker;
|
||||||
|
params.core_id_output_size = &device_id_output_size;
|
||||||
|
params.core_id_output = &device_id_output;
|
||||||
|
params.status = status;
|
||||||
|
|
||||||
|
tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(¶ms);
|
||||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||||
|
|
||||||
if (local_compilation_cache != nullptr) {
|
if (local_compilation_cache != nullptr) {
|
||||||
|
@ -152,10 +152,19 @@ Status ConstructTpuPodState(
|
|||||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||||
});
|
});
|
||||||
size_t host_config_output_size;
|
size_t host_config_output_size;
|
||||||
tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
|
||||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
ConfigureDistributedTpuOp_DoWork_Params params;
|
||||||
server_address.size(), server_address.data(), &host_config_output_size,
|
params.struct_size = ConfigureDistributedTpuOp_DoWork_Params_SIZE;
|
||||||
&host_config_output, status);
|
params.priv = nullptr;
|
||||||
|
params.num_cores_per_host_size = num_devices_per_host.size();
|
||||||
|
params.num_cores_per_host = num_devices_per_host.data();
|
||||||
|
params.server_address_size = server_address.size();
|
||||||
|
params.server_address = server_address.data();
|
||||||
|
params.host_config_output_size = &host_config_output_size;
|
||||||
|
params.host_config_output = &host_config_output;
|
||||||
|
params.status = status;
|
||||||
|
|
||||||
|
tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(¶ms);
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
|
||||||
*host_config_proto = std::string(host_config_output, host_config_output_size);
|
*host_config_proto = std::string(host_config_output, host_config_output_size);
|
||||||
|
|
||||||
|
@ -244,6 +244,7 @@ xla::Status UpdateDynamicInputs(
|
|||||||
TpuExecute_RuntimeInputToPaddedData_Params params;
|
TpuExecute_RuntimeInputToPaddedData_Params params;
|
||||||
params.struct_size =
|
params.struct_size =
|
||||||
TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
|
TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
|
||||||
|
params.priv = nullptr;
|
||||||
params.runtime_input_ptr = raw_input_runtime->data();
|
params.runtime_input_ptr = raw_input_runtime->data();
|
||||||
params.runtime_input_size = raw_input_runtime->size();
|
params.runtime_input_size = raw_input_runtime->size();
|
||||||
params.padded_data_ptr = padded_data->data();
|
params.padded_data_ptr = padded_data->data();
|
||||||
|
@ -107,6 +107,7 @@ TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
|
|||||||
typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
|
typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
|
||||||
int32_t struct_size;
|
int32_t struct_size;
|
||||||
void* priv;
|
void* priv;
|
||||||
|
|
||||||
const XLA_TpuProgram* program;
|
const XLA_TpuProgram* program;
|
||||||
SE_DeviceMemoryBase* arguments;
|
SE_DeviceMemoryBase* arguments;
|
||||||
size_t arguments_len;
|
size_t arguments_len;
|
||||||
@ -134,6 +135,7 @@ TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
|
|||||||
typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
|
typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
|
||||||
int32_t struct_size;
|
int32_t struct_size;
|
||||||
void* priv;
|
void* priv;
|
||||||
|
|
||||||
uint32_t* runtime_input_ptr;
|
uint32_t* runtime_input_ptr;
|
||||||
size_t runtime_input_size;
|
size_t runtime_input_size;
|
||||||
int8_t* padded_data_ptr;
|
int8_t* padded_data_ptr;
|
||||||
@ -150,23 +152,65 @@ typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
|
|||||||
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
|
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
|
||||||
TpuExecute_RuntimeInputToPaddedData_Params* params);
|
TpuExecute_RuntimeInputToPaddedData_Params* params);
|
||||||
|
|
||||||
|
typedef struct ConfigureDistributedTpuOp_DoWork_Params {
|
||||||
|
int32_t struct_size;
|
||||||
|
void* priv;
|
||||||
|
|
||||||
|
size_t num_cores_per_host_size;
|
||||||
|
const int32_t* num_cores_per_host;
|
||||||
|
size_t server_address_size;
|
||||||
|
const char* server_address;
|
||||||
|
|
||||||
|
size_t* host_config_output_size; // out
|
||||||
|
char** host_config_output; // out
|
||||||
|
TF_Status* status; // out
|
||||||
|
} ConfigureDistributedTpuOp_DoWork_Params;
|
||||||
|
|
||||||
|
#define ConfigureDistributedTpuOp_DoWork_Params_SIZE \
|
||||||
|
(sizeof(struct ConfigureDistributedTpuOp_DoWork_Params))
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
ConfigureDistributedTpuOp_DoWork_Params* params);
|
||||||
size_t server_address_size, const char* server_address,
|
|
||||||
size_t* host_config_output_size, char** host_config_output,
|
typedef struct WaitForDistributedTpuOp_DoWork_Params {
|
||||||
TF_Status* status);
|
int32_t struct_size;
|
||||||
|
void* priv;
|
||||||
|
|
||||||
|
size_t num_hosts;
|
||||||
|
size_t num_cores_per_host;
|
||||||
|
const int32_t** host_ordinal_to_global_core_id_map;
|
||||||
|
tensorflow::TpuMeshCommonState* tpu_mesh_common_state;
|
||||||
|
|
||||||
|
size_t* tpu_topology_output_size; // out
|
||||||
|
char** tpu_topology_output; // out
|
||||||
|
TF_Status* status; // out
|
||||||
|
} WaitForDistributedTpuOp_DoWork_Params;
|
||||||
|
|
||||||
|
#define WaitForDistributedTpuOp_DoWork_Params_SIZE \
|
||||||
|
(sizeof(struct WaitForDistributedTpuOp_DoWork_Params))
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||||
const size_t num_hosts, const size_t num_cores_per_host,
|
WaitForDistributedTpuOp_DoWork_Params* params);
|
||||||
const int32_t** host_ordinal_to_global_core_id_map,
|
|
||||||
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
|
typedef struct InitializeHostForDistributedTpuOp_DoWork_Params {
|
||||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
int32_t struct_size;
|
||||||
TF_Status* status);
|
void* priv;
|
||||||
|
|
||||||
|
size_t tpu_host_config_size;
|
||||||
|
const char* tpu_host_config;
|
||||||
|
bool enable_whole_mesh_compilations;
|
||||||
|
bool is_master_worker;
|
||||||
|
|
||||||
|
size_t* core_id_output_size; // out
|
||||||
|
int32_t** core_id_output; // out
|
||||||
|
TF_Status* status; // out
|
||||||
|
} InitializeHostForDistributedTpuOp_DoWork_Params;
|
||||||
|
|
||||||
|
#define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \
|
||||||
|
(sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params))
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
||||||
const size_t tpu_host_config_size, const char* tpu_host_config,
|
InitializeHostForDistributedTpuOp_DoWork_Params* params);
|
||||||
const bool enable_whole_mesh_compilations, bool is_master_worker,
|
|
||||||
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
|
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
||||||
const size_t tpu_topology_size, const char* tpu_topology,
|
const size_t tpu_topology_size, const char* tpu_topology,
|
||||||
|
@ -79,6 +79,7 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
|
|||||||
|
|
||||||
TpuExecutable_LoadProgramAndEnqueueToStream_Params params;
|
TpuExecutable_LoadProgramAndEnqueueToStream_Params params;
|
||||||
params.struct_size = TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE;
|
params.struct_size = TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE;
|
||||||
|
params.priv = nullptr;
|
||||||
params.program = core_program_;
|
params.program = core_program_;
|
||||||
params.arguments = arguments_bases;
|
params.arguments = arguments_bases;
|
||||||
params.arguments_len = arguments.size();
|
params.arguments_len = arguments.size();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user