Hook out compile, execute, load/unload APIs to external TPU driver
PiperOrigin-RevId: 287204382 Change-Id: Ia1ae8ec3179c09e393aec69041a41c241177efe7
This commit is contained in:
parent
400e246b7a
commit
5078cab51c
@ -68,6 +68,12 @@ typedef struct TpuStatus {
|
|||||||
char* msg;
|
char* msg;
|
||||||
} TpuStatus;
|
} TpuStatus;
|
||||||
|
|
||||||
|
typedef struct CompiledProgramShape {
|
||||||
|
struct TpuStatus* status;
|
||||||
|
void* bytes;
|
||||||
|
int32_t size;
|
||||||
|
} CompiledProgramShape;
|
||||||
|
|
||||||
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
|
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
|
||||||
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
|
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
|
||||||
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
|
||||||
@ -85,7 +91,7 @@ typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)(
|
|||||||
int32_t eventc, struct TpuEvent** eventv);
|
int32_t eventc, struct TpuEvent** eventv);
|
||||||
|
|
||||||
typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)(
|
typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)(
|
||||||
struct TpuDriver* driver, int32_t core_id,
|
struct TpuDriver* driver,
|
||||||
struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc,
|
struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc,
|
||||||
struct TpuEvent** eventv);
|
struct TpuEvent** eventv);
|
||||||
|
|
||||||
@ -121,6 +127,13 @@ typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)(
|
|||||||
struct TpuDriver* driver, struct TpuBufferHandle* src,
|
struct TpuDriver* driver, struct TpuBufferHandle* src,
|
||||||
struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv);
|
struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv);
|
||||||
|
|
||||||
|
typedef struct CompiledProgramShape*(
|
||||||
|
PrototypeTpuDriver_GetCompiledProgramShape)(
|
||||||
|
struct TpuCompiledProgramHandle* handle);
|
||||||
|
|
||||||
|
typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)(
|
||||||
|
struct CompiledProgramShape* shape);
|
||||||
|
|
||||||
typedef void(PrototypeTpuDriver_EventAddCallback)(
|
typedef void(PrototypeTpuDriver_EventAddCallback)(
|
||||||
struct TpuEvent* event,
|
struct TpuEvent* event,
|
||||||
void (*callback_fn)(struct TpuStatus*, void* additional_info),
|
void (*callback_fn)(struct TpuStatus*, void* additional_info),
|
||||||
@ -156,6 +169,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice
|
|||||||
TpuDriver_TransferFromDevice;
|
TpuDriver_TransferFromDevice;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice
|
||||||
TpuDriver_TransferFromDeviceToDevice;
|
TpuDriver_TransferFromDeviceToDevice;
|
||||||
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape
|
||||||
|
TpuDriver_GetCompiledProgramShape;
|
||||||
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape
|
||||||
|
TpuDriver_FreeCompiledProgramShape;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback
|
||||||
TpuDriver_EventAddCallback;
|
TpuDriver_EventAddCallback;
|
||||||
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait;
|
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait;
|
||||||
@ -182,6 +199,10 @@ struct TpuDriverFn {
|
|||||||
TpuDriver_TransferFromDevice; // NOLINT
|
TpuDriver_TransferFromDevice; // NOLINT
|
||||||
PrototypeTpuDriver_TransferFromDeviceToDevice*
|
PrototypeTpuDriver_TransferFromDeviceToDevice*
|
||||||
TpuDriver_TransferFromDeviceToDevice; // NOLINT
|
TpuDriver_TransferFromDeviceToDevice; // NOLINT
|
||||||
|
PrototypeTpuDriver_GetCompiledProgramShape*
|
||||||
|
TpuDriver_GetCompiledProgramShape; // NOLINT
|
||||||
|
PrototypeTpuDriver_FreeCompiledProgramShape*
|
||||||
|
TpuDriver_FreeCompiledProgramShape; // NOLINT
|
||||||
PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT
|
PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT
|
||||||
PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT
|
PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT
|
||||||
PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT
|
PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT
|
||||||
|
@ -109,10 +109,13 @@ class ExternalBufferHandle : public BufferHandle {
|
|||||||
|
|
||||||
class ExternalCompiledProgramHandle : public CompiledProgramHandle {
|
class ExternalCompiledProgramHandle : public CompiledProgramHandle {
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<Event> OnReady() override {
|
explicit ExternalCompiledProgramHandle(::TpuDriverFn* driver_fn,
|
||||||
LOG(FATAL) << "Unimplemented";
|
::TpuCompiledProgramHandle* handle)
|
||||||
return std::shared_ptr<Event>();
|
: handle_(handle),
|
||||||
}
|
driver_fn_(driver_fn),
|
||||||
|
event_(new ExternalEvent(driver_fn, handle->event)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||||
|
|
||||||
int64_t size_in_bytes() override {
|
int64_t size_in_bytes() override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
LOG(FATAL) << "Unimplemented.";
|
||||||
@ -120,22 +123,42 @@ class ExternalCompiledProgramHandle : public CompiledProgramHandle {
|
|||||||
}
|
}
|
||||||
|
|
||||||
xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
|
xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
struct CompiledProgramShape* shape =
|
||||||
return xla::Unimplemented("%s", "Unimplemented.");
|
driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
|
||||||
|
program_shape->ParseFromArray(shape->bytes, shape->size);
|
||||||
|
|
||||||
|
auto status = xla::Status(tensorflow::error::Code(shape->status->code),
|
||||||
|
absl::StrFormat("%s", shape->status->msg));
|
||||||
|
driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
|
||||||
|
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
::TpuCompiledProgramHandle* handle_;
|
||||||
|
::TpuDriverFn* driver_fn_;
|
||||||
|
std::shared_ptr<ExternalEvent> event_;
|
||||||
|
|
||||||
|
friend ExternalTpuDriver;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ExternalLoadedProgramHandle : public LoadedProgramHandle {
|
class ExternalLoadedProgramHandle : public LoadedProgramHandle {
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<Event> OnReady() override {
|
explicit ExternalLoadedProgramHandle(::TpuDriverFn* driver_fn,
|
||||||
LOG(FATAL) << "Unimplemented";
|
::TpuLoadedProgramHandle* handle)
|
||||||
return std::shared_ptr<Event>();
|
: handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {}
|
||||||
}
|
std::shared_ptr<Event> OnReady() override { return event_; }
|
||||||
|
|
||||||
int64_t size_in_bytes() override {
|
int64_t size_in_bytes() override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
LOG(FATAL) << "Unimplemented.";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
::TpuLoadedProgramHandle* handle_;
|
||||||
|
std::shared_ptr<ExternalEvent> event_;
|
||||||
|
|
||||||
|
friend ExternalTpuDriver;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ExternalTpuDriver : public TpuDriver {
|
class ExternalTpuDriver : public TpuDriver {
|
||||||
@ -246,28 +269,93 @@ class ExternalTpuDriver : public TpuDriver {
|
|||||||
std::unique_ptr<CompiledProgramHandle> CompileProgram(
|
std::unique_ptr<CompiledProgramHandle> CompileProgram(
|
||||||
const xla::HloProto& source, int32_t num_replicas,
|
const xla::HloProto& source, int32_t num_replicas,
|
||||||
absl::Span<Event* const> wait_for) override {
|
absl::Span<Event* const> wait_for) override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
auto tpu_events = MakeEventArray(wait_for);
|
||||||
return nullptr;
|
|
||||||
|
struct HloProto hlo;
|
||||||
|
hlo.size = source.ByteSizeLong();
|
||||||
|
hlo.bytes = malloc(hlo.size);
|
||||||
|
if (!source.SerializeToArray(hlo.bytes, hlo.size)) {
|
||||||
|
LOG(ERROR) << "Unable to serialize HLO to array.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto handle = absl::make_unique<ExternalCompiledProgramHandle>(
|
||||||
|
&driver_fn_,
|
||||||
|
driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas,
|
||||||
|
wait_for.size(), tpu_events));
|
||||||
|
|
||||||
|
free(hlo.bytes);
|
||||||
|
delete tpu_events;
|
||||||
|
return handle;
|
||||||
}
|
}
|
||||||
std::unique_ptr<LoadedProgramHandle> LoadProgram(
|
std::unique_ptr<LoadedProgramHandle> LoadProgram(
|
||||||
int32_t core_id, const CompiledProgramHandle* handle,
|
int32_t core_id, const CompiledProgramHandle* handle,
|
||||||
absl::Span<Event* const> wait_for) override {
|
absl::Span<Event* const> wait_for) override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
auto tpu_events = MakeEventArray(wait_for);
|
||||||
return nullptr;
|
|
||||||
|
auto loaded_handle = absl::make_unique<ExternalLoadedProgramHandle>(
|
||||||
|
&driver_fn_,
|
||||||
|
driver_fn_.TpuDriver_LoadProgram(
|
||||||
|
driver_, core_id,
|
||||||
|
static_cast<const ExternalCompiledProgramHandle*>(handle)->handle_,
|
||||||
|
wait_for.size(), tpu_events));
|
||||||
|
|
||||||
|
delete tpu_events;
|
||||||
|
return loaded_handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Event> UnloadProgram(
|
std::shared_ptr<Event> UnloadProgram(
|
||||||
std::unique_ptr<LoadedProgramHandle> handle,
|
std::unique_ptr<LoadedProgramHandle> handle,
|
||||||
absl::Span<Event* const> wait_for) override {
|
absl::Span<Event* const> wait_for) override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
auto tpu_events = MakeEventArray(wait_for);
|
||||||
return nullptr;
|
auto event = std::make_shared<ExternalEvent>(
|
||||||
|
&driver_fn_,
|
||||||
|
driver_fn_.TpuDriver_UnloadProgram(
|
||||||
|
driver_,
|
||||||
|
static_cast<ExternalLoadedProgramHandle*>(handle.get())->handle_,
|
||||||
|
wait_for.size(), tpu_events));
|
||||||
|
delete tpu_events;
|
||||||
|
return event;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Event> ExecuteProgram(
|
std::shared_ptr<Event> ExecuteProgram(
|
||||||
LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
|
LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
|
||||||
absl::Span<BufferHandle* const> outputs,
|
absl::Span<BufferHandle* const> outputs,
|
||||||
const xla::DeviceAssignmentProto& device_assignment,
|
const xla::DeviceAssignmentProto& device_assignment,
|
||||||
absl::Span<Event* const> wait_for) override {
|
absl::Span<Event* const> wait_for) override {
|
||||||
LOG(FATAL) << "Unimplemented.";
|
auto tpu_events = MakeEventArray(wait_for);
|
||||||
return nullptr;
|
|
||||||
|
struct DeviceAssignmentProto da_proto;
|
||||||
|
da_proto.size = device_assignment.ByteSizeLong();
|
||||||
|
da_proto.bytes = malloc(da_proto.size);
|
||||||
|
if (!device_assignment.SerializeToArray(da_proto.bytes, da_proto.size)) {
|
||||||
|
LOG(ERROR) << "Unable to serialize device assignment to array.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<::TpuBufferHandle*> inputv;
|
||||||
|
inputv.reserve(inputs.size());
|
||||||
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
|
inputv.push_back(
|
||||||
|
static_cast<ExternalBufferHandle* const>(inputs[i])->handle_);
|
||||||
|
}
|
||||||
|
std::vector<::TpuBufferHandle*> outputv;
|
||||||
|
outputv.reserve(outputs.size());
|
||||||
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
|
outputv.push_back(
|
||||||
|
static_cast<ExternalBufferHandle* const>(outputs[i])->handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto event = std::make_shared<ExternalEvent>(
|
||||||
|
&driver_fn_,
|
||||||
|
driver_fn_.TpuDriver_ExecuteProgram(
|
||||||
|
driver_,
|
||||||
|
static_cast<ExternalLoadedProgramHandle*>(program)->handle_,
|
||||||
|
inputs.size(), inputv.data(), outputs.size(), outputv.data(),
|
||||||
|
da_proto, wait_for.size(), tpu_events));
|
||||||
|
|
||||||
|
free(da_proto.bytes);
|
||||||
|
return event;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }
|
std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }
|
||||||
|
Loading…
Reference in New Issue
Block a user