Create dummy topology types for libtpu C API.
This gives us better typechecking than just using void*. PiperOrigin-RevId: 325919251 Change-Id: Idc375ed151a2b0b2cf7d62b7be9cf2fafe48e934
This commit is contained in:
parent
d10c814ce1
commit
82c043fee5
@ -294,6 +294,7 @@ cc_library(
|
||||
hdrs = ["tpu_platform_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_decl",
|
||||
":tpu_topology_external",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor",
|
||||
|
@ -253,6 +253,10 @@ typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
|
||||
|
||||
typedef void (*XLA_CallbackFn)(void*);
|
||||
typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*);
|
||||
|
||||
typedef struct SE_TpuTopology SE_TpuTopology;
|
||||
typedef struct SE_TpuTopology_Core SE_TpuTopology_Core;
|
||||
typedef struct SE_TpuTopology_Core SE_TpuTopology_Host;
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DECL_H_
|
||||
|
@ -63,8 +63,10 @@ struct SE_DeviceOptions {
|
||||
stream_executor::DeviceOptions options;
|
||||
};
|
||||
|
||||
// Ignored -- these are just used to enforce the interface types
|
||||
struct XLA_TransferManager {};
|
||||
|
||||
struct XLA_ComputationPlacer {};
|
||||
struct SE_TpuTopology {};
|
||||
struct SE_TpuTopology_Core {};
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_DEFN_H_
|
||||
|
@ -39,8 +39,8 @@ SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
|
||||
int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
|
||||
int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
|
||||
bool TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopy(SE_Platform* platform);
|
||||
void* TpuPlatform_GetTopologyPtr(SE_Platform* platform);
|
||||
void* TpuPlatform_GetHostLocation(SE_Platform* platform);
|
||||
SE_TpuTopology* TpuPlatform_GetTopologyPtr(SE_Platform* platform);
|
||||
SE_TpuTopology_Host* TpuPlatform_GetHostLocation(SE_Platform* platform);
|
||||
|
||||
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
|
||||
SE_DeviceOptions* device_options, SE_Status* status);
|
||||
@ -193,29 +193,32 @@ void TpuTransferManager_FreeBuffers(char** buffers_array, int64_t* buffers_size,
|
||||
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
||||
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||
|
||||
int TpuTopology_LogicalDevicesPerHost(void* tpu_topology,
|
||||
int TpuTopology_LogicalDevicesPerHost(SE_TpuTopology* tpu_topology,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
int TpuTopology_LogicalDevicesPerChip(void* tpu_topology,
|
||||
int TpuTopology_LogicalDevicesPerChip(SE_TpuTopology* tpu_topology,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
int TpuTopology_ChipBounds_X(void* tpu_topology);
|
||||
int TpuTopology_ChipBounds_Y(void* tpu_topology);
|
||||
int TpuTopology_ChipBounds_Z(void* tpu_topology);
|
||||
bool TpuTopology_HasChip(void* tpu_topology, int x, int y, int z);
|
||||
void* TpuTopology_Core(void* tpu_topology, int x, int y, int z,
|
||||
TpuCoreTypeEnum tpu_core_type, int index);
|
||||
int TpuTopology_NumCores(void* tpu_topology, TpuCoreTypeEnum tpu_core_type);
|
||||
int TpuTopology_ChipBounds_X(SE_TpuTopology* tpu_topology);
|
||||
int TpuTopology_ChipBounds_Y(SE_TpuTopology* tpu_topology);
|
||||
int TpuTopology_ChipBounds_Z(SE_TpuTopology* tpu_topology);
|
||||
bool TpuTopology_HasChip(SE_TpuTopology* tpu_topology, int x, int y, int z);
|
||||
SE_TpuTopology_Core* TpuTopology_Core(SE_TpuTopology* tpu_topology, int x,
|
||||
int y, int z,
|
||||
TpuCoreTypeEnum tpu_core_type, int index);
|
||||
int TpuTopology_NumCores(SE_TpuTopology* tpu_topology,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
// 'cores' should be a preallocated array of size TpuTopology_NumCores.
|
||||
void TpuTopology_Cores(void* tpu_topology, TpuCoreTypeEnum tpu_core_type,
|
||||
void** cores);
|
||||
int TpuTopology_IdForHost(void* tpu_topology, int x, int y, int z);
|
||||
void TpuCoreLocation_ChipCoordinates(void* tpu_core_location, int* x, int* y,
|
||||
int* z);
|
||||
void TpuCoreLocation_HostCoordinates(void* tpu_core_location, int* x, int* y,
|
||||
int* z);
|
||||
int TpuCoreLocation_Index(void* tpu_core_location);
|
||||
int TpuCoreLocation_Id(void* tpu_core_location);
|
||||
void TpuTopology_Cores(SE_TpuTopology* tpu_topology,
|
||||
TpuCoreTypeEnum tpu_core_type,
|
||||
SE_TpuTopology_Core** cores);
|
||||
int TpuTopology_IdForHost(SE_TpuTopology* tpu_topology, int x, int y, int z);
|
||||
void TpuCoreLocation_ChipCoordinates(SE_TpuTopology_Core* tpu_core_location,
|
||||
int* x, int* y, int* z);
|
||||
void TpuCoreLocation_HostCoordinates(SE_TpuTopology_Core* tpu_core_location,
|
||||
int* x, int* y, int* z);
|
||||
int TpuCoreLocation_Index(SE_TpuTopology_Core* tpu_core_location);
|
||||
int TpuCoreLocation_Id(SE_TpuTopology_Core* tpu_core_location);
|
||||
|
||||
int TpuHostLocation_Id(void* tpu_host_location);
|
||||
int TpuHostLocation_Id(SE_TpuTopology_Host* tpu_host_location);
|
||||
|
||||
// C API for XLA::Compiler interface
|
||||
|
||||
|
@ -18,12 +18,15 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
typedef void* TpuTopologyPtr;
|
||||
// TODO(skyewm): get rid of TpuTopologyPtr and either use SE_TpuTopology* or
|
||||
// return a TpuTopologyExternal.
|
||||
typedef SE_TpuTopology* TpuTopologyPtr;
|
||||
|
||||
class TpuPlatformInterface : public stream_executor::Platform {
|
||||
public:
|
||||
|
@ -79,12 +79,12 @@ std::vector<TpuCoreLocationExternal> TpuTopologyExternal::cores(
|
||||
TpuCoreTypeEnum core_type) const {
|
||||
int num_cores =
|
||||
tpu::ExecutorApiFn()->TpuTopology_NumCoresFn(topology_, core_type);
|
||||
std::vector<void*> core_ptrs(num_cores);
|
||||
std::vector<SE_TpuTopology_Core*> core_ptrs(num_cores);
|
||||
tpu::ExecutorApiFn()->TpuTopology_CoresFn(topology_, core_type,
|
||||
core_ptrs.data());
|
||||
std::vector<TpuCoreLocationExternal> result;
|
||||
result.reserve(num_cores);
|
||||
for (void* ptr : core_ptrs) {
|
||||
for (SE_TpuTopology_Core* ptr : core_ptrs) {
|
||||
result.emplace_back(ptr);
|
||||
}
|
||||
return result;
|
||||
|
@ -33,27 +33,27 @@ struct TpuDimensionsExternal {
|
||||
class TpuCoreLocationExternal {
|
||||
public:
|
||||
TpuCoreLocationExternal() : core_location_(nullptr) {}
|
||||
explicit TpuCoreLocationExternal(void* core_location)
|
||||
explicit TpuCoreLocationExternal(SE_TpuTopology_Core* core_location)
|
||||
: core_location_(core_location) {}
|
||||
TpuDimensionsExternal chip_coordinates() const;
|
||||
TpuDimensionsExternal host_coordinates() const;
|
||||
int32 index() const;
|
||||
int32 Id() const;
|
||||
|
||||
void* impl() const { return core_location_; }
|
||||
SE_TpuTopology_Core* impl() const { return core_location_; }
|
||||
|
||||
private:
|
||||
void* core_location_;
|
||||
SE_TpuTopology_Core* core_location_;
|
||||
};
|
||||
|
||||
class TpuHostLocationExternal {
|
||||
public:
|
||||
explicit TpuHostLocationExternal(void* host_location)
|
||||
explicit TpuHostLocationExternal(SE_TpuTopology_Host* host_location)
|
||||
: host_location_(host_location) {}
|
||||
int32 Id() const;
|
||||
|
||||
private:
|
||||
void* host_location_;
|
||||
SE_TpuTopology_Host* host_location_;
|
||||
};
|
||||
|
||||
struct TpuTopologyChipBoundsExternal {
|
||||
@ -64,7 +64,8 @@ struct TpuTopologyChipBoundsExternal {
|
||||
|
||||
class TpuTopologyExternal {
|
||||
public:
|
||||
explicit TpuTopologyExternal(void* topology) : topology_(topology) {}
|
||||
explicit TpuTopologyExternal(SE_TpuTopology* topology)
|
||||
: topology_(topology) {}
|
||||
int32 LogicalDevicesPerHost(TpuCoreTypeEnum core_type) const;
|
||||
int32 LogicalDevicesPerChip(TpuCoreTypeEnum core_type) const;
|
||||
TpuTopologyChipBoundsExternal chip_bounds() const;
|
||||
@ -75,7 +76,7 @@ class TpuTopologyExternal {
|
||||
int IdForHost(TpuDimensionsExternal host) const;
|
||||
|
||||
private:
|
||||
void* topology_;
|
||||
SE_TpuTopology* topology_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
Loading…
Reference in New Issue
Block a user