Reusing optimized ConvolutionTransposed4x4 for AMD.

PiperOrigin-RevId: 293205441
Change-Id: I1c5e7ec9891b3bee8eb83a7bde582fb4602a46b2
This commit is contained in:
Raman Sarokin 2020-02-04 12:38:03 -08:00 committed by TensorFlower Gardener
parent 867081522a
commit f6768dc78d
3 changed files with 14 additions and 16 deletions

View File

@ -69,14 +69,18 @@ std::string GenerateConvolutionTransposedCode(
break;
}
const std::string weights_space =
weights_upload_type ==
ConvolutionTransposed4x4::WeightsUploadType::CONSTANT_MEM
? "__constant"
: "__global";
const std::string pixel_stride =
op_def.IsBatchSupported() ? "dst_size.w" : "1";
if (need_local_mem) { // we use fixed workgroup size when use local mem
c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
}
c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
c += " __global FLT4* filters,\n";
c += " " + weights_space + " FLT4* filters,\n";
c += " __read_only image2d_t biases";
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
@ -181,7 +185,8 @@ std::string GenerateConvolutionTransposedCode(
c += " weights_cache[local_id + 32] = filters[f_offset + local_id + "
"32];\n";
} else { // GLOBAL_MEM
c += " __global FLT4* weights_cache = filters + f_offset;\n";
c +=
" " + weights_space + " FLT4* weights_cache = filters + f_offset;\n";
}
c += " FLT4 src0 = " + read_src(0, 0) + ";\n";
c += " FLT4 src1 = " + read_src(1, 0) + ";\n";
@ -267,6 +272,8 @@ ConvolutionTransposed4x4::ConvolutionTransposed4x4(
weights_upload_type_ = WeightsUploadType::LOCAL_MEM_ASYNC;
} else if (device.IsNvidia()) {
weights_upload_type_ = WeightsUploadType::LOCAL_MEM_BY_THREADS;
} else if (device.IsAMD()) {
weights_upload_type_ = WeightsUploadType::CONSTANT_MEM;
} else {
weights_upload_type_ = WeightsUploadType::GLOBAL_MEM;
}
@ -334,16 +341,6 @@ int3 ConvolutionTransposed4x4::GetGridSize() const {
return int3(grid_x, grid_y, grid_z);
}
Status ConvolutionTransposed4x4::Tune(const TuningParameters& params) {
if (weights_upload_type_ == WeightsUploadType::LOCAL_MEM_ASYNC ||
weights_upload_type_ == WeightsUploadType::LOCAL_MEM_BY_THREADS) {
return OkStatus();
}
RETURN_IF_ERROR(BindArguments());
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
&work_group_size_);
}
Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) {
RETURN_IF_ERROR(BindArguments());
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);

View File

@ -38,7 +38,6 @@ class ConvolutionTransposed4x4 : public GPUOperation {
public:
ConvolutionTransposed4x4() = default;
Status AddToQueue(CLCommandQueue* queue) override;
Status Tune(const TuningParameters& params) override;
Status Compile(const CreationContext& creation_context) override;
// Move only
@ -51,6 +50,7 @@ class ConvolutionTransposed4x4 : public GPUOperation {
LOCAL_MEM_ASYNC,
LOCAL_MEM_BY_THREADS,
GLOBAL_MEM,
CONSTANT_MEM,
};
private:

View File

@ -103,6 +103,7 @@ Status SelectConvolutionTransposed(const ConvolutionTransposedAttributes& attr,
ptr);
case Vendor::POWERVR:
case Vendor::NVIDIA:
case Vendor::AMD:
return SelectConvolutionTransposedPowerVR(attr, creation_context, op_def,
ptr);
case Vendor::MALI: