From 5078cab51cea45364602f9d2a1d30057799af4b0 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 26 Dec 2019 11:01:18 -0800 Subject: [PATCH] Hook out compile, execute, load/unload APIs to external TPU driver PiperOrigin-RevId: 287204382 Change-Id: Ia1ae8ec3179c09e393aec69041a41c241177efe7 --- .../xla/python/tpu_driver/client/c_api.h | 23 +++- .../python/tpu_driver/external_tpu_driver.cc | 124 +++++++++++++++--- 2 files changed, 128 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h index 1558d9c5580..228128c62e1 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h @@ -68,6 +68,12 @@ typedef struct TpuStatus { char* msg; } TpuStatus; +typedef struct CompiledProgramShape { + struct TpuStatus* status; + void* bytes; + int32_t size; +} CompiledProgramShape; + typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn); typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker); typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver); @@ -85,7 +91,7 @@ typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)( int32_t eventc, struct TpuEvent** eventv); typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)( - struct TpuDriver* driver, int32_t core_id, + struct TpuDriver* driver, struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc, struct TpuEvent** eventv); @@ -121,6 +127,13 @@ typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)( struct TpuDriver* driver, struct TpuBufferHandle* src, struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv); +typedef struct CompiledProgramShape*( + PrototypeTpuDriver_GetCompiledProgramShape)( + struct TpuCompiledProgramHandle* handle); + +typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)( + struct CompiledProgramShape* shape); + typedef void(PrototypeTpuDriver_EventAddCallback)( struct TpuEvent* event, void (*callback_fn)(struct TpuStatus*, void* additional_info), @@ -156,6 +169,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice TpuDriver_TransferFromDevice; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice TpuDriver_TransferFromDeviceToDevice; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape + TpuDriver_GetCompiledProgramShape; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape + TpuDriver_FreeCompiledProgramShape; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback TpuDriver_EventAddCallback; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait; @@ -182,6 +199,10 @@ struct TpuDriverFn { TpuDriver_TransferFromDevice; // NOLINT PrototypeTpuDriver_TransferFromDeviceToDevice* TpuDriver_TransferFromDeviceToDevice; // NOLINT + PrototypeTpuDriver_GetCompiledProgramShape* + TpuDriver_GetCompiledProgramShape; // NOLINT + PrototypeTpuDriver_FreeCompiledProgramShape* + TpuDriver_FreeCompiledProgramShape; // NOLINT PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT diff --git a/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc index 84b25251074..8a8e868b2b8 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc @@ -109,10 +109,13 @@ class ExternalBufferHandle : public BufferHandle { class ExternalCompiledProgramHandle : public CompiledProgramHandle { public: - std::shared_ptr OnReady() override { - LOG(FATAL) << "Unimplemented"; - return std::shared_ptr(); - } + explicit ExternalCompiledProgramHandle(::TpuDriverFn* driver_fn, + ::TpuCompiledProgramHandle* handle) + : handle_(handle), + driver_fn_(driver_fn), + event_(new ExternalEvent(driver_fn, handle->event)) {} + + std::shared_ptr OnReady() override { return event_; } int64_t size_in_bytes() override { LOG(FATAL) << "Unimplemented."; @@ -120,22 +123,42 @@ class ExternalCompiledProgramHandle : public CompiledProgramHandle { } xla::Status program_shape(xla::ProgramShapeProto* program_shape) override { - LOG(FATAL) << "Unimplemented."; - return xla::Unimplemented("%s", "Unimplemented."); + struct CompiledProgramShape* shape = + driver_fn_->TpuDriver_GetCompiledProgramShape(handle_); + program_shape->ParseFromArray(shape->bytes, shape->size); + + auto status = xla::Status(tensorflow::error::Code(shape->status->code), + absl::StrFormat("%s", shape->status->msg)); + driver_fn_->TpuDriver_FreeCompiledProgramShape(shape); + + return status; } + + private: + ::TpuCompiledProgramHandle* handle_; + ::TpuDriverFn* driver_fn_; + std::shared_ptr event_; + + friend ExternalTpuDriver; }; class ExternalLoadedProgramHandle : public LoadedProgramHandle { public: - std::shared_ptr OnReady() override { - LOG(FATAL) << "Unimplemented"; - return std::shared_ptr(); - } + explicit ExternalLoadedProgramHandle(::TpuDriverFn* driver_fn, + ::TpuLoadedProgramHandle* handle) + : handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {} + std::shared_ptr OnReady() override { return event_; } int64_t size_in_bytes() override { LOG(FATAL) << "Unimplemented."; return 0; } + + private: + ::TpuLoadedProgramHandle* handle_; + std::shared_ptr event_; + + friend ExternalTpuDriver; }; class ExternalTpuDriver : public TpuDriver { @@ -246,28 +269,93 @@ class ExternalTpuDriver : public TpuDriver { std::unique_ptr CompileProgram( const xla::HloProto& source, int32_t num_replicas, absl::Span wait_for) override { - LOG(FATAL) << "Unimplemented."; - return nullptr; + auto tpu_events = MakeEventArray(wait_for); + + struct HloProto hlo; + hlo.size = source.ByteSizeLong(); + hlo.bytes = malloc(hlo.size); + if (!source.SerializeToArray(hlo.bytes, hlo.size)) { + LOG(ERROR) << "Unable to serialize HLO to array."; + return nullptr; + } + + auto handle = absl::make_unique( + &driver_fn_, + driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas, + wait_for.size(), tpu_events)); + + free(hlo.bytes); + delete tpu_events; + return handle; } std::unique_ptr LoadProgram( int32_t core_id, const CompiledProgramHandle* handle, absl::Span wait_for) override { - LOG(FATAL) << "Unimplemented."; - return nullptr; + auto tpu_events = MakeEventArray(wait_for); + + auto loaded_handle = absl::make_unique( + &driver_fn_, + driver_fn_.TpuDriver_LoadProgram( + driver_, core_id, + static_cast(handle)->handle_, + wait_for.size(), tpu_events)); + + delete tpu_events; + return loaded_handle; } + std::shared_ptr UnloadProgram( std::unique_ptr handle, absl::Span wait_for) override { - LOG(FATAL) << "Unimplemented."; - return nullptr; + auto tpu_events = MakeEventArray(wait_for); + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_UnloadProgram( + driver_, + static_cast(handle.get())->handle_, + wait_for.size(), tpu_events)); + delete tpu_events; + return event; } + std::shared_ptr ExecuteProgram( LoadedProgramHandle* program, absl::Span inputs, absl::Span outputs, const xla::DeviceAssignmentProto& device_assignment, absl::Span wait_for) override { - LOG(FATAL) << "Unimplemented."; - return nullptr; + auto tpu_events = MakeEventArray(wait_for); + + struct DeviceAssignmentProto da_proto; + da_proto.size = device_assignment.ByteSizeLong(); + da_proto.bytes = malloc(da_proto.size); + if (!device_assignment.SerializeToArray(da_proto.bytes, da_proto.size)) { + LOG(ERROR) << "Unable to serialize device assignment to array."; + return nullptr; + } + + std::vector<::TpuBufferHandle*> inputv; + inputv.reserve(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + inputv.push_back( + static_cast(inputs[i])->handle_); + } + std::vector<::TpuBufferHandle*> outputv; + outputv.reserve(outputs.size()); + for (int i = 0; i < outputs.size(); i++) { + outputv.push_back( + static_cast(outputs[i])->handle_); + } + + auto event = std::make_shared( + &driver_fn_, + driver_fn_.TpuDriver_ExecuteProgram( + driver_, + static_cast(program)->handle_, + inputs.size(), inputv.data(), outputs.size(), outputv.data(), + da_proto, wait_for.size(), tpu_events)); + + free(da_proto.bytes); + return event; } std::unique_ptr GetLinearizer() override { return nullptr; }