diff --git a/tensorflow/core/tpu/tpu_executor_init_fns.inc b/tensorflow/core/tpu/tpu_executor_init_fns.inc index e849fad6c65..4970292c499 100644 --- a/tensorflow/core/tpu/tpu_executor_init_fns.inc +++ b/tensorflow/core/tpu/tpu_executor_init_fns.inc @@ -124,6 +124,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) { TFTPU_SET_FN(executor_fn, TpuCoreLocation_Id); TFTPU_SET_FN(executor_fn, TpuHostLocation_Id); + TFTPU_SET_FN(executor_fn, TpuHostLocation_NumCores); + TFTPU_SET_FN(executor_fn, TpuHostLocation_Cores); TFTPU_SET_FN(executor_fn, TpuCompiler_New); TFTPU_SET_FN(executor_fn, TpuCompiler_Free); diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h index 5d8375aa1de..b59a8f2ad08 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -255,6 +255,12 @@ int TpuCoreLocation_Index(SE_TpuTopology_Core* tpu_core_location); int TpuCoreLocation_Id(SE_TpuTopology_Core* tpu_core_location); int TpuHostLocation_Id(SE_TpuTopology_Host* tpu_host_location); +int TpuHostLocation_NumCores(SE_TpuTopology_Host* tpu_host_location, + TpuCoreTypeEnum tpu_core_type); +// 'cores' should be a preallocated array of size TpuHostLocation_NumCores. +void TpuHostLocation_Cores(SE_TpuTopology_Host* tpu_host_location, + TpuCoreTypeEnum tpu_core_type, + SE_TpuTopology_Core** cores); // C API for XLA::Compiler interface @@ -419,6 +425,8 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Id); TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_Id); + TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_NumCores); + TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_Cores); TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New); TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free); diff --git a/tensorflow/stream_executor/tpu/tpu_topology.cc b/tensorflow/stream_executor/tpu/tpu_topology.cc index c86b399b34e..659cace57a0 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.cc +++ b/tensorflow/stream_executor/tpu/tpu_topology.cc @@ -46,6 +46,21 @@ int32 TpuHostLocationExternal::Id() const { return tpu::ExecutorApiFn()->TpuHostLocation_IdFn(host_location_); } +std::vector TpuHostLocationExternal::Cores( + TpuCoreTypeEnum core_type) const { + int num_cores = tpu::ExecutorApiFn()->TpuHostLocation_NumCoresFn( + host_location_, core_type); + std::vector core_ptrs(num_cores); + tpu::ExecutorApiFn()->TpuHostLocation_CoresFn(host_location_, core_type, + core_ptrs.data()); + std::vector result; + result.reserve(num_cores); + for (SE_TpuTopology_Core* ptr : core_ptrs) { + result.emplace_back(ptr); + } + return result; +} + int32 TpuTopologyExternal::LogicalDevicesPerHost( TpuCoreTypeEnum core_type) const { return tpu::ExecutorApiFn()->TpuTopology_LogicalDevicesPerHostFn(topology_, diff --git a/tensorflow/stream_executor/tpu/tpu_topology.h b/tensorflow/stream_executor/tpu/tpu_topology.h index 1b22efa3613..7a92353993b 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.h +++ b/tensorflow/stream_executor/tpu/tpu_topology.h @@ -51,6 +51,7 @@ class TpuHostLocationExternal { explicit TpuHostLocationExternal(SE_TpuTopology_Host* host_location) : host_location_(host_location) {} int32 Id() const; + std::vector Cores(TpuCoreTypeEnum core_type) const; SE_TpuTopology_Host* impl() const { return host_location_; }