Added support of all storage types to ConvolutionMetal.
PiperOrigin-RevId: 357237233 Change-Id: I7d781b130f482cf7dc6edafed0229cc5648659a5
This commit is contained in:
parent
a320c3cd7d
commit
e9ce368a53
tensorflow/lite/delegates/gpu
@ -204,6 +204,15 @@ std::string GenerateConvolution(const ConvolutionMetal::ConvParams& params,
|
||||
!params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
|
||||
params.y_kernel_is_1;
|
||||
|
||||
const auto src_storage_type = definition.src_tensors[0].storage_type;
|
||||
const auto dst_storage_type = definition.dst_tensors[0].storage_type;
|
||||
const bool src_is_linear =
|
||||
src_storage_type == TensorStorageType::BUFFER ||
|
||||
src_storage_type == TensorStorageType::IMAGE_BUFFER;
|
||||
const bool dst_is_linear =
|
||||
dst_storage_type == TensorStorageType::BUFFER ||
|
||||
dst_storage_type == TensorStorageType::IMAGE_BUFFER;
|
||||
|
||||
std::string channels[4] = {"x", "y", "z", "w"};
|
||||
std::string c;
|
||||
c.reserve(16 * 1024); // Reserve large enough buffer.
|
||||
@ -302,10 +311,12 @@ kernel void ComputeFunction(
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
const std::string s_y = std::to_string(y);
|
||||
c += " int c_y" + s_y + " = y * args.dilation_y + y" + s_y + ";\n";
|
||||
c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
|
||||
" >= args.src_tensor.Height();\n";
|
||||
c += " c_y" + s_y + " = clamp(c_y" + s_y +
|
||||
", 0, args.src_tensor.Height() - 1);\n";
|
||||
if (src_is_linear) {
|
||||
c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
|
||||
" >= args.src_tensor.Height();\n";
|
||||
c += " c_y" + s_y + " = clamp(c_y" + s_y +
|
||||
", 0, args.src_tensor.Height() - 1);\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
@ -320,10 +331,12 @@ kernel void ComputeFunction(
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_x = std::to_string(x);
|
||||
c += " int c_x" + s_x + " = x * args.dilation_x + x" + s_x + ";\n";
|
||||
c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
|
||||
" >= args.src_tensor.Width();\n";
|
||||
c += " c_x" + s_x + " = clamp(c_x" + s_x +
|
||||
", 0, args.src_tensor.Width() - 1);\n";
|
||||
if (src_is_linear) {
|
||||
c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
|
||||
" >= args.src_tensor.Width();\n";
|
||||
c += " c_x" + s_x + " = clamp(c_x" + s_x +
|
||||
", 0, args.src_tensor.Width() - 1);\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
@ -332,34 +345,38 @@ kernel void ComputeFunction(
|
||||
", 0, args.src_tensor.Width() - 1);\n";
|
||||
}
|
||||
}
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
const std::string s_y = std::to_string(y);
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_x = std::to_string(x);
|
||||
const std::string s_yx = s_y + s_x;
|
||||
if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
|
||||
c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x + "_out);\n";
|
||||
} else if (!params.y_kernel_is_1) {
|
||||
c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
|
||||
} else if (!params.x_kernel_is_1) {
|
||||
c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
|
||||
if (src_is_linear) {
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
const std::string s_y = std::to_string(y);
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_x = std::to_string(x);
|
||||
const std::string s_yx = s_y + s_x;
|
||||
if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
|
||||
c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x +
|
||||
"_out);\n";
|
||||
} else if (!params.y_kernel_is_1) {
|
||||
c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
|
||||
} else if (!params.x_kernel_is_1) {
|
||||
c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
const std::string s_y = std::to_string(y);
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_x = std::to_string(x);
|
||||
const std::string s_yx = s_y + s_x;
|
||||
if (definition.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
|
||||
c +=
|
||||
" device FLT4* src_loc_" + s_yx +
|
||||
" = args.src_tensor.GetHandle() + args.src_tensor.GetWHOffset(c_x" +
|
||||
s_x + ", c_y" + s_y + ");\n";
|
||||
} else if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::IMAGE_BUFFER) {
|
||||
c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
|
||||
s_x + ", c_y" + s_y + ");\n";
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
const std::string s_y = std::to_string(y);
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_x = std::to_string(x);
|
||||
const std::string s_yx = s_y + s_x;
|
||||
if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::BUFFER) {
|
||||
c += " device FLT4* src_loc_" + s_yx +
|
||||
" = args.src_tensor.GetHandle() + "
|
||||
"args.src_tensor.GetWHOffset(c_x" +
|
||||
s_x + ", c_y" + s_y + ");\n";
|
||||
} else if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::IMAGE_BUFFER) {
|
||||
c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
|
||||
s_x + ", c_y" + s_y + ");\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -403,30 +420,37 @@ kernel void ComputeFunction(
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_yx = std::to_string(y) + std::to_string(x);
|
||||
if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::BUFFER) {
|
||||
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
|
||||
c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
|
||||
";\n";
|
||||
} else {
|
||||
c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
|
||||
}
|
||||
} else if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::IMAGE_BUFFER) {
|
||||
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
|
||||
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" + s_yx +
|
||||
") * m" + s_yx + ";\n";
|
||||
} else {
|
||||
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" + s_yx +
|
||||
");\n";
|
||||
if (src_is_linear) {
|
||||
if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::BUFFER) {
|
||||
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
|
||||
c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
|
||||
";\n";
|
||||
} else {
|
||||
c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
|
||||
}
|
||||
} else if (definition.src_tensors[0].storage_type ==
|
||||
TensorStorageType::IMAGE_BUFFER) {
|
||||
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
|
||||
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
|
||||
s_yx + ") * m" + s_yx + ";\n";
|
||||
} else {
|
||||
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
|
||||
s_yx + ");\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
c += " src" + s_yx + " = args.src_tensor.Read(c_x" +
|
||||
std::to_string(x) + ", c_y" + std::to_string(y) + ", s);\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_yx = std::to_string(y) + std::to_string(x);
|
||||
c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
|
||||
if (src_is_linear) {
|
||||
for (int y = 0; y < params.block_size.y; ++y) {
|
||||
for (int x = 0; x < params.block_size.x; ++x) {
|
||||
const std::string s_yx = std::to_string(y) + std::to_string(x);
|
||||
c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -498,11 +522,13 @@ kernel void ComputeFunction(
|
||||
"return;\n";
|
||||
}
|
||||
|
||||
for_every_yx([](const std::string& s_yx, const std::string& s_x,
|
||||
const std::string& s_y, int x, int y) {
|
||||
return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
|
||||
", Y + " + s_y + ", Z);";
|
||||
});
|
||||
if (dst_is_linear) {
|
||||
for_every_yx([](const std::string& s_yx, const std::string& s_x,
|
||||
const std::string& s_y, int x, int y) {
|
||||
return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
|
||||
", Y + " + s_y + ", Z);";
|
||||
});
|
||||
}
|
||||
|
||||
std::string bias_name = "args.biases.GetPtr()";
|
||||
if (params.need_dst_loop) {
|
||||
@ -544,11 +570,16 @@ kernel void ComputeFunction(
|
||||
c += " {\n";
|
||||
}
|
||||
c += " FLT4 value = FLT4(r" + s_zyx + ");\n";
|
||||
c += " int linear_index = offset_" + s_yx +
|
||||
" + args.dst_tensor.SliceStride() * " + s_z + ";\n";
|
||||
c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
|
||||
s_y + ", Z + " + s_z + ");\n";
|
||||
c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
|
||||
if (dst_is_linear) {
|
||||
c += " int linear_index = offset_" + s_yx +
|
||||
" + args.dst_tensor.SliceStride() * " + s_z + ";\n";
|
||||
c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
|
||||
s_y + ", Z + " + s_z + ");\n";
|
||||
c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
|
||||
} else {
|
||||
c += " args.dst_tensor.Write(value, X + " + s_x + ", Y + " +
|
||||
s_y + ", Z + " + s_z + ");\n";
|
||||
}
|
||||
c += " }\n";
|
||||
}
|
||||
}
|
||||
@ -1148,14 +1179,7 @@ ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
|
||||
}
|
||||
|
||||
bool IsConvolutionMetalSupported(const OperationDef& definition) {
|
||||
const auto src_storage_type = definition.src_tensors[0].storage_type;
|
||||
const auto dst_storage_type = definition.dst_tensors[0].storage_type;
|
||||
const bool storages_are_buffers =
|
||||
(src_storage_type == TensorStorageType::BUFFER ||
|
||||
src_storage_type == TensorStorageType::IMAGE_BUFFER) &&
|
||||
(dst_storage_type == TensorStorageType::BUFFER ||
|
||||
dst_storage_type == TensorStorageType::IMAGE_BUFFER);
|
||||
return storages_are_buffers && definition.src_tensors.size() == 1 &&
|
||||
return definition.src_tensors.size() == 1 &&
|
||||
!definition.src_tensors[0].HasAxis(Axis::DEPTH);
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ absl::Status ConvolutionO2H2W1I1Stride1x1Dilation1x1Test(TestExecutionEnvironmen
|
||||
attr.padding.appended = HW(1, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
@ -91,7 +91,7 @@ absl::Status ConvolutionO1H2W2I1Stride1x1Dilation2x2Test(TestExecutionEnvironmen
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
@ -127,7 +127,7 @@ absl::Status ConvolutionO1H3W3I1Stride1x1Dilation1x1Test(TestExecutionEnvironmen
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
@ -163,7 +163,7 @@ absl::Status ConvolutionO2H1W1I2Stride1x1Dilation1x1Test(TestExecutionEnvironmen
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(1, 1);
|
||||
|
||||
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
@ -199,7 +199,7 @@ absl::Status ConvolutionO1H1W1I1Stride2x2Dilation1x1Test(TestExecutionEnvironmen
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.strides = HW(2, 2);
|
||||
|
||||
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
|
||||
for (auto storage : env->GetSupportedStorages()) {
|
||||
for (auto precision : env->GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
|
Loading…
Reference in New Issue
Block a user