Added support of all storage types to ConvolutionMetal.

PiperOrigin-RevId: 357237233
Change-Id: I7d781b130f482cf7dc6edafed0229cc5648659a5
This commit is contained in:
Raman Sarokin 2021-02-12 11:50:42 -08:00 committed by TensorFlower Gardener
parent a320c3cd7d
commit e9ce368a53
2 changed files with 101 additions and 77 deletions

View File

@ -204,6 +204,15 @@ std::string GenerateConvolution(const ConvolutionMetal::ConvParams& params,
!params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
params.y_kernel_is_1;
const auto src_storage_type = definition.src_tensors[0].storage_type;
const auto dst_storage_type = definition.dst_tensors[0].storage_type;
const bool src_is_linear =
src_storage_type == TensorStorageType::BUFFER ||
src_storage_type == TensorStorageType::IMAGE_BUFFER;
const bool dst_is_linear =
dst_storage_type == TensorStorageType::BUFFER ||
dst_storage_type == TensorStorageType::IMAGE_BUFFER;
std::string channels[4] = {"x", "y", "z", "w"};
std::string c;
c.reserve(16 * 1024); // Reserve large enough buffer.
@ -302,10 +311,12 @@ kernel void ComputeFunction(
for (int y = 0; y < params.block_size.y; ++y) {
const std::string s_y = std::to_string(y);
c += " int c_y" + s_y + " = y * args.dilation_y + y" + s_y + ";\n";
c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
" >= args.src_tensor.Height();\n";
c += " c_y" + s_y + " = clamp(c_y" + s_y +
", 0, args.src_tensor.Height() - 1);\n";
if (src_is_linear) {
c += " bool y" + s_y + "_out = c_y" + s_y + " < 0 || c_y" + s_y +
" >= args.src_tensor.Height();\n";
c += " c_y" + s_y + " = clamp(c_y" + s_y +
", 0, args.src_tensor.Height() - 1);\n";
}
}
} else {
for (int y = 0; y < params.block_size.y; ++y) {
@ -320,10 +331,12 @@ kernel void ComputeFunction(
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
c += " int c_x" + s_x + " = x * args.dilation_x + x" + s_x + ";\n";
c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
" >= args.src_tensor.Width();\n";
c += " c_x" + s_x + " = clamp(c_x" + s_x +
", 0, args.src_tensor.Width() - 1);\n";
if (src_is_linear) {
c += " bool x" + s_x + "_out = c_x" + s_x + " < 0 || c_x" + s_x +
" >= args.src_tensor.Width();\n";
c += " c_x" + s_x + " = clamp(c_x" + s_x +
", 0, args.src_tensor.Width() - 1);\n";
}
}
} else {
for (int x = 0; x < params.block_size.x; ++x) {
@ -332,34 +345,38 @@ kernel void ComputeFunction(
", 0, args.src_tensor.Width() - 1);\n";
}
}
for (int y = 0; y < params.block_size.y; ++y) {
const std::string s_y = std::to_string(y);
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
const std::string s_yx = s_y + s_x;
if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x + "_out);\n";
} else if (!params.y_kernel_is_1) {
c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
} else if (!params.x_kernel_is_1) {
c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
if (src_is_linear) {
for (int y = 0; y < params.block_size.y; ++y) {
const std::string s_y = std::to_string(y);
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
const std::string s_yx = s_y + s_x;
if (!params.y_kernel_is_1 && !params.x_kernel_is_1) {
c += " FLT m" + s_yx + " = !(y" + s_y + "_out || x" + s_x +
"_out);\n";
} else if (!params.y_kernel_is_1) {
c += " FLT m" + s_yx + " = !y" + s_y + "_out;\n";
} else if (!params.x_kernel_is_1) {
c += " FLT m" + s_yx + " = !x" + s_x + "_out;\n";
}
}
}
}
for (int y = 0; y < params.block_size.y; ++y) {
const std::string s_y = std::to_string(y);
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
const std::string s_yx = s_y + s_x;
if (definition.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
c +=
" device FLT4* src_loc_" + s_yx +
" = args.src_tensor.GetHandle() + args.src_tensor.GetWHOffset(c_x" +
s_x + ", c_y" + s_y + ");\n";
} else if (definition.src_tensors[0].storage_type ==
TensorStorageType::IMAGE_BUFFER) {
c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
s_x + ", c_y" + s_y + ");\n";
for (int y = 0; y < params.block_size.y; ++y) {
const std::string s_y = std::to_string(y);
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_x = std::to_string(x);
const std::string s_yx = s_y + s_x;
if (definition.src_tensors[0].storage_type ==
TensorStorageType::BUFFER) {
c += " device FLT4* src_loc_" + s_yx +
" = args.src_tensor.GetHandle() + "
"args.src_tensor.GetWHOffset(c_x" +
s_x + ", c_y" + s_y + ");\n";
} else if (definition.src_tensors[0].storage_type ==
TensorStorageType::IMAGE_BUFFER) {
c += " int src_loc_" + s_yx + " = args.src_tensor.GetWHOffset(c_x" +
s_x + ", c_y" + s_y + ");\n";
}
}
}
}
@ -403,30 +420,37 @@ kernel void ComputeFunction(
for (int y = 0; y < params.block_size.y; ++y) {
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_yx = std::to_string(y) + std::to_string(x);
if (definition.src_tensors[0].storage_type ==
TensorStorageType::BUFFER) {
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
";\n";
} else {
c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
}
} else if (definition.src_tensors[0].storage_type ==
TensorStorageType::IMAGE_BUFFER) {
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" + s_yx +
") * m" + s_yx + ";\n";
} else {
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" + s_yx +
");\n";
if (src_is_linear) {
if (definition.src_tensors[0].storage_type ==
TensorStorageType::BUFFER) {
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
c += " src" + s_yx + " = *src_loc_" + s_yx + " * m" + s_yx +
";\n";
} else {
c += " src" + s_yx + " = *src_loc_" + s_yx + ";\n";
}
} else if (definition.src_tensors[0].storage_type ==
TensorStorageType::IMAGE_BUFFER) {
if (!params.y_kernel_is_1 || !params.x_kernel_is_1) {
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
s_yx + ") * m" + s_yx + ";\n";
} else {
c += " src" + s_yx + " = args.src_tensor.Read(src_loc_" +
s_yx + ");\n";
}
}
} else {
c += " src" + s_yx + " = args.src_tensor.Read(c_x" +
std::to_string(x) + ", c_y" + std::to_string(y) + ", s);\n";
}
}
}
for (int y = 0; y < params.block_size.y; ++y) {
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_yx = std::to_string(y) + std::to_string(x);
c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
if (src_is_linear) {
for (int y = 0; y < params.block_size.y; ++y) {
for (int x = 0; x < params.block_size.x; ++x) {
const std::string s_yx = std::to_string(y) + std::to_string(x);
c += " src_loc_" + s_yx + " += args.src_tensor.SliceStride();\n";
}
}
}
};
@ -498,11 +522,13 @@ kernel void ComputeFunction(
"return;\n";
}
for_every_yx([](const std::string& s_yx, const std::string& s_x,
const std::string& s_y, int x, int y) {
return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
", Y + " + s_y + ", Z);";
});
if (dst_is_linear) {
for_every_yx([](const std::string& s_yx, const std::string& s_x,
const std::string& s_y, int x, int y) {
return " args.dst_tensor.GetAddress(offset_" + s_yx + ", X + " + s_x +
", Y + " + s_y + ", Z);";
});
}
std::string bias_name = "args.biases.GetPtr()";
if (params.need_dst_loop) {
@ -544,11 +570,16 @@ kernel void ComputeFunction(
c += " {\n";
}
c += " FLT4 value = FLT4(r" + s_zyx + ");\n";
c += " int linear_index = offset_" + s_yx +
" + args.dst_tensor.SliceStride() * " + s_z + ";\n";
c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
s_y + ", Z + " + s_z + ");\n";
c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
if (dst_is_linear) {
c += " int linear_index = offset_" + s_yx +
" + args.dst_tensor.SliceStride() * " + s_z + ";\n";
c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
s_y + ", Z + " + s_z + ");\n";
c += " args.dst_tensor.WriteLinear(value, linear_index);\n";
} else {
c += " args.dst_tensor.Write(value, X + " + s_x + ", Y + " +
s_y + ", Z + " + s_z + ");\n";
}
c += " }\n";
}
}
@ -1148,14 +1179,7 @@ ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
}
bool IsConvolutionMetalSupported(const OperationDef& definition) {
const auto src_storage_type = definition.src_tensors[0].storage_type;
const auto dst_storage_type = definition.dst_tensors[0].storage_type;
const bool storages_are_buffers =
(src_storage_type == TensorStorageType::BUFFER ||
src_storage_type == TensorStorageType::IMAGE_BUFFER) &&
(dst_storage_type == TensorStorageType::BUFFER ||
dst_storage_type == TensorStorageType::IMAGE_BUFFER);
return storages_are_buffers && definition.src_tensors.size() == 1 &&
return definition.src_tensors.size() == 1 &&
!definition.src_tensors[0].HasAxis(Axis::DEPTH);
}

View File

@ -55,7 +55,7 @@ absl::Status ConvolutionO2H2W1I1Stride1x1Dilation1x1Test(TestExecutionEnvironmen
attr.padding.appended = HW(1, 0);
attr.strides = HW(1, 1);
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@ -91,7 +91,7 @@ absl::Status ConvolutionO1H2W2I1Stride1x1Dilation2x2Test(TestExecutionEnvironmen
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@ -127,7 +127,7 @@ absl::Status ConvolutionO1H3W3I1Stride1x1Dilation1x1Test(TestExecutionEnvironmen
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@ -163,7 +163,7 @@ absl::Status ConvolutionO2H1W1I2Stride1x1Dilation1x1Test(TestExecutionEnvironmen
attr.padding.appended = HW(0, 0);
attr.strides = HW(1, 1);
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
@ -199,7 +199,7 @@ absl::Status ConvolutionO1H1W1I1Stride2x2Dilation1x1Test(TestExecutionEnvironmen
attr.padding.appended = HW(0, 0);
attr.strides = HW(2, 2);
for (auto storage : {TensorStorageType::BUFFER, TensorStorageType::IMAGE_BUFFER}) {
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;