diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h index 21107113f67..d282724eda3 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h @@ -74,6 +74,11 @@ typedef struct CompiledProgramShape { int32_t size; } CompiledProgramShape; +typedef struct TpuAllocationShape { + void* bytes; + int32_t size; +} TpuAllocationShape; + typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn); typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker); typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver); @@ -81,6 +86,17 @@ typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver); // TODO(frankchn): Make this not a hard-coded constant. const int32_t MemoryRegion_HBM = 1; +typedef int64_t(PrototypeTpuDriver_ComputeLinearizedBytesFromShape)( + struct TpuDriver* driver, const struct TpuAllocationShape shape); + +typedef struct TpuStatus*(PrototypeTpuDriver_LinearizeShape)( + struct TpuDriver* driver, void* dst, const void* src, + const struct TpuAllocationShape shape); + +typedef struct TpuStatus*(PrototypeTpuDriver_DelinearizeShape)( + struct TpuDriver* driver, void* dst, const void* src, + const struct TpuAllocationShape shape); + typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)( struct TpuDriver* driver, const struct HloProto hlo_proto, int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv); @@ -118,6 +134,11 @@ typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)( struct TpuDriver* driver, int32_t core_id, int32_t memory_region, int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv); +typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateShape)( + struct TpuDriver* driver, int32_t core_id, int32_t memory_region, + const struct TpuAllocationShape shape, int32_t eventc, + struct TpuEvent** eventv); + typedef struct TpuEvent*(PrototypeTpuDriver_Deallocate)( struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle, int32_t eventc, struct TpuEvent** eventv); @@ -158,6 +179,12 @@ typedef const char*(PrototypeTpuDriver_Version)(); TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ComputeLinearizedBytesFromShape + TpuDriver_ComputeLinearizedBytesFromShape; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LinearizeShape + TpuDriver_LinearizeShape; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_DelinearizeShape + TpuDriver_DelinearizeShape; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram TpuDriver_CompileProgram; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgramFromText @@ -171,6 +198,8 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple TpuDriver_AllocateTuple; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate; +TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateShape + TpuDriver_AllocateShape; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate; TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice TpuDriver_TransferToDevice; @@ -196,6 +225,10 @@ TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version; struct TpuDriverFn { PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT + PrototypeTpuDriver_ComputeLinearizedBytesFromShape* + TpuDriver_ComputeLinearizedBytesFromShape; // NOLINT + PrototypeTpuDriver_LinearizeShape* TpuDriver_LinearizeShape; // NOLINT + PrototypeTpuDriver_DelinearizeShape* TpuDriver_DelinearizeShape; // NOLINT PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT PrototypeTpuDriver_CompileProgramFromText* TpuDriver_CompileProgramFromText; // NOLINT @@ -204,6 +237,7 @@ struct TpuDriverFn { PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT + PrototypeTpuDriver_AllocateShape* TpuDriver_AllocateShape; // NOLINT PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT PrototypeTpuDriver_TransferFromDevice* diff --git a/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc index f513941a2b3..6744664c621 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc @@ -27,6 +27,19 @@ namespace tpu_driver { namespace { +::TpuAllocationShape GetTpuAllocationShape(const xla::ShapeProto& shape) { + ::TpuAllocationShape shape_; + shape_.size = shape.ByteSizeLong(); + shape_.bytes = malloc(shape_.size); + if (!shape.SerializeToArray(shape_.bytes, shape_.size)) { + LOG(ERROR) << "Unable to serialize shape to array."; + free(shape_.bytes); + shape_.size = 0; + shape_.bytes = nullptr; + } + return shape_; +} + class ExternalTpuDriver; class ExternalEvent : public Event { @@ -161,6 +174,51 @@ class ExternalLoadedProgramHandle : public LoadedProgramHandle { friend ExternalTpuDriver; }; +class ExternalTpuLinearizer : public TpuLinearizer { + public: + explicit ExternalTpuLinearizer(::TpuDriver* driver, ::TpuDriverFn* driver_fn) + : driver_(driver), driver_fn_(driver_fn) {} + + int64_t ComputeLinearizedBytesFromShape( + const xla::ShapeProto& shape) override { + ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape); + uint64_t size = + driver_fn_->TpuDriver_ComputeLinearizedBytesFromShape(driver_, shape_); + free(shape_.bytes); + return size; + } + + xla::Status LinearizeShape(void* dst, const void* src, + const xla::ShapeProto& shape) override { + ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape); + + auto tpu_status = + driver_fn_->TpuDriver_LinearizeShape(driver_, dst, src, shape_); + auto status = xla::Status(tensorflow::error::Code(tpu_status->code), + absl::StrFormat("%s", tpu_status->msg)); + driver_fn_->TpuDriver_FreeStatus(tpu_status); + free(shape_.bytes); + return status; + } + + xla::Status DelinearizeShape(void* dst, const void* src, + const xla::ShapeProto& shape) override { + ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape); + + auto tpu_status = + driver_fn_->TpuDriver_DelinearizeShape(driver_, dst, src, shape_); + auto status = xla::Status(tensorflow::error::Code(tpu_status->code), + absl::StrFormat("%s", tpu_status->msg)); + driver_fn_->TpuDriver_FreeStatus(tpu_status); + free(shape_.bytes); + return status; + } + + private: + ::TpuDriver* driver_; + ::TpuDriverFn* driver_fn_; +}; + class ExternalTpuDriver : public TpuDriver { public: explicit ExternalTpuDriver(const std::string& so_path) { @@ -201,8 +259,17 @@ class ExternalTpuDriver : public TpuDriver { std::unique_ptr Allocate( int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, absl::Span wait_for) override { - LOG(FATAL) << "Unimplemented."; - return nullptr; + auto tpu_events = MakeEventArray(wait_for); + + ::TpuAllocationShape shape_ = GetTpuAllocationShape(shape); + auto bh = absl::make_unique( + &driver_fn_, + driver_fn_.TpuDriver_AllocateShape(driver_, core_id, region, shape_, + wait_for.size(), tpu_events)); + + free(shape_.bytes); + delete[] tpu_events; + return bh; } std::unique_ptr AllocateTuple( @@ -366,7 +433,9 @@ class ExternalTpuDriver : public TpuDriver { return event; } - std::unique_ptr GetLinearizer() override { return nullptr; } + std::unique_ptr GetLinearizer() override { + return std::make_unique(driver_, &driver_fn_); + } private: ::TpuDriverFn driver_fn_; diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc index 1920cf75e26..ecf70b56c14 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc @@ -33,7 +33,7 @@ DriverRegistryMap* GetDriverRegistryMap() { return driver_registry; } -uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) { +int64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) { switch (primitive_type) { case xla::PrimitiveType::PRED: return sizeof(int8_t); @@ -96,12 +96,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) { config.worker()); } -uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { +int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { if (shape.tuple_shapes_size() > 0) { LOG(FATAL) << "Tuples are not supported at the moment."; } - uint64_t num_elems = 1; + int64_t num_elems = 1; for (auto dim : shape.dimensions()) { num_elems *= dim; } diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h index dc28ad1f0b4..9127f0342fa 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h @@ -42,7 +42,7 @@ namespace tpu_driver { -uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape); +int64_t ComputeBytesFromShape(const xla::ShapeProto& shape); // Represents the deferred completion of a scheduled operation. // @@ -120,10 +120,10 @@ class TpuLinearizer { public: virtual ~TpuLinearizer() {} - uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { + int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { return ::tpu_driver::ComputeBytesFromShape(shape); } - virtual uint64_t ComputeLinearizedBytesFromShape( + virtual int64_t ComputeLinearizedBytesFromShape( const xla::ShapeProto& shape) = 0; virtual xla::Status LinearizeShape(void* dst, const void* src,