Add TpuHostLocationExternal::Cores() and C API plumbing.
PiperOrigin-RevId: 329997425 Change-Id: Ib89ab30c1b01627edc5ca2dac471f622891b099e
This commit is contained in:
parent
8ea12c187f
commit
565f670d9f
@ -124,6 +124,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuCoreLocation_Id);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuHostLocation_Id);
|
||||
TFTPU_SET_FN(executor_fn, TpuHostLocation_NumCores);
|
||||
TFTPU_SET_FN(executor_fn, TpuHostLocation_Cores);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_New);
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_Free);
|
||||
|
@ -255,6 +255,12 @@ int TpuCoreLocation_Index(SE_TpuTopology_Core* tpu_core_location);
|
||||
int TpuCoreLocation_Id(SE_TpuTopology_Core* tpu_core_location);
|
||||
|
||||
int TpuHostLocation_Id(SE_TpuTopology_Host* tpu_host_location);
|
||||
int TpuHostLocation_NumCores(SE_TpuTopology_Host* tpu_host_location,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
// 'cores' should be a preallocated array of size TpuHostLocation_NumCores.
|
||||
void TpuHostLocation_Cores(SE_TpuTopology_Host* tpu_host_location,
|
||||
TpuCoreTypeEnum tpu_core_type,
|
||||
SE_TpuTopology_Core** cores);
|
||||
|
||||
// C API for XLA::Compiler interface
|
||||
|
||||
@ -419,6 +425,8 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCoreLocation_Id);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_Id);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_NumCores);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuHostLocation_Cores);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Free);
|
||||
|
@ -46,6 +46,21 @@ int32 TpuHostLocationExternal::Id() const {
|
||||
return tpu::ExecutorApiFn()->TpuHostLocation_IdFn(host_location_);
|
||||
}
|
||||
|
||||
std::vector<TpuCoreLocationExternal> TpuHostLocationExternal::Cores(
|
||||
TpuCoreTypeEnum core_type) const {
|
||||
int num_cores = tpu::ExecutorApiFn()->TpuHostLocation_NumCoresFn(
|
||||
host_location_, core_type);
|
||||
std::vector<SE_TpuTopology_Core*> core_ptrs(num_cores);
|
||||
tpu::ExecutorApiFn()->TpuHostLocation_CoresFn(host_location_, core_type,
|
||||
core_ptrs.data());
|
||||
std::vector<TpuCoreLocationExternal> result;
|
||||
result.reserve(num_cores);
|
||||
for (SE_TpuTopology_Core* ptr : core_ptrs) {
|
||||
result.emplace_back(ptr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int32 TpuTopologyExternal::LogicalDevicesPerHost(
|
||||
TpuCoreTypeEnum core_type) const {
|
||||
return tpu::ExecutorApiFn()->TpuTopology_LogicalDevicesPerHostFn(topology_,
|
||||
|
@ -51,6 +51,7 @@ class TpuHostLocationExternal {
|
||||
explicit TpuHostLocationExternal(SE_TpuTopology_Host* host_location)
|
||||
: host_location_(host_location) {}
|
||||
int32 Id() const;
|
||||
std::vector<TpuCoreLocationExternal> Cores(TpuCoreTypeEnum core_type) const;
|
||||
|
||||
SE_TpuTopology_Host* impl() const { return host_location_; }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user