Add shape-based allocation and linearizer to external TPU driver

PiperOrigin-RevId: 288938928
Change-Id: I27b6b88d01880db216810f826c69ae40cb7cbaaf
This commit is contained in:
Frank Chen 2020-01-09 11:38:19 -08:00 committed by TensorFlower Gardener
parent 62bc37e41e
commit e98a887ebe
4 changed files with 112 additions and 9 deletions

View File

@ -74,6 +74,11 @@ typedef struct CompiledProgramShape {
int32_t size;
} CompiledProgramShape;
typedef struct TpuAllocationShape {
void* bytes;
int32_t size;
} TpuAllocationShape;
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
@ -81,6 +86,17 @@ typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
// TODO(frankchn): Make this not a hard-coded constant.
const int32_t MemoryRegion_HBM = 1;
typedef int64_t(PrototypeTpuDriver_ComputeLinearizedBytesFromShape)(
struct TpuDriver* driver, const struct TpuAllocationShape shape);
typedef struct TpuStatus*(PrototypeTpuDriver_LinearizeShape)(
struct TpuDriver* driver, void* dst, const void* src,
const struct TpuAllocationShape shape);
typedef struct TpuStatus*(PrototypeTpuDriver_DelinearizeShape)(
struct TpuDriver* driver, void* dst, const void* src,
const struct TpuAllocationShape shape);
typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)(
struct TpuDriver* driver, const struct HloProto hlo_proto,
int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv);
@ -118,6 +134,11 @@ typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)(
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv);
typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateShape)(
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
const struct TpuAllocationShape shape, int32_t eventc,
struct TpuEvent** eventv);
typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)(
struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle,
int32_t eventc, struct TpuEvent** eventv);
@ -158,6 +179,12 @@ typedef const char*(PrototypeTpuDriver_Version)();
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ComputeLinearizedBytesFromShape
TpuDriver_ComputeLinearizedBytesFromShape;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LinearizeShape
TpuDriver_LinearizeShape;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_DelinearizeShape
TpuDriver_DelinearizeShape;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram
TpuDriver_CompileProgram;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText
@ -171,6 +198,8 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple
TpuDriver_AllocateTuple;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateShape
TpuDriver_AllocateShape;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice
TpuDriver_TransferToDevice;
@ -196,6 +225,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version;
struct TpuDriverFn {
PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT
PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT
PrototypeTpuDriver_ComputeLinearizedBytesFromShape*
TpuDriver_ComputeLinearizedBytesFromShape; // NOLINT
PrototypeTpuDriver_LinearizeShape* TpuDriver_LinearizeShape; // NOLINT
PrototypeTpuDriver_DelinearizeShape* TpuDriver_DelinearizeShape; // NOLINT
PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT
PrototypeTpuDriver_CompileProgramFromText*
TpuDriver_CompileProgramFromText; // NOLINT
@ -204,6 +237,7 @@ struct TpuDriverFn {
PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT
PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT
PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT
PrototypeTpuDriver_AllocateShape* TpuDriver_AllocateShape; // NOLINT
PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT
PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT
PrototypeTpuDriver_TransferFromDevice*

View File

@ -27,6 +27,19 @@
namespace tpu_driver {
namespace {
::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) {
::TpuAllocationShape shape_;
shape_.size = shape.ByteSizeLong();
shape_.bytes = malloc(shape_.size);
if (!shape.SerializeToArray(shape_.bytes, shape_.size)) {
LOG(ERROR) << "Unable to serialize shape to array.";
free(shape_.bytes);
shape_.size = 0;
shape_.bytes = nullptr;
}
return shape_;
}
class ExternalTpuDriver;
class ExternalEvent : public Event {
@ -161,6 +174,51 @@ class ExternalLoadedProgramHandle : public LoadedProgramHandle {
friend ExternalTpuDriver;
};
class ExternalTpuLinearizer : public TpuLinearizer {
public:
explicit ExternalTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn)
: driver_(driver), driver_fn_(driver_fn) {}
int64_t ComputeLinearizedBytesFromShape(
const xla::ShapeProto& shape) override {
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
uint64_t size =
driver_fn_->TpuDriver_ComputeLinearizedBytesFromShape(driver_, shape_);
free(shape_.bytes);
return size;
}
xla::Status LinearizeShape(void* dst, const void* src,
const xla::ShapeProto& shape) override {
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
auto tpu_status =
driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_);
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
absl::StrFormat("%s", tpu_status->msg));
driver_fn_->TpuDriver_FreeStatus(tpu_status);
free(shape_.bytes);
return status;
}
xla::Status DelinearizeShape(void* dst, const void* src,
const xla::ShapeProto& shape) override {
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
auto tpu_status =
driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_);
auto status = xla::Status(tensorflow::error::Code(tpu_status->code),
absl::StrFormat("%s", tpu_status->msg));
driver_fn_->TpuDriver_FreeStatus(tpu_status);
free(shape_.bytes);
return status;
}
private:
::TpuDriver* driver_;
::TpuDriverFn* driver_fn_;
};
class ExternalTpuDriver : public TpuDriver {
public:
explicit ExternalTpuDriver(const std::string& so_path) {
@ -201,8 +259,17 @@ class ExternalTpuDriver : public TpuDriver {
std::unique_ptr<BufferHandle> Allocate(
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) override {
LOG(FATAL) << "Unimplemented.";
return nullptr;
auto tpu_events = MakeEventArray(wait_for);
::TpuAllocationShape shape_ = GetTpuAllocationShape(shape);
auto bh = absl::make_unique<ExternalBufferHandle>(
&driver_fn_,
driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_,
wait_for.size(), tpu_events));
free(shape_.bytes);
delete[] tpu_events;
return bh;
}
std::unique_ptr<BufferHandle> AllocateTuple(
@ -366,7 +433,9 @@ class ExternalTpuDriver : public TpuDriver {
return event;
}
std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }
std::unique_ptr<TpuLinearizer> GetLinearizer() override {
return std::make_unique<ExternalTpuLinearizer>(driver_, &driver_fn_);
}
private:
::TpuDriverFn driver_fn_;

View File

@ -33,7 +33,7 @@ DriverRegistryMap* GetDriverRegistryMap() {
return driver_registry;
}
uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
int64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
switch (primitive_type) {
case xla::PrimitiveType::PRED:
return sizeof(int8_t);
@ -96,12 +96,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
config.worker());
}
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
if (shape.tuple_shapes_size() > 0) {
LOG(FATAL) << "Tuples are not supported at the moment.";
}
uint64_t num_elems = 1;
int64_t num_elems = 1;
for (auto dim : shape.dimensions()) {
num_elems *= dim;
}

View File

@ -42,7 +42,7 @@
namespace tpu_driver {
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
// Represents the deferred completion of a scheduled operation.
//
@ -120,10 +120,10 @@ class TpuLinearizer {
public:
virtual ~TpuLinearizer() {}
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
return ::tpu_driver::ComputeBytesFromShape(shape);
}
virtual uint64_t ComputeLinearizedBytesFromShape(
virtual int64_t ComputeLinearizedBytesFromShape(
const xla::ShapeProto& shape) = 0;
virtual xla::Status LinearizeShape(void* dst, const void* src,