Add TPUMeshStateInterface integration to TPU
PiperOrigin-RevId: 318595505 Change-Id: I214c3bd9efcf62bf3b474de9d300b1a74b66c6c1
This commit is contained in:
parent
897803b00c
commit
064994341a
@ -174,9 +174,10 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
||||||
core::ScopedUnref mesh_state_unref(mesh_state);
|
core::ScopedUnref mesh_state_unref(mesh_state);
|
||||||
|
|
||||||
|
auto* mesh_common_state = mesh_state->mesh_common_state();
|
||||||
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||||
num_hosts, num_devices_per_host,
|
num_hosts, num_devices_per_host,
|
||||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
|
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
|
||||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||||
|
|
||||||
Tensor* ctx_output;
|
Tensor* ctx_output;
|
||||||
@ -210,12 +211,25 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
|||||||
VLOG(1) << "InitializeHostForDistributedTpuOp";
|
VLOG(1) << "InitializeHostForDistributedTpuOp";
|
||||||
XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
|
XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
|
||||||
|
|
||||||
|
auto* rmgr = GetTPUConfigResourceMgr();
|
||||||
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
|
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
|
||||||
|
|
||||||
size_t device_id_output_size;
|
size_t device_id_output_size;
|
||||||
int32_t* device_id_output;
|
int32_t* device_id_output;
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
bool is_master_worker =
|
||||||
|
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
|
||||||
|
if (!is_master_worker) {
|
||||||
|
// Reset the mesh interface if we are not the master.
|
||||||
|
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||||
|
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||||
|
auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
|
||||||
|
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
|
||||||
|
tpu::kTpuMeshStateInterfaceResourceName,
|
||||||
|
mesh_state_interface));
|
||||||
|
}
|
||||||
|
|
||||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||||
tpu_host_config.size(), tpu_host_config.data(),
|
tpu_host_config.size(), tpu_host_config.data(),
|
||||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
typedef struct TpuSerializedProto TpuSerializedProto;
|
typedef struct TpuSerializedProto TpuSerializedProto;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
class TpuMeshCommonState;
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
class TpuMeshStateInterface;
|
class TpuMeshStateInterface;
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
@ -40,7 +41,7 @@ TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
|||||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||||
const size_t num_hosts, const size_t num_cores_per_host,
|
const size_t num_hosts, const size_t num_cores_per_host,
|
||||||
const int32_t** host_ordinal_to_global_core_id_map,
|
const int32_t** host_ordinal_to_global_core_id_map,
|
||||||
tensorflow::tpu::TpuMeshStateInterface* tpu_mesh_state_interface,
|
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
|
||||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
@ -60,6 +61,8 @@ TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
|
|||||||
|
|
||||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
|
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
|
||||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
||||||
|
|
||||||
|
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TfTpu_ConfigApiFn {
|
struct TfTpu_ConfigApiFn {
|
||||||
@ -71,6 +74,7 @@ struct TfTpu_ConfigApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
|
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
|
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
||||||
|
@ -11,6 +11,7 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
|||||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||||
|
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
|
||||||
|
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user