DataType utility function moved from cl/kernels/util to data_type.

PiperOrigin-RevId: 272046539
This commit is contained in:
A. Unique TensorFlower 2019-09-30 13:01:11 -07:00 committed by TensorFlower Gardener
parent 99e48c99d1
commit 7a9b0ec9bd
5 changed files with 45 additions and 26 deletions

View File

@ -87,11 +87,13 @@ class FromTensorConverter : public OpenClConverterImpl {
const TensorObjectDef& input_def, const TensorObjectDef& input_def,
const TensorObjectDef& output_def) const { const TensorObjectDef& output_def) const {
return std::make_pair( return std::make_pair(
"__global " + GetDataType4(output_def.object_def.data_type) + "* dst", "__global " + ToCLDataType(output_def.object_def.data_type, 4) +
"* dst",
"dst[(d * size.y + y) * size.x + x] = " + "dst[(d * size.y + y) * size.x + x] = " +
(output_def.object_def.data_type == input_def.object_def.data_type (output_def.object_def.data_type == input_def.object_def.data_type
? "input;" ? "input;"
: "convert_" + GetDataType4(output_def.object_def.data_type) + : "convert_" +
ToCLDataType(output_def.object_def.data_type, 4) +
"(input);")); "(input);"));
} }
@ -99,7 +101,7 @@ class FromTensorConverter : public OpenClConverterImpl {
const TensorObjectDef& input_def, const TensorObjectDef& input_def,
const TensorObjectDef& output_def) const { const TensorObjectDef& output_def) const {
return std::make_pair( return std::make_pair(
"__global " + GetDataType(output_def.object_def.data_type) + "* dst", "__global " + ToCLDataType(output_def.object_def.data_type) + "* dst",
R"( R"(
int c = d * 4; int c = d * 4;
int index = (y * size.x + x) * size.z + c; int index = (y * size.x + x) * size.z + c;
@ -143,7 +145,7 @@ __kernel void from_tensor()" +
int y = get_global_id(1); int y = get_global_id(1);
int d = get_global_id(2); int d = get_global_id(2);
if (x >= size.x || y >= size.y || d >= size.w) return; if (x >= size.x || y >= size.y || d >= size.w) return;
)" + GetDataType4(input_def.object_def.data_type) + )" + ToCLDataType(input_def.object_def.data_type, 4) +
" input = " + src_tensor.Read3D("x", "y", "d") + ";\n" + " input = " + src_tensor.Read3D("x", "y", "d") + ";\n" +
params_kernel.second + "\n}"; params_kernel.second + "\n}";
queue_ = environment->queue(); queue_ = environment->queue();
@ -199,11 +201,11 @@ class ToTensorConverter : public OpenClConverterImpl {
const TensorObjectDef& input_def, const TensorObjectDef& input_def,
const TensorObjectDef& output_def) const { const TensorObjectDef& output_def) const {
return std::make_pair( return std::make_pair(
"__global " + GetDataType4(input_def.object_def.data_type) + "* src", "__global " + ToCLDataType(input_def.object_def.data_type, 4) + "* src",
output_def.object_def.data_type == input_def.object_def.data_type output_def.object_def.data_type == input_def.object_def.data_type
? "result = src[(d * size.y + y) * size.x + x];" ? "result = src[(d * size.y + y) * size.x + x];"
: "result = convert_" + : "result = convert_" +
GetDataType4(output_def.object_def.data_type) + ToCLDataType(output_def.object_def.data_type, 4) +
"(src[(d * size.y + y) * size.x + x]);"); "(src[(d * size.y + y) * size.x + x]);");
} }
@ -211,7 +213,7 @@ class ToTensorConverter : public OpenClConverterImpl {
const TensorObjectDef& input_def, const TensorObjectDef& input_def,
const TensorObjectDef& output_def) const { const TensorObjectDef& output_def) const {
return std::make_pair( return std::make_pair(
"__global " + GetDataType(input_def.object_def.data_type) + "* src", "__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
R"(int c = d * 4; R"(int c = d * 4;
int index = (y * size.x + x) * size.z + c; int index = (y * size.x + x) * size.z + c;
result.x = src[index]; result.x = src[index];
@ -246,7 +248,7 @@ __kernel void to_tensor()" +
int d = get_global_id(2); int d = get_global_id(2);
if (x >= size.x || y >= size.y || d >= size.w) return; if (x >= size.x || y >= size.y || d >= size.w) return;
)" + GetDataType4(output_def.object_def.data_type) + )" + ToCLDataType(output_def.object_def.data_type, 4) +
" result;\n" + params_kernel.second + "\n " + " result;\n" + params_kernel.second + "\n " +
dst_tensor.Write3D("result", "x", "y", "d") + ";\n}"; dst_tensor.Write3D("result", "x", "y", "d") + ";\n}";
queue_ = environment->queue(); queue_ = environment->queue();

View File

@ -125,19 +125,6 @@ std::string GetCommonDefines(CalculationsPrecision precision) {
return result; return result;
} }
std::string GetDataType(DataType type) {
switch (type) {
case DataType::FLOAT16:
return "half";
case DataType::FLOAT32:
return "float";
default:
return "error";
}
}
std::string GetDataType4(DataType type) { return GetDataType(type) + "4"; }
TensorCodeGenerator::TensorCodeGenerator(const std::string& name, TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
const std::string& uniform_size_name, const std::string& uniform_size_name,
const TensorDescriptor& descriptor) const TensorDescriptor& descriptor)
@ -148,7 +135,7 @@ TensorCodeGenerator::TensorCodeGenerator(const std::string& name,
std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const { std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
switch (descriptor_.storage_type) { switch (descriptor_.storage_type) {
case TensorStorageType::BUFFER: case TensorStorageType::BUFFER:
return absl::StrCat("__global ", GetDataType4(descriptor_.data_type), return absl::StrCat("__global ", ToCLDataType(descriptor_.data_type, 4),
"* ", tensor_name_); "* ", tensor_name_);
case TensorStorageType::TEXTURE_2D: case TensorStorageType::TEXTURE_2D:
case TensorStorageType::SINGLE_TEXTURE_2D: case TensorStorageType::SINGLE_TEXTURE_2D:
@ -157,7 +144,7 @@ std::string TensorCodeGenerator::GetDeclaration(AccessType access_type) const {
return GetImageModifier(access_type) + " image2d_array_t " + tensor_name_; return GetImageModifier(access_type) + " image2d_array_t " + tensor_name_;
case TensorStorageType::IMAGE_BUFFER: case TensorStorageType::IMAGE_BUFFER:
if (access_type == AccessType::WRITE) { if (access_type == AccessType::WRITE) {
return absl::StrCat("__global ", GetDataType4(descriptor_.data_type), return absl::StrCat("__global ", ToCLDataType(descriptor_.data_type, 4),
"* ", tensor_name_); "* ", tensor_name_);
} else { } else {
return GetImageModifier(access_type) + " image1d_buffer_t " + return GetImageModifier(access_type) + " image1d_buffer_t " +

View File

@ -36,9 +36,6 @@ namespace cl {
std::string GetCommonDefines(CalculationsPrecision precision); std::string GetCommonDefines(CalculationsPrecision precision);
std::string GetDataType(DataType type);
std::string GetDataType4(DataType type);
enum class TextureAddressMode { enum class TextureAddressMode {
DONT_CARE, // translated to CLK_ADDRESS_NONE DONT_CARE, // translated to CLK_ADDRESS_NONE
ZERO, // translated to CLK_ADDRESS_CLAMP ZERO, // translated to CLK_ADDRESS_CLAMP

View File

@ -74,5 +74,36 @@ std::string ToString(DataType data_type) {
return "undefined"; return "undefined";
} }
std::string ToCLDataType(DataType data_type, int vec_size) {
const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
switch (data_type) {
case DataType::FLOAT16:
return "half" + postfix;
case DataType::FLOAT32:
return "float" + postfix;
case DataType::FLOAT64:
return "double" + postfix;
case DataType::INT16:
return "short" + postfix;
case DataType::INT32:
return "int" + postfix;
case DataType::INT64:
return "long" + postfix;
case DataType::INT8:
return "char" + postfix;
case DataType::UINT16:
return "ushort" + postfix;
case DataType::UINT32:
return "uint" + postfix;
case DataType::UINT64:
return "ulong" + postfix;
case DataType::UINT8:
return "uchar" + postfix;
case DataType::UNKNOWN:
return "unknown";
}
return "undefined";
}
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -41,6 +41,8 @@ size_t SizeOf(DataType type);
std::string ToString(DataType t); std::string ToString(DataType t);
std::string ToCLDataType(DataType data_type, int vec_size = 1);
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite