Add TPUMeshStateInterface integration to TPU
PiperOrigin-RevId: 318595505 Change-Id: I214c3bd9efcf62bf3b474de9d300b1a74b66c6c1
This commit is contained in:
parent
897803b00c
commit
064994341a
@ -174,9 +174,10 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
||||
core::ScopedUnref mesh_state_unref(mesh_state);
|
||||
|
||||
auto* mesh_common_state = mesh_state->mesh_common_state();
|
||||
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
|
||||
num_hosts, num_devices_per_host,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
|
||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||
|
||||
Tensor* ctx_output;
|
||||
@ -210,12 +211,25 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "InitializeHostForDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
|
||||
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
|
||||
|
||||
size_t device_id_output_size;
|
||||
int32_t* device_id_output;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
bool is_master_worker =
|
||||
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
|
||||
if (!is_master_worker) {
|
||||
// Reset the mesh interface if we are not the master.
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
|
||||
auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
|
||||
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
|
||||
tpu::kTpuMeshStateInterfaceResourceName,
|
||||
mesh_state_interface));
|
||||
}
|
||||
|
||||
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
typedef struct TpuSerializedProto TpuSerializedProto;
|
||||
|
||||
namespace tensorflow {
|
||||
class TpuMeshCommonState;
|
||||
namespace tpu {
|
||||
class TpuMeshStateInterface;
|
||||
} // namespace tpu
|
||||
@ -40,7 +41,7 @@ TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
|
||||
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
|
||||
const size_t num_hosts, const size_t num_cores_per_host,
|
||||
const int32_t** host_ordinal_to_global_core_id_map,
|
||||
tensorflow::tpu::TpuMeshStateInterface* tpu_mesh_state_interface,
|
||||
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
|
||||
size_t* tpu_topology_output_size, char** tpu_topology_output,
|
||||
TF_Status* status);
|
||||
|
||||
@ -60,6 +61,8 @@ TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
|
||||
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
|
||||
|
||||
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
|
||||
}
|
||||
|
||||
struct TfTpu_ConfigApiFn {
|
||||
@ -71,6 +74,7 @@ struct TfTpu_ConfigApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
|
||||
|
@ -11,6 +11,7 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
|
||||
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
|
||||
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user