Added support of int types to Texture2DDescriptor.

PiperOrigin-RevId: 333550534
Change-Id: Ic580aa6fd19ec9fa20563c2018f8833a03ac18bc
This commit is contained in:
Raman Sarokin 2020-09-24 10:56:23 -07:00 committed by TensorFlower Gardener
parent c31acd156c
commit ff9594c4f7
5 changed files with 81 additions and 19 deletions

View File

@ -204,8 +204,9 @@ absl::Status LinearStorage::CreateFromTensorLinearDescriptor(
return CreateCLBuffer(context->context(), depth_ * float4_size, read_only,
data_ptr, &memory_);
} else {
return CreateFloatRGBAImage2D(context->context(), depth_, 1,
desc.element_type, data_ptr, &memory_);
return CreateRGBAImage2D(context->context(), depth_, 1,
DataTypeToChannelType(desc.element_type), data_ptr,
&memory_);
}
}

View File

@ -24,10 +24,9 @@ namespace {
absl::Status CreateTexture2D(int width, int height, DataType type, void* data,
CLContext* context, Texture2D* result) {
cl_mem texture;
RETURN_IF_ERROR(CreateFloatRGBAImage2D(context->context(), width, height,
type, data, &texture));
cl_channel_type channel_type =
type == DataType::FLOAT32 ? CL_FLOAT : CL_HALF_FLOAT;
cl_channel_type channel_type = DataTypeToChannelType(type);
RETURN_IF_ERROR(CreateRGBAImage2D(context->context(), width, height,
channel_type, data, &texture));
*result = Texture2D(texture, width, height, channel_type);
return absl::OkStatus();
@ -37,6 +36,8 @@ absl::Status CreateTexture2D(int width, int height, DataType type, void* data,
Texture2DDescriptor::Texture2DDescriptor(Texture2DDescriptor&& desc)
: GPUObjectDescriptor(std::move(desc)),
element_type(desc.element_type),
normalized(desc.normalized),
normalized_type(desc.normalized_type),
size(desc.size),
data(std::move(desc.data)) {}
@ -44,6 +45,8 @@ Texture2DDescriptor& Texture2DDescriptor::operator=(
Texture2DDescriptor&& desc) {
if (this != &desc) {
std::swap(element_type, desc.element_type);
std::swap(normalized, desc.normalized);
std::swap(normalized_type, desc.normalized_type);
std::swap(size, desc.size);
data = std::move(desc.data);
GPUObjectDescriptor::operator=(std::move(desc));
@ -80,8 +83,38 @@ absl::Status Texture2DDescriptor::PerformReadSelector(
absl::StrCat("Texture2DDescriptor Read require two arguments, but ",
args.size(), " was passed"));
}
const std::string read =
element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
std::string read;
switch (element_type) {
case DataType::FLOAT32:
read = "read_imagef";
break;
case DataType::FLOAT16:
read = "read_imageh";
break;
case DataType::INT8:
case DataType::INT16:
case DataType::INT32:
if (normalized) {
read = normalized_type == DataType::FLOAT16 ? "read_imageh"
: "read_imagef";
} else {
read = "read_imagei";
}
break;
case DataType::UINT8:
case DataType::UINT16:
case DataType::UINT32:
if (normalized) {
read = normalized_type == DataType::FLOAT16 ? "read_imageh"
: "read_imagef";
} else {
read = "read_imageui";
}
break;
default:
read = "unknown_type";
break;
}
*result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0],
", " + args[1] + "))");
return absl::OkStatus();
@ -145,13 +178,12 @@ absl::Status Texture2D::CreateFromTexture2DDescriptor(
const Texture2DDescriptor& desc, CLContext* context) {
width_ = desc.size.x;
height_ = desc.size.y;
channel_type_ =
desc.element_type == DataType::FLOAT32 ? CL_FLOAT : CL_HALF_FLOAT;
channel_type_ = DataTypeToChannelType(desc.element_type, desc.normalized);
uint8_t* data_ptr = desc.data.empty()
? nullptr
: const_cast<unsigned char*>(desc.data.data());
return CreateFloatRGBAImage2D(context->context(), desc.size.x, desc.size.y,
desc.element_type, data_ptr, &texture_);
return CreateRGBAImage2D(context->context(), desc.size.x, desc.size.y,
channel_type_, data_ptr, &texture_);
}
// Creates new 4-channel 2D texture with f32 elements

View File

@ -32,7 +32,11 @@ namespace gpu {
namespace cl {
struct Texture2DDescriptor : public GPUObjectDescriptor {
DataType element_type; // FLOAT32 or FLOAT16
DataType element_type;
bool normalized = false; // used with INT data types, if normalized, we read
// in kernel float data.
DataType normalized_type; // can be FLOAT32 or FLOAT16, using with normalized
// = true
// optional
int2 size = int2(0, 0);

View File

@ -184,8 +184,32 @@ absl::Status CreateCLBuffer(cl_context context, int size_in_bytes,
return absl::OkStatus();
}
absl::Status CreateFloatRGBAImage2D(cl_context context, int width, int height,
DataType type, void* data, cl_mem* result) {
cl_channel_type DataTypeToChannelType(DataType type, bool normalized) {
switch (type) {
case DataType::FLOAT32:
return CL_FLOAT;
case DataType::FLOAT16:
return CL_HALF_FLOAT;
case DataType::INT8:
return normalized ? CL_SNORM_INT8 : CL_SIGNED_INT8;
case DataType::UINT8:
return normalized ? CL_UNORM_INT8 : CL_UNSIGNED_INT8;
case DataType::INT16:
return normalized ? CL_SNORM_INT16 : CL_SIGNED_INT16;
case DataType::UINT16:
return normalized ? CL_UNORM_INT16 : CL_UNSIGNED_INT16;
case DataType::INT32:
return CL_SIGNED_INT32;
case DataType::UINT32:
return CL_UNSIGNED_INT32;
default:
return CL_FLOAT;
}
}
absl::Status CreateRGBAImage2D(cl_context context, int width, int height,
cl_channel_type channel_type, void* data,
cl_mem* result) {
cl_image_desc desc;
desc.image_type = CL_MEM_OBJECT_IMAGE2D;
desc.image_width = width;
@ -199,8 +223,7 @@ absl::Status CreateFloatRGBAImage2D(cl_context context, int width, int height,
cl_image_format format;
format.image_channel_order = CL_RGBA;
format.image_channel_data_type =
type == DataType::FLOAT32 ? CL_FLOAT : CL_HALF_FLOAT;
format.image_channel_data_type = channel_type;
cl_mem_flags flags = CL_MEM_READ_WRITE;
if (data) {

View File

@ -52,8 +52,10 @@ void CopyLinearFLT4(const tflite::gpu::Tensor<Linear, S>& src,
absl::Status CreateCLBuffer(cl_context context, int size_in_bytes,
bool read_only, void* data, cl_mem* result);
absl::Status CreateFloatRGBAImage2D(cl_context context, int width, int height,
DataType type, void* data, cl_mem* result);
cl_channel_type DataTypeToChannelType(DataType type, bool normalized = false);
absl::Status CreateRGBAImage2D(cl_context context, int width, int height,
cl_channel_type channel_type, void* data,
cl_mem* result);
} // namespace cl
} // namespace gpu