DataType utility function moved from cl/kernels/util to data_type.
PiperOrigin-RevId: 272046539
This commit is contained in:
parent
99e48c99d1
commit
7a9b0ec9bd
@ -87,11 +87,13 @@ class FromTensorConverter : public OpenClConverterImpl {
|
|||||||
const TensorObjectDef& input_def,
|
const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def) const {
|
const TensorObjectDef& output_def) const {
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
"__global " + GetDataType4(output_def.object_def.data_type) + "* dst",
|
"__global " + ToCLDataType(output_def.object_def.data_type, 4) +
|
||||||
|
"* dst",
|
||||||
"dst[(d * size.y + y) * size.x + x] = " +
|
"dst[(d * size.y + y) * size.x + x] = " +
|
||||||
(output_def.object_def.data_type == input_def.object_def.data_type
|
(output_def.object_def.data_type == input_def.object_def.data_type
|
||||||
? "input;"
|
? "input;"
|
||||||
: "convert_" + GetDataType4(output_def.object_def.data_type) +
|
: "convert_" +
|
||||||
|
ToCLDataType(output_def.object_def.data_type, 4) +
|
||||||
"(input);"));
|
"(input);"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,7 +101,7 @@ class FromTensorConverter : public OpenClConverterImpl {
|
|||||||
const TensorObjectDef& input_def,
|
const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def) const {
|
const TensorObjectDef& output_def) const {
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
"__global " + GetDataType(output_def.object_def.data_type) + "* dst",
|
"__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
|
||||||
R"(
|
R"(
|
||||||
int c = d * 4;
|
int c = d * 4;
|
||||||
int index = (y * size.x + x) * size.z + c;
|
int index = (y * size.x + x) * size.z + c;
|
||||||
@ -143,7 +145,7 @@ __kernel void from_tensor()" +
|
|||||||
int y = get_global_id(1);
|
int y = get_global_id(1);
|
||||||
int d = get_global_id(2);
|
int d = get_global_id(2);
|
||||||
if (x >= size.x || y >= size.y || d >= size.w) return;
|
if (x >= size.x || y >= size.y || d >= size.w) return;
|
||||||
)" + GetDataType4(input_def.object_def.data_type) +
|
)" + ToCLDataType(input_def.object_def.data_type, 4) +
|
||||||
" input = " + src_tensor.Read3D("x", "y", "d") + ";\n" +
|
" input = " + src_tensor.Read3D("x", "y", "d") + ";\n" +
|
||||||
params_kernel.second + "\n}";
|
params_kernel.second + "\n}";
|
||||||
queue_ = environment->queue();
|
queue_ = environment->queue();
|
||||||
@ -199,11 +201,11 @@ class ToTensorConverter : public OpenClConverterImpl {
|
|||||||
const TensorObjectDef& input_def,
|
const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def) const {
|
const TensorObjectDef& output_def) const {
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
"__global " + GetDataType4(input_def.object_def.data_type) + "* src",
|
"__global " + ToCLDataType(input_def.object_def.data_type, 4) + "* src",
|
||||||
output_def.object_def.data_type == input_def.object_def.data_type
|
output_def.object_def.data_type == input_def.object_def.data_type
|
||||||
? "result = src[(d * size.y + y) * size.x + x];"
|
? "result = src[(d * size.y + y) * size.x + x];"
|
||||||
: "result = convert_" +
|
: "result = convert_" +
|
||||||
GetDataType4(output_def.object_def.data_type) +
|
ToCLDataType(output_def.object_def.data_type, 4) +
|
||||||
"(src[(d * size.y + y) * size.x + x]);");
|
"(src[(d * size.y + y) * size.x + x]);");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +213,7 @@ class ToTensorConverter : public OpenClConverterImpl {
|
|||||||
const TensorObjectDef& input_def,
|
const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def) const {
|
const TensorObjectDef& output_def) const {
|
||||||
return std::make_pair(
|
return std::make_pair(
|
||||||
"__global " + GetDataType(input_def.object_def.data_type) + "* src",
|
"__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
|
||||||
R"(int c = d * 4;
|
R"(int c = d * 4;
|
||||||
int index = (y * size.x + x) * size.z + c;
|
int index = (y * size.x + x) * size.z + c;
|
||||||
result.x = src[index];
|
result.x = src[index];
|
||||||
@ -246,7 +248,7 @@ __kernel void to_tensor()" +
|
|||||||
int d = get_global_id(2);
|
int d = get_global_id(2);
|
||||||
|
|
||||||
if (x >= size.x || y >= size.y || d >= size.w) return;
|
if (x >= size.x || y >= size.y || d >= size.w) return;
|
||||||
)" + GetDataType4(output_def.object_def.data_type) +
|
)" + ToCLDataType(output_def.object_def.data_type, 4) +
|
||||||
" result;\n" + params_kernel.second + "\n " +
|
" result;\n" + params_kernel.second + "\n " +
|
||||||
dst_tensor.Write3D("result", "x", "y", "d") + ";\n}";
|
dst_tensor.Write3D("result", "x", "y", "d") + ";\n}";
|
||||||
queue_ = environment->queue();
|
queue_ = environment->queue();
|
||||||
|
@ -125,19 +125,6 @@ std::string GetCommonDefines(CalculationsPrecision precision) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetDataType(DataType type) {
|
|
||||||
switch (type) {
|
|
||||||
case DataType::FLOAT16:
|
|
||||||
return "half";
|
|
||||||
case DataType::FLOAT32:
|
|
||||||
return "float";
|
|
||||||
default:
|
|
||||||
return "error";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetDataType4(DataType type) { return GetDataType(type) + "4"; }
|
|
||||||
|
|
||||||
TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
|
TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
|
||||||
const std::string& uniform_size_name,
|
const std::string& uniform_size_name,
|
||||||
const TensorDescriptor& descriptor)
|
const TensorDescriptor& descriptor)
|
||||||
@ -148,7 +135,7 @@ TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
|
|||||||
std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
|
std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
|
||||||
switch (descriptor_.storage_type) {
|
switch (descriptor_.storage_type) {
|
||||||
case TensorStorageType::BUFFER:
|
case TensorStorageType::BUFFER:
|
||||||
return absl::StrCat("__global ", GetDataType4(descriptor_.data_type),
|
return absl::StrCat("__global ", ToCLDataType(descriptor_.data_type, 4),
|
||||||
"* ", tensor_name_);
|
"* ", tensor_name_);
|
||||||
case TensorStorageType::TEXTURE_2D:
|
case TensorStorageType::TEXTURE_2D:
|
||||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||||
@ -157,7 +144,7 @@ std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
|
|||||||
return GetImageModifier(access_type) + " image2d_array_t " + tensor_name_;
|
return GetImageModifier(access_type) + " image2d_array_t " + tensor_name_;
|
||||||
case TensorStorageType::IMAGE_BUFFER:
|
case TensorStorageType::IMAGE_BUFFER:
|
||||||
if (access_type == AccessType::WRITE) {
|
if (access_type == AccessType::WRITE) {
|
||||||
return absl::StrCat("__global ", GetDataType4(descriptor_.data_type),
|
return absl::StrCat("__global ", ToCLDataType(descriptor_.data_type, 4),
|
||||||
"* ", tensor_name_);
|
"* ", tensor_name_);
|
||||||
} else {
|
} else {
|
||||||
return GetImageModifier(access_type) + " image1d_buffer_t " +
|
return GetImageModifier(access_type) + " image1d_buffer_t " +
|
||||||
|
@ -36,9 +36,6 @@ namespace cl {
|
|||||||
|
|
||||||
std::string GetCommonDefines(CalculationsPrecision precision);
|
std::string GetCommonDefines(CalculationsPrecision precision);
|
||||||
|
|
||||||
std::string GetDataType(DataType type);
|
|
||||||
std::string GetDataType4(DataType type);
|
|
||||||
|
|
||||||
enum class TextureAddressMode {
|
enum class TextureAddressMode {
|
||||||
DONT_CARE, // translated to CLK_ADDRESS_NONE
|
DONT_CARE, // translated to CLK_ADDRESS_NONE
|
||||||
ZERO, // translated to CLK_ADDRESS_CLAMP
|
ZERO, // translated to CLK_ADDRESS_CLAMP
|
||||||
|
@ -74,5 +74,36 @@ std::string ToString(DataType data_type) {
|
|||||||
return "undefined";
|
return "undefined";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ToCLDataType(DataType data_type, int vec_size) {
|
||||||
|
const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::FLOAT16:
|
||||||
|
return "half" + postfix;
|
||||||
|
case DataType::FLOAT32:
|
||||||
|
return "float" + postfix;
|
||||||
|
case DataType::FLOAT64:
|
||||||
|
return "double" + postfix;
|
||||||
|
case DataType::INT16:
|
||||||
|
return "short" + postfix;
|
||||||
|
case DataType::INT32:
|
||||||
|
return "int" + postfix;
|
||||||
|
case DataType::INT64:
|
||||||
|
return "long" + postfix;
|
||||||
|
case DataType::INT8:
|
||||||
|
return "char" + postfix;
|
||||||
|
case DataType::UINT16:
|
||||||
|
return "ushort" + postfix;
|
||||||
|
case DataType::UINT32:
|
||||||
|
return "uint" + postfix;
|
||||||
|
case DataType::UINT64:
|
||||||
|
return "ulong" + postfix;
|
||||||
|
case DataType::UINT8:
|
||||||
|
return "uchar" + postfix;
|
||||||
|
case DataType::UNKNOWN:
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
return "undefined";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -41,6 +41,8 @@ size_t SizeOf(DataType type);
|
|||||||
|
|
||||||
std::string ToString(DataType t);
|
std::string ToString(DataType t);
|
||||||
|
|
||||||
|
std::string ToCLDataType(DataType data_type, int vec_size = 1);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user