Add TPUMeshStateInterface integration to TPU

PiperOrigin-RevId: 318595505
Change-Id: I214c3bd9efcf62bf3b474de9d300b1a74b66c6c1
This commit is contained in:
Frank Chen 2020-06-26 22:08:49 -07:00 committed by TensorFlower Gardener
parent 897803b00c
commit 064994341a
3 changed files with 21 additions and 2 deletions

View File

@ -174,9 +174,10 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
core::ScopedUnref mesh_state_unref(mesh_state);
auto* mesh_common_state = mesh_state->mesh_common_state();
tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
num_hosts, num_devices_per_host,
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
&tpu_topology_output_size, &tpu_topology_output, status);
Tensor* ctx_output;
@ -210,12 +211,25 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "InitializeHostForDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
auto* rmgr = GetTPUConfigResourceMgr();
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
size_t device_id_output_size;
int32_t* device_id_output;
TF_Status* status = TF_NewStatus();
bool is_master_worker =
tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
if (!is_master_worker) {
// Reset the mesh interface if we are not the master.
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
rmgr, tpu::kTpuMeshStateInterfaceResourceName));
auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
tpu::kTpuMeshStateInterfaceResourceName,
mesh_state_interface));
}
tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
tpu_host_config.size(), tpu_host_config.data(),
enable_whole_mesh_compilations_, &device_id_output_size,

View File

@ -25,6 +25,7 @@ limitations under the License.
typedef struct TpuSerializedProto TpuSerializedProto;
namespace tensorflow {
class TpuMeshCommonState;
namespace tpu {
class TpuMeshStateInterface;
} // namespace tpu
@ -40,7 +41,7 @@ TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
const size_t num_hosts, const size_t num_cores_per_host,
const int32_t** host_ordinal_to_global_core_id_map,
tensorflow::tpu::TpuMeshStateInterface* tpu_mesh_state_interface,
tensorflow::TpuMeshCommonState* tpu_mesh_common_state,
size_t* tpu_topology_output_size, char** tpu_topology_output,
TF_Status* status);
@ -60,6 +61,8 @@ TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
}
struct TfTpu_ConfigApiFn {
@ -71,6 +74,7 @@ struct TfTpu_ConfigApiFn {
TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
};
#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_

View File

@ -11,6 +11,7 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeCharArray);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_FreeInt32Array);
TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
return tensorflow::Status::OK();
}