DataType utility function moved from cl/kernels/util to data_type.
PiperOrigin-RevId: 272046539
This commit is contained in:
parent
99e48c99d1
commit
7a9b0ec9bd
@ -87,11 +87,13 @@ class FromTensorConverter : public OpenClConverterImpl {
|
||||
const TensorObjectDef& input_def,
|
||||
const TensorObjectDef& output_def) const {
|
||||
return std::make_pair(
|
||||
"__global " + GetDataType4(output_def.object_def.data_type) + "* dst",
|
||||
"__global " + ToCLDataType(output_def.object_def.data_type, 4) +
|
||||
"* dst",
|
||||
"dst[(d * size.y + y) * size.x + x] = " +
|
||||
(output_def.object_def.data_type == input_def.object_def.data_type
|
||||
? "input;"
|
||||
: "convert_" + GetDataType4(output_def.object_def.data_type) +
|
||||
: "convert_" +
|
||||
ToCLDataType(output_def.object_def.data_type, 4) +
|
||||
"(input);"));
|
||||
}
|
||||
|
||||
@ -99,7 +101,7 @@ class FromTensorConverter : public OpenClConverterImpl {
|
||||
const TensorObjectDef& input_def,
|
||||
const TensorObjectDef& output_def) const {
|
||||
return std::make_pair(
|
||||
"__global " + GetDataType(output_def.object_def.data_type) + "* dst",
|
||||
"__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
|
||||
R"(
|
||||
int c = d * 4;
|
||||
int index = (y * size.x + x) * size.z + c;
|
||||
@ -143,7 +145,7 @@ __kernel void from_tensor()" +
|
||||
int y = get_global_id(1);
|
||||
int d = get_global_id(2);
|
||||
if (x >= size.x || y >= size.y || d >= size.w) return;
|
||||
)" + GetDataType4(input_def.object_def.data_type) +
|
||||
)" + ToCLDataType(input_def.object_def.data_type, 4) +
|
||||
" input = " + src_tensor.Read3D("x", "y", "d") + ";\n" +
|
||||
params_kernel.second + "\n}";
|
||||
queue_ = environment->queue();
|
||||
@ -199,11 +201,11 @@ class ToTensorConverter : public OpenClConverterImpl {
|
||||
const TensorObjectDef& input_def,
|
||||
const TensorObjectDef& output_def) const {
|
||||
return std::make_pair(
|
||||
"__global " + GetDataType4(input_def.object_def.data_type) + "* src",
|
||||
"__global " + ToCLDataType(input_def.object_def.data_type, 4) + "* src",
|
||||
output_def.object_def.data_type == input_def.object_def.data_type
|
||||
? "result = src[(d * size.y + y) * size.x + x];"
|
||||
: "result = convert_" +
|
||||
GetDataType4(output_def.object_def.data_type) +
|
||||
ToCLDataType(output_def.object_def.data_type, 4) +
|
||||
"(src[(d * size.y + y) * size.x + x]);");
|
||||
}
|
||||
|
||||
@ -211,7 +213,7 @@ class ToTensorConverter : public OpenClConverterImpl {
|
||||
const TensorObjectDef& input_def,
|
||||
const TensorObjectDef& output_def) const {
|
||||
return std::make_pair(
|
||||
"__global " + GetDataType(input_def.object_def.data_type) + "* src",
|
||||
"__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
|
||||
R"(int c = d * 4;
|
||||
int index = (y * size.x + x) * size.z + c;
|
||||
result.x = src[index];
|
||||
@ -246,7 +248,7 @@ __kernel void to_tensor()" +
|
||||
int d = get_global_id(2);
|
||||
|
||||
if (x >= size.x || y >= size.y || d >= size.w) return;
|
||||
)" + GetDataType4(output_def.object_def.data_type) +
|
||||
)" + ToCLDataType(output_def.object_def.data_type, 4) +
|
||||
" result;\n" + params_kernel.second + "\n " +
|
||||
dst_tensor.Write3D("result", "x", "y", "d") + ";\n}";
|
||||
queue_ = environment->queue();
|
||||
|
@ -125,19 +125,6 @@ std::string GetCommonDefines(CalculationsPrecision precision) {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string GetDataType(DataType type) {
|
||||
switch (type) {
|
||||
case DataType::FLOAT16:
|
||||
return "half";
|
||||
case DataType::FLOAT32:
|
||||
return "float";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetDataType4(DataType type) { return GetDataType(type) + "4"; }
|
||||
|
||||
TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
|
||||
const std::string& uniform_size_name,
|
||||
const TensorDescriptor& descriptor)
|
||||
@ -148,7 +135,7 @@ TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
|
||||
std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
return absl::StrCat("__global ", GetDataType4(descriptor_.data_type),
|
||||
return absl::StrCat("__global ", ToCLDataType(descriptor_.data_type, 4),
|
||||
"* ", tensor_name_);
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
@ -157,7 +144,7 @@ std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
|
||||
return GetImageModifier(access_type) + " image2d_array_t " + tensor_name_;
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
if (access_type == AccessType::WRITE) {
|
||||
return absl::StrCat("__global ", GetDataType4(descriptor_.data_type),
|
||||
return absl::StrCat("__global ", ToCLDataType(descriptor_.data_type, 4),
|
||||
"* ", tensor_name_);
|
||||
} else {
|
||||
return GetImageModifier(access_type) + " image1d_buffer_t " +
|
||||
|
@ -36,9 +36,6 @@ namespace cl {
|
||||
|
||||
std::string GetCommonDefines(CalculationsPrecision precision);
|
||||
|
||||
std::string GetDataType(DataType type);
|
||||
std::string GetDataType4(DataType type);
|
||||
|
||||
enum class TextureAddressMode {
|
||||
DONT_CARE, // translated to CLK_ADDRESS_NONE
|
||||
ZERO, // translated to CLK_ADDRESS_CLAMP
|
||||
|
@ -74,5 +74,36 @@ std::string ToString(DataType data_type) {
|
||||
return "undefined";
|
||||
}
|
||||
|
||||
std::string ToCLDataType(DataType data_type, int vec_size) {
|
||||
const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
|
||||
switch (data_type) {
|
||||
case DataType::FLOAT16:
|
||||
return "half" + postfix;
|
||||
case DataType::FLOAT32:
|
||||
return "float" + postfix;
|
||||
case DataType::FLOAT64:
|
||||
return "double" + postfix;
|
||||
case DataType::INT16:
|
||||
return "short" + postfix;
|
||||
case DataType::INT32:
|
||||
return "int" + postfix;
|
||||
case DataType::INT64:
|
||||
return "long" + postfix;
|
||||
case DataType::INT8:
|
||||
return "char" + postfix;
|
||||
case DataType::UINT16:
|
||||
return "ushort" + postfix;
|
||||
case DataType::UINT32:
|
||||
return "uint" + postfix;
|
||||
case DataType::UINT64:
|
||||
return "ulong" + postfix;
|
||||
case DataType::UINT8:
|
||||
return "uchar" + postfix;
|
||||
case DataType::UNKNOWN:
|
||||
return "unknown";
|
||||
}
|
||||
return "undefined";
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -41,6 +41,8 @@ size_t SizeOf(DataType type);
|
||||
|
||||
std::string ToString(DataType t);
|
||||
|
||||
std::string ToCLDataType(DataType data_type, int vec_size = 1);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user