Update Configure TPU, Wait For TPU, and Initialize TPU APIs to backward compatible API style
PiperOrigin-RevId: 342301514 Change-Id: I867bd45628db09df8854685d70328391a8ffb6ed
This commit is contained in:
parent
55a470ba5d
commit
0c444b107f
@ -207,10 +207,20 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
});
|
||||
|
||||
auto* mesh_common_state = mesh_state->mesh_common_state();
|
||||
tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||
num_hosts, num_devices_per_host,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
|
||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||
|
||||
WaitForDistributedTpuOp_DoWork_Params params;
|
||||
params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
|
||||
params.priv = nullptr;
|
||||
params.num_hosts = num_hosts;
|
||||
params.num_cores_per_host = num_devices_per_host;
|
||||
params.host_ordinal_to_global_core_id_map =
|
||||
const_cast<const int32_t**>(mapping_arg.data());
|
||||
params.tpu_mesh_common_state = mesh_common_state;
|
||||
params.tpu_topology_output_size = &tpu_topology_output_size;
|
||||
params.tpu_topology_output = &tpu_topology_output;
|
||||
params.status = status;
|
||||
|
||||
tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(¶ms);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
|
||||
@ -284,10 +294,19 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
TF_DeleteStatus(status);
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
|
||||
});
|
||||
tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
|
||||
&device_id_output, status);
|
||||
|
||||
InitializeHostForDistributedTpuOp_DoWork_Params params;
|
||||
params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
|
||||
params.priv = nullptr;
|
||||
params.tpu_host_config_size = tpu_host_config.size();
|
||||
params.tpu_host_config = tpu_host_config.data();
|
||||
params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
|
||||
params.is_master_worker = is_master_worker;
|
||||
params.core_id_output_size = &device_id_output_size;
|
||||
params.core_id_output = &device_id_output;
|
||||
params.status = status;
|
||||
|
||||
tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(¶ms);
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
|
||||
if (local_compilation_cache != nullptr) {
|
||||
|
@ -152,10 +152,19 @@ Status ConstructTpuPodState(
|
||||
tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
|
||||
});
|
||||
size_t host_config_output_size;
|
||||
tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
server_address.size(), server_address.data(), &host_config_output_size,
|
||||
&host_config_output, status);
|
||||
|
||||
ConfigureDistributedTpuOp_DoWork_Params params;
|
||||
params.struct_size = ConfigureDistributedTpuOp_DoWork_Params_SIZE;
|
||||
params.priv = nullptr;
|
||||
params.num_cores_per_host_size = num_devices_per_host.size();
|
||||
params.num_cores_per_host = num_devices_per_host.data();
|
||||
params.server_address_size = server_address.size();
|
||||
params.server_address = server_address.data();
|
||||
params.host_config_output_size = &host_config_output_size;
|
||||
params.host_config_output = &host_config_output;
|
||||
params.status = status;
|
||||
|
||||
tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(¶ms);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
|
||||
*host_config_proto = std::string(host_config_output, host_config_output_size);
|
||||
|
||||
|
@ -244,6 +244,7 @@ xla::Status UpdateDynamicInputs(
|
||||
TpuExecute_RuntimeInputToPaddedData_Params params;
|
||||
params.struct_size =
|
||||
TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
|
||||
params.priv = nullptr;
|
||||
params.runtime_input_ptr = raw_input_runtime->data();
|
||||
params.runtime_input_size = raw_input_runtime->size();
|
||||
params.padded_data_ptr = padded_data->data();
|
||||
|
@ -107,6 +107,7 @@ TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
|
||||
typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
|
||||
int32_t struct_size;
|
||||
void* priv;
|
||||
|
||||
const XLA_TpuProgram* program;
|
||||
SE_DeviceMemoryBase* arguments;
|
||||
size_t arguments_len;
|
||||
@ -134,6 +135,7 @@ TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
|
||||
typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
|
||||
int32_t struct_size;
|
||||
void* priv;
|
||||
|
||||
uint32_t* runtime_input_ptr;
|
||||
size_t runtime_input_size;
|
||||
int8_t* padded_data_ptr;
|
||||
@ -150,23 +152,65 @@ typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
|
||||
TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
|
||||
TpuExecute_RuntimeInputToPaddedData_Params* params);
|
||||
|
||||
typedef struct ConfigureDistributedTpuOp_DoWork_Params {
|
||||
int32_t struct_size;
|
||||
void* priv;
|
||||
|
||||
size_t num_cores_per_host_size;
|
||||
const int32_t* num_cores_per_host;
|
||||
size_t server_address_size;
|
||||
const char* server_address;
|
||||
|
||||
size_t* host_config_output_size; // out
|
||||
char** host_config_output; // out
|
||||
TF_Status* status; // out
|
||||
} ConfigureDistributedTpuOp_DoWork_Params;
|
||||
|
||||
#define ConfigureDistributedTpuOp_DoWork_Params_SIZE \
|
||||
(sizeof(struct ConfigureDistributedTpuOp_DoWork_Params))
|
||||
|
||||
TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
|
||||
size_t server_address_size, const char* server_address,
|
||||
size_t* host_config_output_size, char** host_config_output,
|
||||
TF_Status* status);
|
||||
ConfigureDistributedTpuOp_DoWork_Params* params);
|
||||
|
||||
typedef struct WaitForDistributedTpuOp_DoWork_Params {
|
||||
int32_t struct_size;
|
||||
void* priv;
|
||||
|
||||
size_t num_hosts;
|
||||
size_t num_cores_per_host;
|
||||
const int32_t** host_ordinal_to_global_core_id_map;
|
||||
tensorflow::TpuMeshCommonState* tpu_mesh_common_state;
|
||||
|
||||
size_t* tpu_topology_output_size; // out
|
||||
char** tpu_topology_output; // out
|
||||
TF_Status* status; // out
|
||||
} WaitForDistributedTpuOp_DoWork_Params;
|
||||
|
||||
#define WaitForDistributedTpuOp_DoWork_Params_SIZE \
|
||||
(sizeof(struct WaitForDistributedTpuOp_DoWork_Params))
|
||||
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
const int32_t** host_ordinal_to_global_core_id_map,
|
||||
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
|
||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||
TF_Status* status);
|
||||
WaitForDistributedTpuOp_DoWork_Params* params);
|
||||
|
||||
typedef struct InitializeHostForDistributedTpuOp_DoWork_Params {
|
||||
int32_t struct_size;
|
||||
void* priv;
|
||||
|
||||
size_t tpu_host_config_size;
|
||||
const char* tpu_host_config;
|
||||
bool enable_whole_mesh_compilations;
|
||||
bool is_master_worker;
|
||||
|
||||
size_t* core_id_output_size; // out
|
||||
int32_t** core_id_output; // out
|
||||
TF_Status* status; // out
|
||||
} InitializeHostForDistributedTpuOp_DoWork_Params;
|
||||
|
||||
#define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \
|
||||
(sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params))
|
||||
|
||||
TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
|
||||
const size_t tpu_host_config_size, const char* tpu_host_config,
|
||||
const bool enable_whole_mesh_compilations, bool is_master_worker,
|
||||
size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
|
||||
InitializeHostForDistributedTpuOp_DoWork_Params* params);
|
||||
|
||||
TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
|
||||
const size_t tpu_topology_size, const char* tpu_topology,
|
||||
|
@ -79,6 +79,7 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
|
||||
|
||||
TpuExecutable_LoadProgramAndEnqueueToStream_Params params;
|
||||
params.struct_size = TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE;
|
||||
params.priv = nullptr;
|
||||
params.program = core_program_;
|
||||
params.arguments = arguments_bases;
|
||||
params.arguments_len = arguments.size();
|
||||
|
Loading…
x
Reference in New Issue
Block a user