Add TpuHostLocationExternal::Cores() and C API plumbing.

PiperOrigin-RevId: 329997425
Change-Id: Ib89ab30c1b01627edc5ca2dac471f622891b099e
This commit is contained in:
Skye Wanderman-Milne 2020-09-03 14:46:58 -07:00 committed by TensorFlower Gardener
parent 8ea12c187f
commit 565f670d9f
4 changed files with 26 additions and 0 deletions

View File

@ -124,6 +124,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuCoreLocation_Id);
TFTPU_SET_FN(executor_fn, TpuHostLocation_Id);
TFTPU_SET_FN(executor_fn, TpuHostLocation_NumCores);
TFTPU_SET_FN(executor_fn, TpuHostLocation_Cores);
TFTPU_SET_FN(executor_fn, TpuCompiler_New);
TFTPU_SET_FN(executor_fn, TpuCompiler_Free);

View File

@ -255,6 +255,12 @@ int TpuCoreLocation_Index(SE_TpuTopology_Core* tpu_core_location);
int TpuCoreLocation_Id(SE_TpuTopology_Core* tpu_core_location);
int TpuHostLocation_Id(SE_TpuTopology_Host* tpu_host_location);
int TpuHostLocation_NumCores(SE_TpuTopology_Host* tpu_host_location,
TpuCoreTypeEnum tpu_core_type);
// 'cores' should be a preallocated array of size TpuHostLocation_NumCores.
void TpuHostLocation_Cores(SE_TpuTopology_Host* tpu_host_location,
TpuCoreTypeEnum tpu_core_type,
SE_TpuTopology_Core** cores);
// C API for XLA::Compiler interface
@ -419,6 +425,8 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Id);
TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_Id);
TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_NumCores);
TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_Cores);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);

View File

@ -46,6 +46,21 @@ int32 TpuHostLocationExternal::Id() const {
return tpu::ExecutorApiFn()->TpuHostLocation_IdFn(host_location_);
}
std::vector<TpuCoreLocationExternal> TpuHostLocationExternal::Cores(
TpuCoreTypeEnum core_type) const {
int num_cores = tpu::ExecutorApiFn()->TpuHostLocation_NumCoresFn(
host_location_, core_type);
std::vector<SE_TpuTopology_Core*> core_ptrs(num_cores);
tpu::ExecutorApiFn()->TpuHostLocation_CoresFn(host_location_, core_type,
core_ptrs.data());
std::vector<TpuCoreLocationExternal> result;
result.reserve(num_cores);
for (SE_TpuTopology_Core* ptr : core_ptrs) {
result.emplace_back(ptr);
}
return result;
}
int32 TpuTopologyExternal::LogicalDevicesPerHost(
TpuCoreTypeEnum core_type) const {
return tpu::ExecutorApiFn()->TpuTopology_LogicalDevicesPerHostFn(topology_,

View File

@ -51,6 +51,7 @@ class TpuHostLocationExternal {
explicit TpuHostLocationExternal(SE_TpuTopology_Host* host_location)
: host_location_(host_location) {}
int32 Id() const;
std::vector<TpuCoreLocationExternal> Cores(TpuCoreTypeEnum core_type) const;
SE_TpuTopology_Host* impl() const { return host_location_; }