From 064994341a8c9636d888358da12cfefb90f95d00 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 26 Jun 2020 22:08:49 -0700 Subject: [PATCH] Add TPUMeshStateInterface integration to TPU PiperOrigin-RevId: 318595505 Change-Id: I214c3bd9efcf62bf3b474de9d300b1a74b66c6c1 --- .../core/tpu/kernels/tpu_configuration_ops.cc | 16 +++++++++++++++- tensorflow/core/tpu/tpu_config_c_api.h | 6 +++++- tensorflow/core/tpu/tpu_library_init_fns.inc | 1 + 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 7d3814ad3c3..065a7f77dd6 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -174,9 +174,10 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state)); core::ScopedUnref mesh_state_unref(mesh_state); + auto* mesh_common_state = mesh_state->mesh_common_state(); tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn( num_hosts, num_devices_per_host, - const_cast(mapping_arg.data()), mesh_state, + const_cast(mapping_arg.data()), mesh_common_state, &tpu_topology_output_size, &tpu_topology_output, status); Tensor* ctx_output; @@ -210,12 +211,25 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { VLOG(1) << "InitializeHostForDistributedTpuOp"; XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp"); + auto* rmgr = GetTPUConfigResourceMgr(); auto tpu_host_config = ctx->input(0).scalar()(); size_t device_id_output_size; int32_t* device_id_output; TF_Status* status = TF_NewStatus(); + bool is_master_worker = + tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn(); + if (!is_master_worker) { + // Reset the mesh interface if we are not the master. + OP_REQUIRES_OK(ctx, DeleteIfExists( + rmgr, tpu::kTpuMeshStateInterfaceResourceName)); + auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create(); + OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(), + tpu::kTpuMeshStateInterfaceResourceName, + mesh_state_interface)); + } + tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn( tpu_host_config.size(), tpu_host_config.data(), enable_whole_mesh_compilations_, &device_id_output_size, diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h index 9c0ed203d4b..8530df5ac26 100644 --- a/tensorflow/core/tpu/tpu_config_c_api.h +++ b/tensorflow/core/tpu/tpu_config_c_api.h @@ -25,6 +25,7 @@ limitations under the License. typedef struct TpuSerializedProto TpuSerializedProto; namespace tensorflow { +class TpuMeshCommonState; namespace tpu { class TpuMeshStateInterface; } // namespace tpu @@ -40,7 +41,7 @@ TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( const size_t num_hosts, const size_t num_cores_per_host, const int32_t** host_ordinal_to_global_core_id_map, - tensorflow::tpu::TpuMeshStateInterface* tpu_mesh_state_interface, + tensorflow::TpuMeshCommonState* tpu_mesh_common_state, size_t* tpu_topology_output_size, char** tpu_topology_output, TF_Status* status); @@ -60,6 +61,8 @@ TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); + +TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); } struct TfTpu_ConfigApiFn { @@ -71,6 +74,7 @@ struct TfTpu_ConfigApiFn { TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); + TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); }; #endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 16d06539349..f8bde09e728 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -11,6 +11,7 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) { TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork); TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray); TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array); + TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState); return tensorflow::Status::OK(); }