TFLite GPU Metal: Removed restriction channels_multiplier == 1.

PiperOrigin-RevId: 247638516
This commit is contained in:
A. Unique TensorFlower 2019-05-10 10:51:47 -07:00 committed by TensorFlower Gardener
parent 6d40b8e722
commit 403417fdc8

View File

@ -468,6 +468,7 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
int id, ValueId input_id, ValueId output_id,
const DepthwiseConvolution2DAttributes& attr,
const RuntimeOptions& options) {
int channels_multiplier = attr.weights.shape.o;
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
@ -503,10 +504,44 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
const bool outside = coords.x < 0 || coords.y < 0 ||
coords.x >= params.size.x || coords.y >= params.size.y;
if (outside) continue;
const int src_layer = gid.z;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
sum0 += float4(src_buffer[src_index]) * float4(temp[ky * kernel_x + kx]);
)";
if (channels_multiplier == 1) {
shader_source += R"(
const int src_layer = gid.z;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src_modified = src_buffer[src_index];
)";
} else if (channels_multiplier == 2) {
shader_source += R"(
const int src_layer = gid.z / 2;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src = src_buffer[src_index];
const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw;
const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
)";
} else if (channels_multiplier == 4) {
shader_source += R"(
const int src_layer = gid.z / 4;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src = src_buffer[src_index];
const FLT t0 = src[gid.z % 4];
const FLT4 src_modified = FLT4(t0, t0, t0, t0);
)";
} else {
shader_source += R"(
const int src_layer = gid.z / params.channel_multiplier.x;
const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
const FLT4 src = src_buffer[src_index];
FLT4 src_modified;
const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4;
src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x];
src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x];
src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x];
src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x];
)";
}
shader_source += R"(
sum0 += float4(src_modified * temp[ky * kernel_x + kx]);
}
}
FLT4 res = FLT4(sum0 + float4(biases[gid.z]));
@ -531,19 +566,7 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
return out_shape;
}};
const int num_output_channels = attr.weights.shape.i * attr.weights.shape.o;
BHWC reordered_dims{1, attr.weights.shape.h, attr.weights.shape.w,
num_output_channels};
std::vector<float> filters_reordered(GetElementsSizeForPHWC4(reordered_dims),
0.0f);
if (!ConvertToPHWC4(
absl::MakeConstSpan(attr.weights.data.data(),
attr.weights.data.size()),
reordered_dims,
absl::MakeSpan(filters_reordered.data(), filters_reordered.size()))
.ok()) {
return {};
}
std::vector<float> filters_reordered = ConvertToPIOHW4(attr.weights);
auto filters = options.storage_precision == RuntimeOptions::Precision::FP32
? VectorToUint8Vector(filters_reordered)
: VectorFloatToHalf(filters_reordered);