TFLite GPU Metal: Removed restriction channels_multiplier == 1.
PiperOrigin-RevId: 247638516
This commit is contained in:
parent
6d40b8e722
commit
403417fdc8
@ -468,6 +468,7 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
|
||||
int id, ValueId input_id, ValueId output_id,
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
const RuntimeOptions& options) {
|
||||
int channels_multiplier = attr.weights.shape.o;
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = false;
|
||||
@ -503,10 +504,44 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
|
||||
const bool outside = coords.x < 0 || coords.y < 0 ||
|
||||
coords.x >= params.size.x || coords.y >= params.size.y;
|
||||
if (outside) continue;
|
||||
|
||||
const int src_layer = gid.z;
|
||||
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
|
||||
sum0 += float4(src_buffer[src_index]) * float4(temp[ky * kernel_x + kx]);
|
||||
)";
|
||||
if (channels_multiplier == 1) {
|
||||
shader_source += R"(
|
||||
const int src_layer = gid.z;
|
||||
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
|
||||
const FLT4 src_modified = src_buffer[src_index];
|
||||
)";
|
||||
} else if (channels_multiplier == 2) {
|
||||
shader_source += R"(
|
||||
const int src_layer = gid.z / 2;
|
||||
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
|
||||
const FLT4 src = src_buffer[src_index];
|
||||
const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw;
|
||||
const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
|
||||
)";
|
||||
} else if (channels_multiplier == 4) {
|
||||
shader_source += R"(
|
||||
const int src_layer = gid.z / 4;
|
||||
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
|
||||
const FLT4 src = src_buffer[src_index];
|
||||
const FLT t0 = src[gid.z % 4];
|
||||
const FLT4 src_modified = FLT4(t0, t0, t0, t0);
|
||||
)";
|
||||
} else {
|
||||
shader_source += R"(
|
||||
const int src_layer = gid.z / params.channel_multiplier.x;
|
||||
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
|
||||
const FLT4 src = src_buffer[src_index];
|
||||
FLT4 src_modified;
|
||||
const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4;
|
||||
src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x];
|
||||
src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x];
|
||||
src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x];
|
||||
src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x];
|
||||
)";
|
||||
}
|
||||
shader_source += R"(
|
||||
sum0 += float4(src_modified * temp[ky * kernel_x + kx]);
|
||||
}
|
||||
}
|
||||
FLT4 res = FLT4(sum0 + float4(biases[gid.z]));
|
||||
@ -531,19 +566,7 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
|
||||
return out_shape;
|
||||
}};
|
||||
|
||||
const int num_output_channels = attr.weights.shape.i * attr.weights.shape.o;
|
||||
BHWC reordered_dims{1, attr.weights.shape.h, attr.weights.shape.w,
|
||||
num_output_channels};
|
||||
std::vector<float> filters_reordered(GetElementsSizeForPHWC4(reordered_dims),
|
||||
0.0f);
|
||||
if (!ConvertToPHWC4(
|
||||
absl::MakeConstSpan(attr.weights.data.data(),
|
||||
attr.weights.data.size()),
|
||||
reordered_dims,
|
||||
absl::MakeSpan(filters_reordered.data(), filters_reordered.size()))
|
||||
.ok()) {
|
||||
return {};
|
||||
}
|
||||
std::vector<float> filters_reordered = ConvertToPIOHW4(attr.weights);
|
||||
auto filters = options.storage_precision == RuntimeOptions::Precision::FP32
|
||||
? VectorToUint8Vector(filters_reordered)
|
||||
: VectorFloatToHalf(filters_reordered);
|
||||
|
Loading…
Reference in New Issue
Block a user