Added support of Windows to opencl_wrapper.

PiperOrigin-RevId: 316169152
Change-Id: Idd2a61a80c04fcee610cec0d4a48090ac9d0e967
This commit is contained in:
Raman Sarokin 2020-06-12 13:43:03 -07:00 committed by TensorFlower Gardener
parent bdbf103c4d
commit a4c8a190f8
2 changed files with 35 additions and 1 deletions
tensorflow/lite/delegates/gpu/cl

View File

@ -15,7 +15,15 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
#if defined(_WIN32)
#define __WINDOWS__
#endif
#ifdef __WINDOWS__
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include <string>
@ -33,12 +41,34 @@ namespace cl {
} else { \
function = reinterpret_cast<PFN_##function>(dlsym(libopencl, #function)); \
}
#elif defined(__WINDOWS__)
#define LoadFunction(function) \
function = \
reinterpret_cast<PFN_##function>(GetProcAddress(libopencl, #function));
#else
#define LoadFunction(function) \
function = reinterpret_cast<PFN_##function>(dlsym(libopencl, #function));
#endif
#ifdef __WINDOWS__
void LoadOpenCLFunctions(HMODULE libopencl);
#else
void LoadOpenCLFunctions(void* libopencl, bool is_pixel);
#endif
absl::Status LoadOpenCL() {
#ifdef __WINDOWS__
HMODULE libopencl = LoadLibraryA("OpenCL.dll");
if (libopencl) {
LoadOpenCLFunctions(libopencl);
return absl::OkStatus();
} else {
DWORD error_code = GetLastError();
return absl::UnknownError(absl::StrCat(
"Can not open OpenCL library on this device, error code - ",
error_code));
}
#else
void* libopencl = dlopen("libOpenCL.so", RTLD_NOW | RTLD_LOCAL);
if (libopencl) {
LoadOpenCLFunctions(libopencl, false);
@ -60,8 +90,12 @@ absl::Status LoadOpenCL() {
#endif
return absl::UnknownError(
absl::StrCat("Can not open OpenCL library on this device - ", error));
#endif
}
#ifdef __WINDOWS__
void LoadOpenCLFunctions(HMODULE libopencl) {
#else
void LoadOpenCLFunctions(void* libopencl, bool is_pixel) {
#ifdef __ANDROID__
typedef void* (*loadOpenCLPointer_t)(const char* name);
@ -70,6 +104,7 @@ void LoadOpenCLFunctions(void* libopencl, bool is_pixel) {
loadOpenCLPointer = reinterpret_cast<loadOpenCLPointer_t>(
dlsym(libopencl, "loadOpenCLPointer"));
}
#endif
#endif
LoadFunction(clGetPlatformIDs);

View File

@ -28,7 +28,6 @@ namespace gpu {
namespace cl {
absl::Status LoadOpenCL();
void LoadOpenCLFunctions(void *libopencl, bool is_pixel);
typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)(
cl_uint /* num_entries */, cl_platform_id * /* platforms */,