Hook out compile, execute, load/unload APIs to external TPU driver

PiperOrigin-RevId: 287204382
Change-Id: Ia1ae8ec3179c09e393aec69041a41c241177efe7
This commit is contained in:
Frank Chen 2019-12-26 11:01:18 -08:00 committed by TensorFlower Gardener
parent 400e246b7a
commit 5078cab51c
2 changed files with 128 additions and 19 deletions

View File

@ -68,6 +68,12 @@ typedef struct TpuStatus {
char* msg;
} TpuStatus;
typedef struct CompiledProgramShape {
struct TpuStatus* status;
void* bytes;
int32_t size;
} CompiledProgramShape;
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
@ -85,7 +91,7 @@ typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)(
int32_t eventc, struct TpuEvent** eventv);
typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)(
struct TpuDriver* driver, int32_t core_id,
struct TpuDriver* driver,
struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc,
struct TpuEvent** eventv);
@ -121,6 +127,13 @@ typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)(
struct TpuDriver* driver, struct TpuBufferHandle* src,
struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv);
typedef struct CompiledProgramShape*(
PrototypeTpuDriver_GetCompiledProgramShape)(
struct TpuCompiledProgramHandle* handle);
typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)(
struct CompiledProgramShape* shape);
typedef void(PrototypeTpuDriver_EventAddCallback)(
struct TpuEvent* event,
void (*callback_fn)(struct TpuStatus*, void* additional_info),
@ -156,6 +169,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice
TpuDriver_TransferFromDevice;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice
TpuDriver_TransferFromDeviceToDevice;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape
TpuDriver_GetCompiledProgramShape;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape
TpuDriver_FreeCompiledProgramShape;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback
TpuDriver_EventAddCallback;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait;
@ -182,6 +199,10 @@ struct TpuDriverFn {
TpuDriver_TransferFromDevice; // NOLINT
PrototypeTpuDriver_TransferFromDeviceToDevice*
TpuDriver_TransferFromDeviceToDevice; // NOLINT
PrototypeTpuDriver_GetCompiledProgramShape*
TpuDriver_GetCompiledProgramShape; // NOLINT
PrototypeTpuDriver_FreeCompiledProgramShape*
TpuDriver_FreeCompiledProgramShape; // NOLINT
PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT
PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT
PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT

View File

@ -109,10 +109,13 @@ class ExternalBufferHandle : public BufferHandle {
class ExternalCompiledProgramHandle : public CompiledProgramHandle {
public:
std::shared_ptr<Event> OnReady() override {
LOG(FATAL) << "Unimplemented";
return std::shared_ptr<Event>();
}
explicit ExternalCompiledProgramHandle(::TpuDriverFn* driver_fn,
::TpuCompiledProgramHandle* handle)
: handle_(handle),
driver_fn_(driver_fn),
event_(new ExternalEvent(driver_fn, handle->event)) {}
std::shared_ptr<Event> OnReady() override { return event_; }
int64_t size_in_bytes() override {
LOG(FATAL) << "Unimplemented.";
@ -120,22 +123,42 @@ class ExternalCompiledProgramHandle : public CompiledProgramHandle {
}
xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
LOG(FATAL) << "Unimplemented.";
return xla::Unimplemented("%s", "Unimplemented.");
struct CompiledProgramShape* shape =
driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
program_shape->ParseFromArray(shape->bytes, shape->size);
auto status = xla::Status(tensorflow::error::Code(shape->status->code),
absl::StrFormat("%s", shape->status->msg));
driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
return status;
}
private:
::TpuCompiledProgramHandle* handle_;
::TpuDriverFn* driver_fn_;
std::shared_ptr<ExternalEvent> event_;
friend ExternalTpuDriver;
};
class ExternalLoadedProgramHandle : public LoadedProgramHandle {
public:
std::shared_ptr<Event> OnReady() override {
LOG(FATAL) << "Unimplemented";
return std::shared_ptr<Event>();
}
explicit ExternalLoadedProgramHandle(::TpuDriverFn* driver_fn,
::TpuLoadedProgramHandle* handle)
: handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {}
std::shared_ptr<Event> OnReady() override { return event_; }
int64_t size_in_bytes() override {
LOG(FATAL) << "Unimplemented.";
return 0;
}
private:
::TpuLoadedProgramHandle* handle_;
std::shared_ptr<ExternalEvent> event_;
friend ExternalTpuDriver;
};
class ExternalTpuDriver : public TpuDriver {
@ -246,28 +269,93 @@ class ExternalTpuDriver : public TpuDriver {
std::unique_ptr<CompiledProgramHandle> CompileProgram(
const xla::HloProto& source, int32_t num_replicas,
absl::Span<Event* const> wait_for) override {
LOG(FATAL) << "Unimplemented.";
return nullptr;
auto tpu_events = MakeEventArray(wait_for);
struct HloProto hlo;
hlo.size = source.ByteSizeLong();
hlo.bytes = malloc(hlo.size);
if (!source.SerializeToArray(hlo.bytes, hlo.size)) {
LOG(ERROR) << "Unable to serialize HLO to array.";
return nullptr;
}
auto handle = absl::make_unique<ExternalCompiledProgramHandle>(
&driver_fn_,
driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas,
wait_for.size(), tpu_events));
free(hlo.bytes);
delete tpu_events;
return handle;
}
std::unique_ptr<LoadedProgramHandle> LoadProgram(
int32_t core_id, const CompiledProgramHandle* handle,
absl::Span<Event* const> wait_for) override {
LOG(FATAL) << "Unimplemented.";
return nullptr;
auto tpu_events = MakeEventArray(wait_for);
auto loaded_handle = absl::make_unique<ExternalLoadedProgramHandle>(
&driver_fn_,
driver_fn_.TpuDriver_LoadProgram(
driver_, core_id,
static_cast<const ExternalCompiledProgramHandle*>(handle)->handle_,
wait_for.size(), tpu_events));
delete tpu_events;
return loaded_handle;
}
std::shared_ptr<Event> UnloadProgram(
std::unique_ptr<LoadedProgramHandle> handle,
absl::Span<Event* const> wait_for) override {
LOG(FATAL) << "Unimplemented.";
return nullptr;
auto tpu_events = MakeEventArray(wait_for);
auto event = std::make_shared<ExternalEvent>(
&driver_fn_,
driver_fn_.TpuDriver_UnloadProgram(
driver_,
static_cast<ExternalLoadedProgramHandle*>(handle.get())->handle_,
wait_for.size(), tpu_events));
delete tpu_events;
return event;
}
std::shared_ptr<Event> ExecuteProgram(
LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
absl::Span<BufferHandle* const> outputs,
const xla::DeviceAssignmentProto& device_assignment,
absl::Span<Event* const> wait_for) override {
LOG(FATAL) << "Unimplemented.";
return nullptr;
auto tpu_events = MakeEventArray(wait_for);
struct DeviceAssignmentProto da_proto;
da_proto.size = device_assignment.ByteSizeLong();
da_proto.bytes = malloc(da_proto.size);
if (!device_assignment.SerializeToArray(da_proto.bytes, da_proto.size)) {
LOG(ERROR) << "Unable to serialize device assignment to array.";
return nullptr;
}
std::vector<::TpuBufferHandle*> inputv;
inputv.reserve(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
inputv.push_back(
static_cast<ExternalBufferHandle* const>(inputs[i])->handle_);
}
std::vector<::TpuBufferHandle*> outputv;
outputv.reserve(outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputv.push_back(
static_cast<ExternalBufferHandle* const>(outputs[i])->handle_);
}
auto event = std::make_shared<ExternalEvent>(
&driver_fn_,
driver_fn_.TpuDriver_ExecuteProgram(
driver_,
static_cast<ExternalLoadedProgramHandle*>(program)->handle_,
inputs.size(), inputv.data(), outputs.size(), outputv.data(),
da_proto, wait_for.size(), tpu_events));
free(da_proto.bytes);
return event;
}
std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }