Batch support for Padding.

PiperOrigin-RevId: 272992513
This commit is contained in:
A. Unique TensorFlower 2019-10-04 17:44:22 -07:00 committed by TensorFlower Gardener
parent 8c650bc738
commit c97f235ee8

View File

@ -28,54 +28,57 @@ namespace {
std::string GetPaddingCode(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor("src_data", "src_size", op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data", "dst_size", op_def.dst_tensors[0]);
TensorCodeGenerator src_tensor("src_data",
{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data",
{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
std::string code = GetCommonDefines(op_def.precision);
std::string c = GetCommonDefines(op_def.precision);
const std::string channels[] = {".x", ".y", ".z", ".w"};
code += "__kernel void main_function(\n";
code += src_tensor.GetDeclaration(AccessType::READ);
code += GetArgsDeclaration(linked_operations);
code += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
code += " int4 src_size, \n";
code += " int4 dst_size, \n";
code += " int4 prepended \n";
code += ") {\n";
code += " int X = get_global_id(0);\n";
code += " int Y = get_global_id(1);\n";
code += " int Z = get_global_id(2);\n";
code +=
" if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) return; \n";
code += " FLT4 result = (FLT4)(0.0);\n";
code += " int s_x = X - prepended.x;\n";
code += " int s_y = Y - prepended.y;\n";
code += " bool inside_x = s_x >= 0 && s_x < src_size.x;\n";
code += " bool inside_y = s_y >= 0 && s_y < src_size.y;\n";
code += " if (inside_x && inside_y) {\n";
code += " int start_channel = Z * 4;\n";
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int src_channels, \n";
c += " int4 dst_size, \n";
c += " int4 prepended \n";
c += ") {\n";
c += " int X = get_global_id(0);\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
c += " FLT4 result = (FLT4)(0.0);\n";
c += " int s_x = X - prepended.x;\n";
c += " int s_y = Y - prepended.y;\n";
c += " bool inside_x = s_x >= 0 && s_x < src_size.x;\n";
c += " bool inside_y = s_y >= 0 && s_y < src_size.y;\n";
c += " if (inside_x && inside_y) {\n";
c += " int start_channel = Z * 4;\n";
for (int i = 0; i < 4; ++i) {
const auto& s = channels[i];
code += " {\n";
code += " int channel = start_channel + " + std::to_string(i) + ";\n";
code += " int s_z = channel - prepended.z;\n";
code += " if (s_z >= 0 && s_z < src_size.z) {\n";
code += " FLT4 t = " +
src_tensor.Read3D("s_x", "s_y", "s_z / 4",
TextureAddressMode::DONT_CARE) +
";\n";
code += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
code += " result" + s + " = t_ar[s_z % 4];\n";
code += " }\n";
code += " }\n";
c += " {\n";
c += " int channel = start_channel + " + std::to_string(i) + ";\n";
c += " int s_z = channel - prepended.z;\n";
c += " if (s_z >= 0 && s_z < src_channels) {\n";
c += " FLT4 t = " +
src_tensor.Read3D("s_x", "s_y", "s_z / 4",
TextureAddressMode::DONT_CARE) +
";\n";
c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
c += " result" + s + " = t_ar[s_z % 4];\n";
c += " }\n";
c += " }\n";
}
code += " }\n";
const LinkingContext context{"result", "X", "Y", "Z"};
code += PostProcess(linked_operations, context);
code += " " + dst_tensor.Write3D("result", "X", "Y", "Z");
code += "}\n";
c += " }\n";
c += PostProcess(linked_operations, {"result", "X", "Y", "Z"});
c += " " + dst_tensor.Write3D("result", "X", "Y", "Z");
c += "}\n";
return code;
return c;
}
} // namespace
@ -119,14 +122,17 @@ Status Padding::BindArguments() {
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(prepended_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int4(prepended_.x * src_[0]->Batch(), prepended_.y,
prepended_.z, prepended_.w)));
return OkStatus();
}
int3 Padding::GetGridSize() const {
const int grid_x = dst_[0]->Width();
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
const int grid_y = dst_[0]->Height();
const int grid_z = dst_[0]->Depth();
return int3(grid_x, grid_y, grid_z);