Batch support for ConvBuffer.
PiperOrigin-RevId: 273381243
This commit is contained in:
parent
2ca39c62f1
commit
05a5da9097
@ -29,11 +29,16 @@ namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GenerateConvBuffer(
|
||||
const OperationDef& op_def, int x_elements, int y_elements,
|
||||
const OperationDef& op_def, bool stride_correction, int x_elements,
|
||||
int y_elements,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
TensorCodeGenerator src_tensor("src_data", "src_size", op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor("dst_data", "dst_size", op_def.dst_tensors[0]);
|
||||
TensorCodeGenerator src_tensor("src_data",
|
||||
{"src_size.x", "src_size.y", "src_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor("dst_data",
|
||||
{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
|
||||
switch (op_def.precision) {
|
||||
case CalculationsPrecision::F32:
|
||||
@ -77,8 +82,8 @@ std::string GenerateConvBuffer(
|
||||
c += " int X = get_global_id(0) * " + std::to_string(x_elements) + ";\n";
|
||||
c += " int Y = get_global_id(1) * " + std::to_string(y_elements) + ";\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) return;\n";
|
||||
c += " __global FLT16* temp = filters_buffer + Z * src_size.w * "
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
|
||||
c += " __global FLT16* temp = filters_buffer + Z * src_size.z * "
|
||||
"kernel_size.x * kernel_size.y;\n";
|
||||
c += " ACCUM_FLT4 bias_val = TO_ACCUM_TYPE(biases[Z]);\n";
|
||||
for (int i = 0; i < x_elements * y_elements; ++i) {
|
||||
@ -86,7 +91,14 @@ std::string GenerateConvBuffer(
|
||||
}
|
||||
for (int x = 0; x < x_elements; ++x) {
|
||||
std::string x_s = std::to_string(x);
|
||||
c += " int xc" + x_s + " = (X + " + x_s + ") * stride.x + padding.x;\n";
|
||||
if (stride_correction) {
|
||||
c += " int xc" + x_s + " = " +
|
||||
GetXStrideCorrected("X + " + x_s, "src_size.w", "stride.x",
|
||||
"padding.x") +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int xc" + x_s + " = (X + " + x_s + ") * stride.x + padding.x;\n";
|
||||
}
|
||||
}
|
||||
for (int y = 0; y < y_elements; ++y) {
|
||||
std::string y_s = std::to_string(y);
|
||||
@ -117,7 +129,7 @@ std::string GenerateConvBuffer(
|
||||
"x;\n";
|
||||
}
|
||||
}
|
||||
c += " for (int s = 0; s < src_size.w; ++s) {\n";
|
||||
c += " for (int s = 0; s < src_size.z; ++s) {\n";
|
||||
for (int x = 0; x < x_elements; ++x) {
|
||||
std::string x_s = std::to_string(x);
|
||||
for (int y = 0; y < y_elements; ++y) {
|
||||
@ -134,10 +146,10 @@ std::string GenerateConvBuffer(
|
||||
}
|
||||
for (int i = 0; i < x_elements * y_elements; ++i) {
|
||||
std::string i_s = std::to_string(i);
|
||||
c += " src_addr_" + i_s + " += src_size.z;\n";
|
||||
c += " src_addr_" + i_s + " += src_size.x * src_size.y;\n";
|
||||
}
|
||||
c += " temp += 1;\n";
|
||||
c += " }\n"; // src_size.w - SRC_DEPTH
|
||||
c += " }\n"; // src_size.z - SRC_DEPTH
|
||||
c += " }\n"; // kernel_size.x
|
||||
c += " }\n"; // kernel_size.y
|
||||
|
||||
@ -204,8 +216,10 @@ ConvBuffer& ConvBuffer::operator=(ConvBuffer&& operation) {
|
||||
}
|
||||
|
||||
Status ConvBuffer::Compile(const CreationContext& creation_context) {
|
||||
std::string code = GenerateConvBuffer(definition_, x_elements_, y_elements_,
|
||||
linked_operations_);
|
||||
const bool stride_correction = definition_.batch_support && stride_.x != 1;
|
||||
const std::string code =
|
||||
GenerateConvBuffer(definition_, stride_correction, x_elements_,
|
||||
y_elements_, linked_operations_);
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
@ -218,21 +232,20 @@ Status ConvBuffer::BindArguments() {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
int4 src_size = int4(src_[0]->Width(), src_[0]->Height(),
|
||||
src_[0]->Width() * src_[0]->Height(), src_[0]->Depth());
|
||||
int4 dst_size = int4(dst_[0]->Width(), dst_[0]->Height(),
|
||||
dst_[0]->Width() * dst_[0]->Height(), dst_[0]->Depth());
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_size));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_size));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dilation_));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(int2(dilation_.x * src_[0]->Batch(), dilation_.y)));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
int3 ConvBuffer::GetGridSize() const {
|
||||
const int grid_x = IntegralDivideRoundUp(dst_[0]->Width(), x_elements_);
|
||||
const int grid_x =
|
||||
IntegralDivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), x_elements_);
|
||||
const int grid_y = IntegralDivideRoundUp(dst_[0]->Height(), y_elements_);
|
||||
const int grid_z = dst_[0]->Depth();
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
|
Loading…
x
Reference in New Issue
Block a user