Added support of runtime weights to ConvolutionMetal.

PiperOrigin-RevId: 357252868
Change-Id: Ic707011251e8a8bdac01f5db09c99f20f687b411
This commit is contained in:
Raman Sarokin 2021-02-12 13:06:33 -08:00 committed by TensorFlower Gardener
parent aa9bd19fe3
commit c8f15be4c1
2 changed files with 27 additions and 12 deletions

View File

@ -185,7 +185,14 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights(
const BHWC& dst_shape, const GpuInfo& gpu_info,
const OperationDef& op_def, ModelHints hints,
WeightsDescription* weights_desc) {
if (gpu_info.IsAdreno()) {
if (gpu_info.IsApiMetal() && IsConvolutionMetalSupported(op_def)) {
Convolution2DAttributes attr_copy = attr;
attr_copy.weights.shape = OHWI(weights_shape.b, weights_shape.h,
weights_shape.w, weights_shape.c);
ConvolutionMetal conv =
CreateConvolutionMetal(op_def, dst_shape, attr_copy, gpu_info);
return absl::make_unique<ConvolutionMetal>(std::move(conv));
} else if (gpu_info.IsAdreno()) {
return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape,
gpu_info, op_def, hints,
weights_desc);

View File

@ -1027,15 +1027,24 @@ ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
? MemoryType::CONSTANT
: MemoryType::GLOBAL;
BufferDescriptor weights_desc;
weights_desc.element_type = weights_type;
weights_desc.element_size = 4;
weights_desc.memory_type = mem_type;
weights_desc.data = ReorderWeightsForConv(
attr.weights, desc.GetWeightsDescription(), weights_type);
weights_desc.size = weights_desc.data.size();
desc.args_.AddObject(
"weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
if (definition.src_tensors.size() == 2) {
// dynamic weights
BufferDescriptor weights_desc;
weights_desc.element_type = definition.src_tensors[1].data_type;
weights_desc.element_size = 4;
weights_desc.memory_type = mem_type;
desc.AddSrcBuffer("weights", weights_desc);
} else {
BufferDescriptor weights_desc;
weights_desc.element_type = weights_type;
weights_desc.element_size = 4;
weights_desc.memory_type = mem_type;
weights_desc.data = ReorderWeightsForConv(
attr.weights, desc.GetWeightsDescription(), weights_type);
weights_desc.size = weights_desc.data.size();
desc.args_.AddObject("weights", absl::make_unique<BufferDescriptor>(
std::move(weights_desc)));
}
BufferDescriptor bias_desc;
bias_desc.element_type = weights_type;
@ -1179,8 +1188,7 @@ ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
}
bool IsConvolutionMetalSupported(const OperationDef& definition) {
return definition.src_tensors.size() == 1 &&
!definition.src_tensors[0].HasAxis(Axis::DEPTH);
return !definition.src_tensors[0].HasAxis(Axis::DEPTH);
}
} // namespace gpu