Supported F32_F16 precision in Winograd transformations.

PiperOrigin-RevId: 296084853
Change-Id: If7f1715d84eae34159cf403d1ad208f9d1aa7305
This commit is contained in:
Raman Sarokin 2020-02-19 16:30:48 -08:00 committed by TensorFlower Gardener
parent ba2cbe1e55
commit 3aecbb9fb1
3 changed files with 182 additions and 43 deletions

View File

@ -16,10 +16,12 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include <cmath>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
namespace tflite {
@ -225,6 +227,37 @@ std::string TensorCodeGenerator::ReadAsFloatWHDSB(
address_mode);
}
std::string TensorCodeGenerator::ReadAsTypeWHS(
DataType type, const std::string& x, const std::string& y,
const std::string& s, TextureAddressMode address_mode) const {
return ReadAsType(type, GetGlobalAddressNoDeclarationWHS(x, y, s),
address_mode);
}
std::string TensorCodeGenerator::ReadAsTypeWHSB(
DataType type, const std::string& x, const std::string& y,
const std::string& s, const std::string& b,
TextureAddressMode address_mode) const {
return ReadAsType(type, GetGlobalAddressNoDeclarationWHSB(x, y, s, b),
address_mode);
}
std::string TensorCodeGenerator::ReadAsTypeWHDS(
DataType type, const std::string& x, const std::string& y,
const std::string& z, const std::string& s,
TextureAddressMode address_mode) const {
return ReadAsType(type, GetGlobalAddressNoDeclarationWHDS(x, y, z, s),
address_mode);
}
std::string TensorCodeGenerator::ReadAsTypeWHDSB(
DataType type, const std::string& x, const std::string& y,
const std::string& z, const std::string& s, const std::string& b,
TextureAddressMode address_mode) const {
return ReadAsType(type, GetGlobalAddressNoDeclarationWHDSB(x, y, z, s, b),
address_mode);
}
std::string TensorCodeGenerator::GetAddressWHS(const std::string& var_name,
const std::string& x,
const std::string& y,
@ -449,6 +482,39 @@ std::string TensorCodeGenerator::ReadAsFloat(
}
}
std::string TensorCodeGenerator::ReadAsType(
DataType type, const std::string& global_address,
TextureAddressMode address_mode) const {
const std::string read_as =
type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
switch (descriptor_.storage_type) {
case TensorStorageType::BUFFER: {
const std::string reading =
absl::StrCat(tensor_name_, "[", global_address, "]");
if (type == descriptor_.data_type) {
return reading;
} else {
const std::string conversion =
type == DataType::FLOAT16 ? "convert_half4" : "convert_float4";
return absl::StrCat(conversion, "(", reading, ")");
}
}
case TensorStorageType::TEXTURE_2D:
case TensorStorageType::TEXTURE_3D:
case TensorStorageType::SINGLE_TEXTURE_2D:
case TensorStorageType::TEXTURE_ARRAY:
return absl::StrCat(
read_as, "(", tensor_name_,
", " + TextureAddressModeToString(address_mode) + ", ",
global_address, ")");
case TensorStorageType::IMAGE_BUFFER:
return absl::StrCat(read_as, "(", tensor_name_, ", ", global_address,
")");
case TensorStorageType::UNKNOWN:
return "";
}
}
std::string TensorCodeGenerator::Write(
const std::string& var_name, const std::string& global_address) const {
switch (descriptor_.storage_type) {

View File

@ -138,6 +138,28 @@ class TensorCodeGenerator {
const std::string& s, const std::string& b,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
// Optimization for textures, so as in opencl we can use read_imagef for any
// texture type.
std::string ReadAsTypeWHS(
DataType type, const std::string& x, const std::string& y,
const std::string& s,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
std::string ReadAsTypeWHSB(
DataType type, const std::string& x, const std::string& y,
const std::string& s, const std::string& b,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
std::string ReadAsTypeWHDS(
DataType type, const std::string& x, const std::string& y,
const std::string& z, const std::string& s,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
std::string ReadAsTypeWHDSB(
DataType type, const std::string& x, const std::string& y,
const std::string& z, const std::string& s, const std::string& b,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
std::string WriteWHS(const std::string& var_name, const std::string& x,
const std::string& y, const std::string& s) const;
@ -161,6 +183,9 @@ class TensorCodeGenerator {
std::string ReadAsFloat(
const std::string& global_address,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
std::string ReadAsType(
DataType type, const std::string& global_address,
TextureAddressMode address_mode = TextureAddressMode::DONT_CARE) const;
std::string Write(const std::string& var_name,
const std::string& global_address) const;

View File

@ -21,6 +21,8 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -49,8 +51,22 @@ std::string GetWinograd4x4To36Code(
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
const bool is_buffer = src_tensor_type == TensorStorageType::BUFFER;
switch (op_def.precision) {
case CalculationsPrecision::F32:
case CalculationsPrecision::F32_F16:
c += "#define ACCUM_FLT float\n";
break;
case CalculationsPrecision::F16:
c += "#define ACCUM_FLT half\n";
break;
}
const DataType accum_type = op_def.precision == CalculationsPrecision::F16
? DataType::FLOAT16
: DataType::FLOAT32;
auto bt_mat = BtMatrixForWinograd4x4To6x6();
c += "constant FLT Bt[36] = {\n";
c += "constant ACCUM_FLT Bt[36] = {\n";
for (int y = 0; y < 6; ++y) {
c += "\t";
for (int x = 0; x < 6; ++x) {
@ -79,10 +95,12 @@ std::string GetWinograd4x4To36Code(
c += " }\n";
c += " int tile_x = (DST_X % tiles_x) * 4;\n";
c += " int tile_y = (DST_X / tiles_x) * 4;\n";
c += " FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " FLT bt_ar[6];\n";
c += " FLT4 t0 = " + bt_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ";\n";
c += " FLT4 t1 = " + bt_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ";\n";
c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " ACCUM_FLT bt_ar[6];\n";
c += " ACCUM_FLT4 t0 = TO_ACCUM_TYPE(" +
bt_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ");\n";
c += " ACCUM_FLT4 t1 = TO_ACCUM_TYPE(" +
bt_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ");\n";
c += " DST_Y *= 6;\n";
c += " bt_ar[0] = t0.x;\n";
c += " bt_ar[1] = t0.y;\n";
@ -92,15 +110,17 @@ std::string GetWinograd4x4To36Code(
c += " bt_ar[5] = t1.y;\n";
auto read_src = [&](const std::string& src, const std::string& xs) {
if (is_image_buffer) {
c += " FLT4 " + src + " = " +
src_tensor.Read("src_a_" + xs + " + offset") + ";\n";
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") +
";\n";
} else if (is_buffer) {
c += " FLT4 " + src + " = " +
src_tensor.Read("src_a_" + xs + " + offset") + " * m" + xs + "_x;\n";
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsType(accum_type, "src_a_" + xs + " + offset") +
" * m" + xs + "_x;\n";
} else {
c += " FLT4 " + src + " = " +
src_tensor.ReadWHSB("tile_x + padding.x + " + xs, "yc", "DST_Z",
batch_id) +
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsTypeWHSB(accum_type, "tile_x + padding.x + " + xs,
"yc", "DST_Z", batch_id) +
";\n";
}
};
@ -108,8 +128,8 @@ std::string GetWinograd4x4To36Code(
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
c += " int xc" + xs + " = tile_x + padding.x + " + xs + ";\n";
c += " FLT m" + xs + "_x = (FLT)(xc" + xs + " >= 0 && xc" + xs +
" < src_size.x);\n";
c += " ACCUM_FLT m" + xs + "_x = (ACCUM_FLT)(xc" + xs + " >= 0 && xc" +
xs + " < src_size.x);\n";
c += " bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
" < src_size.x);\n";
c += " xc" + xs + " = clamp(xc" + xs + ", 0, src_size.x - 1);\n";
@ -126,9 +146,9 @@ std::string GetWinograd4x4To36Code(
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
c += " int offset = select(0, yc * src_size.x, iny);\n";
c += " FLT bt = bt_ar[0] * (FLT)(iny);\n";
c += " ACCUM_FLT bt = bt_ar[0] * (ACCUM_FLT)(iny);\n";
} else {
c += " FLT bt = bt_ar[0];\n";
c += " ACCUM_FLT bt = bt_ar[0];\n";
}
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
@ -144,9 +164,9 @@ std::string GetWinograd4x4To36Code(
if (is_buffer || is_image_buffer) {
c += " bool iny = (yc >= 0 && yc < src_size.y);\n";
c += " int offset = select(0, yc * src_size.x, iny);\n";
c += " FLT bt = bt_ar[" + ys + "] * (FLT)(iny);\n";
c += " ACCUM_FLT bt = bt_ar[" + ys + "] * (ACCUM_FLT)(iny);\n";
} else {
c += " FLT bt = bt_ar[" + ys + "];\n";
c += " ACCUM_FLT bt = bt_ar[" + ys + "];\n";
}
for (int x = 0; x < 6; ++x) {
const std::string xs = std::to_string(x);
@ -158,42 +178,50 @@ std::string GetWinograd4x4To36Code(
}
const LinkingContext context{"r0", "DST_X", "DST_Y", "DST_Z"};
c += " {\n";
c += " FLT4 r0 = I0 + Bt[2] * I2 + Bt[4] * I4;\n";
c += " FLT4 r0 = TO_FLT4(I0 + Bt[2] * I2 + Bt[4] * I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * I4;\n";
c += " FLT4 r0 = TO_FLT4(Bt[7] * I1 + Bt[8] * I2 + Bt[9] * I3 + Bt[10] * "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = Bt[13] * I1 + Bt[14] * I2 + Bt[15] * I3 + Bt[16] * I4;\n";
c += " FLT4 r0 = TO_FLT4(Bt[13] * I1 + Bt[14] * I2 + Bt[15] * I3 + Bt[16] "
"* "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = Bt[19] * I1 + Bt[20] * I2 + Bt[21] * I3 + Bt[22] * I4;\n";
c += " FLT4 r0 = TO_FLT4(Bt[19] * I1 + Bt[20] * I2 + Bt[21] * I3 + Bt[22] "
"* "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = Bt[25] * I1 + Bt[26] * I2 + Bt[27] * I3 + Bt[28] * I4;\n";
c += " FLT4 r0 = TO_FLT4(Bt[25] * I1 + Bt[26] * I2 + Bt[27] * I3 + Bt[28] "
"* "
"I4);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " DST_Y++;\n";
c += " }\n";
c += " {\n";
c += " FLT4 r0 = Bt[31] * I1 + Bt[33] * I3 + I5;\n";
c += " FLT4 r0 = TO_FLT4(Bt[31] * I1 + Bt[33] * I3 + I5);\n";
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("r0", "DST_X", "DST_Y", "DST_Z", batch_id);
c += " DST_Y++;\n";
c += " }\n";
c += "}\n";
// std::cout << c << std::endl;
return c;
}
@ -213,8 +241,22 @@ std::string GetWinograd36To4x4Code(
const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
std::string c = GetCommonDefines(op_def.precision);
switch (op_def.precision) {
case CalculationsPrecision::F32:
case CalculationsPrecision::F32_F16:
c += "#define ACCUM_FLT float\n";
break;
case CalculationsPrecision::F16:
c += "#define ACCUM_FLT half\n";
break;
}
const DataType accum_type = op_def.precision == CalculationsPrecision::F16
? DataType::FLOAT16
: DataType::FLOAT32;
auto at_mat = AtMatrixForWinograd4x4To6x6();
c += "constant FLT At[24] = {\n";
c += "constant ACCUM_FLT At[24] = {\n";
for (int y = 0; y < 4; ++y) {
c += "\t";
for (int x = 0; x < 6; ++x) {
@ -243,10 +285,12 @@ std::string GetWinograd36To4x4Code(
"dst_size.z) {\n";
c += " return; \n";
c += " }\n";
c += " FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " FLT at_ar[6];\n";
c += " FLT4 t00 = " + at_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ";\n";
c += " FLT4 t01 = " + at_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ";\n";
c += " ACCUM_FLT4 I0, I1, I2, I3, I4, I5;\n";
c += " ACCUM_FLT at_ar[6];\n";
c += " ACCUM_FLT4 t00 = TO_ACCUM_TYPE(" +
at_arr.ReadLinearFLT4("DST_Y * 2 + 0") + ");\n";
c += " ACCUM_FLT4 t01 = TO_ACCUM_TYPE(" +
at_arr.ReadLinearFLT4("DST_Y * 2 + 1") + ");\n";
c += " at_ar[0] = t00.x;\n";
c += " at_ar[1] = t00.y;\n";
c += " at_ar[2] = t00.z;\n";
@ -254,56 +298,60 @@ std::string GetWinograd36To4x4Code(
c += " at_ar[4] = t01.x;\n";
c += " at_ar[5] = t01.y;\n";
c += " {\n";
c += " FLT at = at_ar[0];\n";
c += " ACCUM_FLT at = at_ar[0];\n";
for (int x = 0; x < 6; ++x) {
const std::string yc = std::to_string(x);
const std::string src = "src" + std::to_string(x);
c += " FLT4 " + src + " = " +
src_tensor.ReadWHSB("tile_id", yc, "DST_Z", batch_id) + ";\n";
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsTypeWHSB(accum_type, "tile_id", yc, "DST_Z",
batch_id) +
";\n";
c += " I" + std::to_string(x) + " = at * " + src + ";\n";
}
c += " }\n";
for (int y = 1; y < 6; ++y) {
c += " {\n";
c += " FLT at = at_ar[" + std::to_string(y) + "];\n";
c += " ACCUM_FLT at = at_ar[" + std::to_string(y) + "];\n";
for (int x = 0; x < 6; ++x) {
const std::string yc = std::to_string(y * 6 + x);
const std::string src = "src" + std::to_string(x);
c += " FLT4 " + src + " = " +
src_tensor.ReadWHSB("tile_id", yc, "DST_Z", batch_id) + ";\n";
c += " ACCUM_FLT4 " + src + " = " +
src_tensor.ReadAsTypeWHSB(accum_type, "tile_id", yc, "DST_Z",
batch_id) +
";\n";
c += " I" + std::to_string(x) + " += at * " + src + ";\n";
}
c += " }\n";
}
c += " FLT4 t0 = I1 + I2;\n";
c += " FLT4 t1 = I3 + I4;\n";
c += " ACCUM_FLT4 t0 = I1 + I2;\n";
c += " ACCUM_FLT4 t1 = I3 + I4;\n";
c += " FLT4 bias_val = " + biases.ReadLinearFLT4("DST_Z") + ";\n";
c += " {\n";
const LinkingContext context{"r0", "tile_x", "tile_y", "DST_Z"};
c += " FLT4 r0 = I0 + t0 + t1 + bias_val;\n";
c += " FLT4 r0 = TO_FLT4(I0 + t0 + t1) + bias_val;\n";
c += PostProcess(linked_operations, context);
c += " " +
dst_tensor.WriteWHSB("r0", "tile_x", "tile_y", "DST_Z", batch_id);
c += " tile_x++;\n";
c += " }\n";
c += " FLT4 t2 = I1 - I2;\n";
c += " FLT4 t3 = I3 - I4;\n";
c += " ACCUM_FLT4 t2 = I1 - I2;\n";
c += " ACCUM_FLT4 t3 = I3 - I4;\n";
c += " if (tile_x < dst_size.x) {\n";
c += " FLT4 r0 = t2 * At[7] + t3 * At[9] + bias_val;\n";
c += " FLT4 r0 = TO_FLT4(t2 * At[7] + t3 * At[9]) + bias_val;\n";
c += PostProcess(linked_operations, context);
c += " " +
dst_tensor.WriteWHSB("r0", "tile_x", "tile_y", "DST_Z", batch_id);
c += " tile_x++;\n";
c += " }\n";
c += " if (tile_x < dst_size.x) {\n";
c += " FLT4 r0 = t0 * At[13] + t1 * At[15] + bias_val;\n";
c += " FLT4 r0 = TO_FLT4(t0 * At[13] + t1 * At[15]) + bias_val;\n";
c += PostProcess(linked_operations, context);
c += " " +
dst_tensor.WriteWHSB("r0", "tile_x", "tile_y", "DST_Z", batch_id);
c += " tile_x++;\n";
c += " }\n";
c += " if (tile_x < dst_size.x) {\n";
c += " FLT4 r0 = t2 * At[19] + t3 * At[21] + I5 + bias_val;\n";
c += " FLT4 r0 = TO_FLT4(t2 * At[19] + t3 * At[21] + I5) + bias_val;\n";
c += PostProcess(linked_operations, context);
c += " " +
dst_tensor.WriteWHSB("r0", "tile_x", "tile_y", "DST_Z", batch_id);