Add shape-based allocation and linearizer to external TPU driver
PiperOrigin-RevId: 288938928 Change-Id: I27b6b88d01880db216810f826c69ae40cb7cbaaf
This commit is contained in:
parent
62bc37e41e
commit
e98a887ebe
@ -74,6 +74,11 @@ typedef struct CompiledProgramShape {
|
||||
int32_t size;
|
||||
} CompiledProgramShape;
|
||||
|
||||
typedef struct TpuAllocationShape {
|
||||
void* bytes;
|
||||
int32_t size;
|
||||
} TpuAllocationShape;
|
||||
|
||||
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
|
||||
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
|
||||
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
||||
@ -81,6 +86,17 @@ typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
||||
// TODO(frankchn): Make this not a hard-coded constant.
|
||||
const int32_t MemoryRegion_HBM = 1;
|
||||
|
||||
typedef int64_t(PrototypeTpuDriver_ComputeLinearizedBytesFromShape)(
|
||||
struct TpuDriver* driver, const struct TpuAllocationShape shape);
|
||||
|
||||
typedef struct TpuStatus*(PrototypeTpuDriver_LinearizeShape)(
|
||||
struct TpuDriver* driver, void* dst, const void* src,
|
||||
const struct TpuAllocationShape shape);
|
||||
|
||||
typedef struct TpuStatus*(PrototypeTpuDriver_DelinearizeShape)(
|
||||
struct TpuDriver* driver, void* dst, const void* src,
|
||||
const struct TpuAllocationShape shape);
|
||||
|
||||
typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)(
|
||||
struct TpuDriver* driver, const struct HloProto hlo_proto,
|
||||
int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv);
|
||||
@ -118,6 +134,11 @@ typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)(
|
||||
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
|
||||
int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv);
|
||||
|
||||
typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateShape)(
|
||||
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
|
||||
const struct TpuAllocationShape shape, int32_t eventc,
|
||||
struct TpuEvent** eventv);
|
||||
|
||||
typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)(
|
||||
struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle,
|
||||
int32_t eventc, struct TpuEvent** eventv);
|
||||
@ -158,6 +179,12 @@ typedef const char*(PrototypeTpuDriver_Version)();
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ComputeLinearizedBytesFromShape
|
||||
TpuDriver_ComputeLinearizedBytesFromShape;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LinearizeShape
|
||||
TpuDriver_LinearizeShape;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_DelinearizeShape
|
||||
TpuDriver_DelinearizeShape;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram
|
||||
TpuDriver_CompileProgram;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText
|
||||
@ -171,6 +198,8 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple
|
||||
TpuDriver_AllocateTuple;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateShape
|
||||
TpuDriver_AllocateShape;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate;
|
||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice
|
||||
TpuDriver_TransferToDevice;
|
||||
@ -196,6 +225,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version;
|
||||
struct TpuDriverFn {
|
||||
PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT
|
||||
PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT
|
||||
PrototypeTpuDriver_ComputeLinearizedBytesFromShape*
|
||||
TpuDriver_ComputeLinearizedBytesFromShape; // NOLINT
|
||||
PrototypeTpuDriver_LinearizeShape* TpuDriver_LinearizeShape; // NOLINT
|
||||
PrototypeTpuDriver_DelinearizeShape* TpuDriver_DelinearizeShape; // NOLINT
|
||||
PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT
|
||||
PrototypeTpuDriver_CompileProgramFromText*
|
||||
TpuDriver_CompileProgramFromText; // NOLINT
|
||||
@ -204,6 +237,7 @@ struct TpuDriverFn {
|
||||
PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT
|
||||
PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT
|
||||
PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT
|
||||
PrototypeTpuDriver_AllocateShape* TpuDriver_AllocateShape; // NOLINT
|
||||
PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT
|
||||
PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT
|
||||
PrototypeTpuDriver_TransferFromDevice*
|
||||
|
@ -27,6 +27,19 @@
|
||||
namespace tpu_driver {
|
||||
namespace {
|
||||
|
||||
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
|
||||
::TpuAllocationShape shape_;
|
||||
shape_.size = shape.ByteSizeLong();
|
||||
shape_.bytes = malloc(shape_.size);
|
||||
if (!shape.SerializeToArray(shape_.bytes, shape_.size)) {
|
||||
LOG(ERROR) << "Unable to serialize shape to array.";
|
||||
free(shape_.bytes);
|
||||
shape_.size = 0;
|
||||
shape_.bytes = nullptr;
|
||||
}
|
||||
return shape_;
|
||||
}
|
||||
|
||||
class ExternalTpuDriver;
|
||||
|
||||
class ExternalEvent : public Event {
|
||||
@ -161,6 +174,51 @@ class ExternalLoadedProgramHandle : public LoadedProgramHandle {
|
||||
friend ExternalTpuDriver;
|
||||
};
|
||||
|
||||
class ExternalTpuLinearizer : public TpuLinearizer {
|
||||
public:
|
||||
explicit ExternalTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
|
||||
: driver_(driver), driver_fn_(driver_fn) {}
|
||||
|
||||
int64_t ComputeLinearizedBytesFromShape(
|
||||
const xla::ShapeProto& shape) override {
|
||||
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||
uint64_t size =
|
||||
driver_fn_->TpuDriver_ComputeLinearizedBytesFromShape(driver_, shape_);
|
||||
free(shape_.bytes);
|
||||
return size;
|
||||
}
|
||||
|
||||
xla::Status LinearizeShape(void* dst, const void* src,
|
||||
const xla::ShapeProto& shape) override {
|
||||
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||
|
||||
auto tpu_status =
|
||||
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
|
||||
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||
absl::StrFormat("%s", tpu_status->msg));
|
||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||
free(shape_.bytes);
|
||||
return status;
|
||||
}
|
||||
|
||||
xla::Status DelinearizeShape(void* dst, const void* src,
|
||||
const xla::ShapeProto& shape) override {
|
||||
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||
|
||||
auto tpu_status =
|
||||
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
|
||||
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
|
||||
absl::StrFormat("%s", tpu_status->msg));
|
||||
driver_fn_->TpuDriver_FreeStatus(tpu_status);
|
||||
free(shape_.bytes);
|
||||
return status;
|
||||
}
|
||||
|
||||
private:
|
||||
::TpuDriver* driver_;
|
||||
::TpuDriverFn* driver_fn_;
|
||||
};
|
||||
|
||||
class ExternalTpuDriver : public TpuDriver {
|
||||
public:
|
||||
explicit ExternalTpuDriver(const std::string& so_path) {
|
||||
@ -201,8 +259,17 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
std::unique_ptr<BufferHandle> Allocate(
|
||||
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
|
||||
absl::Span<Event* const> wait_for) override {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
return nullptr;
|
||||
auto tpu_events = MakeEventArray(wait_for);
|
||||
|
||||
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
|
||||
auto bh = absl::make_unique<ExternalBufferHandle>(
|
||||
&driver_fn_,
|
||||
driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_,
|
||||
wait_for.size(), tpu_events));
|
||||
|
||||
free(shape_.bytes);
|
||||
delete[] tpu_events;
|
||||
return bh;
|
||||
}
|
||||
|
||||
std::unique_ptr<BufferHandle> AllocateTuple(
|
||||
@ -366,7 +433,9 @@ class ExternalTpuDriver : public TpuDriver {
|
||||
return event;
|
||||
}
|
||||
|
||||
std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }
|
||||
std::unique_ptr<TpuLinearizer> GetLinearizer() override {
|
||||
return std::make_unique<ExternalTpuLinearizer>(driver_, &driver_fn_);
|
||||
}
|
||||
|
||||
private:
|
||||
::TpuDriverFn driver_fn_;
|
||||
|
@ -33,7 +33,7 @@ DriverRegistryMap* GetDriverRegistryMap() {
|
||||
return driver_registry;
|
||||
}
|
||||
|
||||
uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
||||
int64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case xla::PrimitiveType::PRED:
|
||||
return sizeof(int8_t);
|
||||
@ -96,12 +96,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
||||
config.worker());
|
||||
}
|
||||
|
||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||
if (shape.tuple_shapes_size() > 0) {
|
||||
LOG(FATAL) << "Tuples are not supported at the moment.";
|
||||
}
|
||||
|
||||
uint64_t num_elems = 1;
|
||||
int64_t num_elems = 1;
|
||||
for (auto dim : shape.dimensions()) {
|
||||
num_elems *= dim;
|
||||
}
|
||||
|
@ -42,7 +42,7 @@
|
||||
|
||||
namespace tpu_driver {
|
||||
|
||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
||||
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
||||
|
||||
// Represents the deferred completion of a scheduled operation.
|
||||
//
|
||||
@ -120,10 +120,10 @@ class TpuLinearizer {
|
||||
public:
|
||||
virtual ~TpuLinearizer() {}
|
||||
|
||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||
return ::tpu_driver::ComputeBytesFromShape(shape);
|
||||
}
|
||||
virtual uint64_t ComputeLinearizedBytesFromShape(
|
||||
virtual int64_t ComputeLinearizedBytesFromShape(
|
||||
const xla::ShapeProto& shape) = 0;
|
||||
|
||||
virtual xla::Status LinearizeShape(void* dst, const void* src,
|
||||
|
Loading…
x
Reference in New Issue
Block a user