Added support of runtime weights to ConvolutionMetal.

PiperOrigin-RevId: 357252868
Change-Id: Ic707011251e8a8bdac01f5db09c99f20f687b411
This commit is contained in:
Raman Sarokin 2021-02-12 13:06:33 -08:00 committed by TensorFlower Gardener
parent aa9bd19fe3
commit c8f15be4c1
2 changed files with 27 additions and 12 deletions

View File

@ -185,7 +185,14 @@ std::unique_ptr<GPUOperation> SelectConvolutionWithDynamicWeights(
const BHWC& dst_shape, const GpuInfo& gpu_info, const BHWC& dst_shape, const GpuInfo& gpu_info,
const OperationDef& op_def, ModelHints hints, const OperationDef& op_def, ModelHints hints,
WeightsDescription* weights_desc) { WeightsDescription* weights_desc) {
if (gpu_info.IsAdreno()) { if (gpu_info.IsApiMetal() && IsConvolutionMetalSupported(op_def)) {
Convolution2DAttributes attr_copy = attr;
attr_copy.weights.shape = OHWI(weights_shape.b, weights_shape.h,
weights_shape.w, weights_shape.c);
ConvolutionMetal conv =
CreateConvolutionMetal(op_def, dst_shape, attr_copy, gpu_info);
return absl::make_unique<ConvolutionMetal>(std::move(conv));
} else if (gpu_info.IsAdreno()) {
return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape, return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape,
gpu_info, op_def, hints, gpu_info, op_def, hints,
weights_desc); weights_desc);

View File

@ -1027,15 +1027,24 @@ ConvolutionMetal CreateConvolutionMetal(const OperationDef& definition,
? MemoryType::CONSTANT ? MemoryType::CONSTANT
: MemoryType::GLOBAL; : MemoryType::GLOBAL;
BufferDescriptor weights_desc; if (definition.src_tensors.size() == 2) {
weights_desc.element_type = weights_type; // dynamic weights
weights_desc.element_size = 4; BufferDescriptor weights_desc;
weights_desc.memory_type = mem_type; weights_desc.element_type = definition.src_tensors[1].data_type;
weights_desc.data = ReorderWeightsForConv( weights_desc.element_size = 4;
attr.weights, desc.GetWeightsDescription(), weights_type); weights_desc.memory_type = mem_type;
weights_desc.size = weights_desc.data.size(); desc.AddSrcBuffer("weights", weights_desc);
desc.args_.AddObject( } else {
"weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc))); BufferDescriptor weights_desc;
weights_desc.element_type = weights_type;
weights_desc.element_size = 4;
weights_desc.memory_type = mem_type;
weights_desc.data = ReorderWeightsForConv(
attr.weights, desc.GetWeightsDescription(), weights_type);
weights_desc.size = weights_desc.data.size();
desc.args_.AddObject("weights", absl::make_unique<BufferDescriptor>(
std::move(weights_desc)));
}
BufferDescriptor bias_desc; BufferDescriptor bias_desc;
bias_desc.element_type = weights_type; bias_desc.element_type = weights_type;
@ -1179,8 +1188,7 @@ ConvolutionMetal CreateConvolutionMetalWino4x4To6x6(
} }
bool IsConvolutionMetalSupported(const OperationDef& definition) { bool IsConvolutionMetalSupported(const OperationDef& definition) {
return definition.src_tensors.size() == 1 && return !definition.src_tensors[0].HasAxis(Axis::DEPTH);
!definition.src_tensors[0].HasAxis(Axis::DEPTH);
} }
} // namespace gpu } // namespace gpu