Added support of runtime weights to ConvolutionMetal.
PiperOrigin-RevId: 357252868 Change-Id: Ic707011251e8a8bdac01f5db09c99f20f687b411
This commit is contained in:
parent
aa9bd19fe3
commit
c8f15be4c1
@ -185,7 +185,14 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights(
|
||||
const BHWC& dst_shape, const GpuInfo& gpu_info,
|
||||
const OperationDef& op_def, ModelHints hints,
|
||||
WeightsDescription* weights_desc) {
|
||||
if (gpu_info.IsAdreno()) {
|
||||
if (gpu_info.IsApiMetal() && IsConvolutionMetalSupported(op_def)) {
|
||||
Convolution2DAttributes attr_copy = attr;
|
||||
attr_copy.weights.shape = OHWI(weights_shape.b, weights_shape.h,
|
||||
weights_shape.w, weights_shape.c);
|
||||
ConvolutionMetal conv =
|
||||
CreateConvolutionMetal(op_def, dst_shape, attr_copy, gpu_info);
|
||||
return absl::make_unique<ConvolutionMetal>(std::move(conv));
|
||||
} else if (gpu_info.IsAdreno()) {
|
||||
return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape,
|
||||
gpu_info, op_def, hints,
|
||||
weights_desc);
|
||||
|
@ -1027,15 +1027,24 @@ ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
|
||||
? MemoryType::CONSTANT
|
||||
: MemoryType::GLOBAL;
|
||||
|
||||
BufferDescriptor weights_desc;
|
||||
weights_desc.element_type = weights_type;
|
||||
weights_desc.element_size = 4;
|
||||
weights_desc.memory_type = mem_type;
|
||||
weights_desc.data = ReorderWeightsForConv(
|
||||
attr.weights, desc.GetWeightsDescription(), weights_type);
|
||||
weights_desc.size = weights_desc.data.size();
|
||||
desc.args_.AddObject(
|
||||
"weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
|
||||
if (definition.src_tensors.size() == 2) {
|
||||
// dynamic weights
|
||||
BufferDescriptor weights_desc;
|
||||
weights_desc.element_type = definition.src_tensors[1].data_type;
|
||||
weights_desc.element_size = 4;
|
||||
weights_desc.memory_type = mem_type;
|
||||
desc.AddSrcBuffer("weights", weights_desc);
|
||||
} else {
|
||||
BufferDescriptor weights_desc;
|
||||
weights_desc.element_type = weights_type;
|
||||
weights_desc.element_size = 4;
|
||||
weights_desc.memory_type = mem_type;
|
||||
weights_desc.data = ReorderWeightsForConv(
|
||||
attr.weights, desc.GetWeightsDescription(), weights_type);
|
||||
weights_desc.size = weights_desc.data.size();
|
||||
desc.args_.AddObject("weights", absl::make_unique<BufferDescriptor>(
|
||||
std::move(weights_desc)));
|
||||
}
|
||||
|
||||
BufferDescriptor bias_desc;
|
||||
bias_desc.element_type = weights_type;
|
||||
@ -1179,8 +1188,7 @@ ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
|
||||
}
|
||||
|
||||
bool IsConvolutionMetalSupported(const OperationDef& definition) {
|
||||
return definition.src_tensors.size() == 1 &&
|
||||
!definition.src_tensors[0].HasAxis(Axis::DEPTH);
|
||||
return !definition.src_tensors[0].HasAxis(Axis::DEPTH);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
Loading…
Reference in New Issue
Block a user