Batch support for ConvBuffer.

PiperOrigin-RevId: 273381243
This commit is contained in:
A. Unique TensorFlower 2019-10-07 14:48:28 -07:00 committed by TensorFlower Gardener
parent 2ca39c62f1
commit 05a5da9097

View File

@ -29,11 +29,16 @@ namespace cl {
namespace {
std::string GenerateConvBuffer(
const OperationDef& op_def, int x_elements, int y_elements,
const OperationDef& op_def, bool stride_correction, int x_elements,
int y_elements,
const std::vector<ElementwiseOperation*>& linked_operations) {
std::string c = GetCommonDefines(op_def.precision);
TensorCodeGenerator src_tensor("src_data", "src_size", op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data", "dst_size", op_def.dst_tensors[0]);
TensorCodeGenerator src_tensor("src_data",
{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data",
{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
switch (op_def.precision) {
case CalculationsPrecision::F32:
@ -77,8 +82,8 @@ std::string GenerateConvBuffer(
c += " int X = get_global_id(0) * " + std::to_string(x_elements) + ";\n";
c += " int Y = get_global_id(1) * " + std::to_string(y_elements) + ";\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) return;\n";
c += " __global FLT16* temp = filters_buffer + Z * src_size.w * "
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
c += " __global FLT16* temp = filters_buffer + Z * src_size.z * "
"kernel_size.x * kernel_size.y;\n";
c += " ACCUM_FLT4 bias_val = TO_ACCUM_TYPE(biases[Z]);\n";
for (int i = 0; i < x_elements * y_elements; ++i) {
@ -86,7 +91,14 @@ std::string GenerateConvBuffer(
}
for (int x = 0; x < x_elements; ++x) {
std::string x_s = std::to_string(x);
c += " int xc" + x_s + " = (X + " + x_s + ") * stride.x + padding.x;\n";
if (stride_correction) {
c += " int xc" + x_s + " = " +
GetXStrideCorrected("X + " + x_s, "src_size.w", "stride.x",
"padding.x") +
";\n";
} else {
c += " int xc" + x_s + " = (X + " + x_s + ") * stride.x + padding.x;\n";
}
}
for (int y = 0; y < y_elements; ++y) {
std::string y_s = std::to_string(y);
@ -117,7 +129,7 @@ std::string GenerateConvBuffer(
"x;\n";
}
}
c += " for (int s = 0; s < src_size.w; ++s) {\n";
c += " for (int s = 0; s < src_size.z; ++s) {\n";
for (int x = 0; x < x_elements; ++x) {
std::string x_s = std::to_string(x);
for (int y = 0; y < y_elements; ++y) {
@ -134,10 +146,10 @@ std::string GenerateConvBuffer(
}
for (int i = 0; i < x_elements * y_elements; ++i) {
std::string i_s = std::to_string(i);
c += " src_addr_" + i_s + " += src_size.z;\n";
c += " src_addr_" + i_s + " += src_size.x * src_size.y;\n";
}
c += " temp += 1;\n";
c += " }\n"; // src_size.w - SRC_DEPTH
c += " }\n"; // src_size.z - SRC_DEPTH
c += " }\n"; // kernel_size.x
c += " }\n"; // kernel_size.y
@ -204,8 +216,10 @@ ConvBuffer& ConvBuffer::operator=(ConvBuffer&& operation) {
}
Status ConvBuffer::Compile(const CreationContext& creation_context) {
std::string code = GenerateConvBuffer(definition_, x_elements_, y_elements_,
linked_operations_);
const bool stride_correction = definition_.batch_support && stride_.x != 1;
const std::string code =
GenerateConvBuffer(definition_, stride_correction, x_elements_,
y_elements_, linked_operations_);
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
@ -218,21 +232,20 @@ Status ConvBuffer::BindArguments() {
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
int4 src_size = int4(src_[0]->Width(), src_[0]->Height(),
src_[0]->Width() * src_[0]->Height(), src_[0]->Depth());
int4 dst_size = int4(dst_[0]->Width(), dst_[0]->Height(),
dst_[0]->Width() * dst_[0]->Height(), dst_[0]->Depth());
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_size));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_size));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dilation_));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int2(dilation_.x * src_[0]->Batch(), dilation_.y)));
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
return OkStatus();
}
int3 ConvBuffer::GetGridSize() const {
const int grid_x = IntegralDivideRoundUp(dst_[0]->Width(), x_elements_);
const int grid_x =
IntegralDivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), x_elements_);
const int grid_y = IntegralDivideRoundUp(dst_[0]->Height(), y_elements_);
const int grid_z = dst_[0]->Depth();
return int3(grid_x, grid_y, grid_z);