Add shape-based allocation and linearizer to external TPU driver
PiperOrigin-RevId: 288938928 Change-Id: I27b6b88d01880db216810f826c69ae40cb7cbaaf
This commit is contained in:
parent
62bc37e41e
commit
e98a887ebe
@ -74,6 +74,11 @@ typedef struct CompiledProgramShape {
|
|||||||
int32_t size;
|
int32_t size;
|
||||||
} CompiledProgramShape;
|
} CompiledProgramShape;
|
||||||
|
|
||||||
|
typedef struct TpuAllocationShape {
|
||||||
|
void* bytes;
|
||||||
|
int32_t size;
|
||||||
|
} TpuAllocationShape;
|
||||||
|
|
||||||
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
|
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
|
||||||
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
|
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
|
||||||
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
||||||
@ -81,6 +86,17 @@ typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
|||||||
// TODO(frankchn): Make this not a hard-coded constant.
|
// TODO(frankchn): Make this not a hard-coded constant.
|
||||||
const int32_t MemoryRegion_HBM = 1;
|
const int32_t MemoryRegion_HBM = 1;
|
||||||
|
|
||||||
|
typedef int64_t(PrototypeTpuDriver_ComputeLinearizedBytesFromShape)(
|
||||||
|
struct TpuDriver* driver, const struct TpuAllocationShape shape);
|
||||||
|
|
||||||
|
typedef struct TpuStatus*(PrototypeTpuDriver_LinearizeShape)(
|
||||||
|
struct TpuDriver* driver, void* dst, const void* src,
|
||||||
|
const struct TpuAllocationShape shape);
|
||||||
|
|
||||||
|
typedef struct TpuStatus*(PrototypeTpuDriver_DelinearizeShape)(
|
||||||
|
struct TpuDriver* driver, void* dst, const void* src,
|
||||||
|
const struct TpuAllocationShape shape);
|
||||||
|
|
||||||
typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)(
|
typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)(
|
||||||
struct TpuDriver* driver, const struct HloProto hlo_proto,
|
struct TpuDriver* driver, const struct HloProto hlo_proto,
|
||||||
int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv);
|
int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv);
|
||||||
@ -118,6 +134,11 @@ typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)(
|
|||||||
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
|
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
|
||||||
int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv);
|
int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv);
|
||||||
|
|
||||||
|
typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateShape)(
|
||||||
|
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
|
||||||
|
const struct TpuAllocationShape shape, int32_t eventc,
|
||||||
|
struct TpuEvent** eventv);
|
||||||
|
|
||||||
typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)(
|
typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)(
|
||||||
struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle,
|
struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle,
|
||||||
int32_t eventc, struct TpuEvent** eventv);
|
int32_t eventc, struct TpuEvent** eventv);
|
||||||
@ -158,6 +179,12 @@ typedef const char*(PrototypeTpuDriver_Version)();
|
|||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize;
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open;
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close;
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close;
|
||||||
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ComputeLinearizedBytesFromShape
|
||||||
|
TpuDriver_ComputeLinearizedBytesFromShape;
|
||||||
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LinearizeShape
|
||||||
|
TpuDriver_LinearizeShape;
|
||||||
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_DelinearizeShape
|
||||||
|
TpuDriver_DelinearizeShape;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram
|
||||||
TpuDriver_CompileProgram;
|
TpuDriver_CompileProgram;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText
|
||||||
@ -171,6 +198,8 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram
|
|||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple
|
||||||
TpuDriver_AllocateTuple;
|
TpuDriver_AllocateTuple;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate;
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate;
|
||||||
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateShape
|
||||||
|
TpuDriver_AllocateShape;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate;
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice
|
||||||
TpuDriver_TransferToDevice;
|
TpuDriver_TransferToDevice;
|
||||||
@ -196,6 +225,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version;
|
|||||||
struct TpuDriverFn {
|
struct TpuDriverFn {
|
||||||
PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT
|
PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT
|
||||||
PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT
|
PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT
|
||||||
|
PrototypeTpuDriver_ComputeLinearizedBytesFromShape*
|
||||||
|
TpuDriver_ComputeLinearizedBytesFromShape; // NOLINT
|
||||||
|
PrototypeTpuDriver_LinearizeShape* TpuDriver_LinearizeShape; // NOLINT
|
||||||
|
PrototypeTpuDriver_DelinearizeShape* TpuDriver_DelinearizeShape; // NOLINT
|
||||||
PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT
|
PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT
|
||||||
PrototypeTpuDriver_CompileProgramFromText*
|
PrototypeTpuDriver_CompileProgramFromText*
|
||||||
TpuDriver_CompileProgramFromText; // NOLINT
|
TpuDriver_CompileProgramFromText; // NOLINT
|
||||||
@ -204,6 +237,7 @@ struct TpuDriverFn {
|
|||||||
PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT
|
PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT
|
||||||
PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT
|
PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT
|
||||||
PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT
|
PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT
|
||||||
|
PrototypeTpuDriver_AllocateShape* TpuDriver_AllocateShape; // NOLINT
|
||||||
PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT
|
PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT
|
||||||
PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT
|
PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT
|
||||||
PrototypeTpuDriver_TransferFromDevice*
|
PrototypeTpuDriver_TransferFromDevice*
|
||||||
|
@ -27,6 +27,19 @@
|
|||||||
namespace tpu_driver {
|
namespace tpu_driver {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
|
||||||
|
::TpuAllocationShape shape_;
|
||||||
|
shape_.size = shape.ByteSizeLong();
|
||||||
|
shape_.bytes = malloc(shape_.size);
|
||||||
|
if (!shape.SerializeToArray(shape_.bytes, shape_.size)) {
|
||||||
|
LOG(ERROR) << "Unable to serialize shape to array.";
|
||||||
|
free(shape_.bytes);
|
||||||
|
shape_.size = 0;
|
||||||
|
shape_.bytes = nullptr;
|
||||||
|
}
|
||||||
|
return shape_;
|
||||||
|
}
|
||||||
|
|
||||||
class ExternalTpuDriver;
|
class ExternalTpuDriver;
|
||||||
|
|
||||||
class ExternalEvent : public Event {
|
class ExternalEvent : public Event {
|
||||||
@ -161,6 +174,51 @@ class ExternalLoadedProgramHandle : public LoadedProgramHandle {
|
|||||||
friend ExternalTpuDriver;
|
friend ExternalTpuDriver;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ExternalTpuLinearizer : public TpuLinearizer {
|
||||||
|
public:
|
||||||
|
explicit ExternalTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
|
||||||
|
: driver_(driver), driver_fn_(driver_fn) {}
|
||||||
|
|
||||||
|
int64_t ComputeLinearizedBytesFromShape(
|
||||||
|
const xla::ShapeProto& shape) override {
|
||||||
|
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||||
|
uint64_t size =
|
||||||
|
driver_fn_->TpuDriver_ComputeLinearizedBytesFromShape(driver_, shape_);
|
||||||
|
free(shape_.bytes);
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Status LinearizeShape(void* dst, const void* src,
|
||||||
|
const xla::ShapeProto& shape) override {
|
||||||
|
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||||
|
|
||||||
|
auto tpu_status =
|
||||||
|
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
|
||||||
|
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||||
|
absl::StrFormat("%s", tpu_status->msg));
|
||||||
|
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||||
|
free(shape_.bytes);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::Status DelinearizeShape(void* dst, const void* src,
|
||||||
|
const xla::ShapeProto& shape) override {
|
||||||
|
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||||
|
|
||||||
|
auto tpu_status =
|
||||||
|
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
|
||||||
|
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||||
|
absl::StrFormat("%s", tpu_status->msg));
|
||||||
|
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||||
|
free(shape_.bytes);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
::TpuDriver* driver_;
|
||||||
|
::TpuDriverFn* driver_fn_;
|
||||||
|
};
|
||||||
|
|
||||||
class ExternalTpuDriver : public TpuDriver {
|
class ExternalTpuDriver : public TpuDriver {
|
||||||
public:
|
public:
|
||||||
explicit ExternalTpuDriver(const std::string& so_path) {
|
explicit ExternalTpuDriver(const std::string& so_path) {
|
||||||
@ -201,8 +259,17 @@ class ExternalTpuDriver : public TpuDriver {
|
|||||||
std::unique_ptr<BufferHandle> Allocate(
|
std::unique_ptr<BufferHandle> Allocate(
|
||||||
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
|
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
|
||||||
absl::Span<Event* const> wait_for) override {
|
absl::Span<Event* const> wait_for) override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
auto tpu_events = MakeEventArray(wait_for);
|
||||||
return nullptr;
|
|
||||||
|
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||||
|
auto bh = absl::make_unique<ExternalBufferHandle>(
|
||||||
|
&driver_fn_,
|
||||||
|
driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_,
|
||||||
|
wait_for.size(), tpu_events));
|
||||||
|
|
||||||
|
free(shape_.bytes);
|
||||||
|
delete[] tpu_events;
|
||||||
|
return bh;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<BufferHandle> AllocateTuple(
|
std::unique_ptr<BufferHandle> AllocateTuple(
|
||||||
@ -366,7 +433,9 @@ class ExternalTpuDriver : public TpuDriver {
|
|||||||
return event;
|
return event;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }
|
std::unique_ptr<TpuLinearizer> GetLinearizer() override {
|
||||||
|
return std::make_unique<ExternalTpuLinearizer>(driver_, &driver_fn_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
::TpuDriverFn driver_fn_;
|
::TpuDriverFn driver_fn_;
|
||||||
|
@ -33,7 +33,7 @@ DriverRegistryMap* GetDriverRegistryMap() {
|
|||||||
return driver_registry;
|
return driver_registry;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
int64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
||||||
switch (primitive_type) {
|
switch (primitive_type) {
|
||||||
case xla::PrimitiveType::PRED:
|
case xla::PrimitiveType::PRED:
|
||||||
return sizeof(int8_t);
|
return sizeof(int8_t);
|
||||||
@ -96,12 +96,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
|||||||
config.worker());
|
config.worker());
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||||
if (shape.tuple_shapes_size() > 0) {
|
if (shape.tuple_shapes_size() > 0) {
|
||||||
LOG(FATAL) << "Tuples are not supported at the moment.";
|
LOG(FATAL) << "Tuples are not supported at the moment.";
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t num_elems = 1;
|
int64_t num_elems = 1;
|
||||||
for (auto dim : shape.dimensions()) {
|
for (auto dim : shape.dimensions()) {
|
||||||
num_elems *= dim;
|
num_elems *= dim;
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@
|
|||||||
|
|
||||||
namespace tpu_driver {
|
namespace tpu_driver {
|
||||||
|
|
||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
||||||
|
|
||||||
// Represents the deferred completion of a scheduled operation.
|
// Represents the deferred completion of a scheduled operation.
|
||||||
//
|
//
|
||||||
@ -120,10 +120,10 @@ class TpuLinearizer {
|
|||||||
public:
|
public:
|
||||||
virtual ~TpuLinearizer() {}
|
virtual ~TpuLinearizer() {}
|
||||||
|
|
||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||||
return ::tpu_driver::ComputeBytesFromShape(shape);
|
return ::tpu_driver::ComputeBytesFromShape(shape);
|
||||||
}
|
}
|
||||||
virtual uint64_t ComputeLinearizedBytesFromShape(
|
virtual int64_t ComputeLinearizedBytesFromShape(
|
||||||
const xla::ShapeProto& shape) = 0;
|
const xla::ShapeProto& shape) = 0;
|
||||||
|
|
||||||
virtual xla::Status LinearizeShape(void* dst, const void* src,
|
virtual xla::Status LinearizeShape(void* dst, const void* src,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user