ConcatXY converted to new style.
Added support of concatenation in Batch and Depth axis. PiperOrigin-RevId: 316999295 Change-Id: I94f2168f2861790b3a30c79b2b3476aa44c55748
This commit is contained in:
parent
b1933d67e5
commit
274a0f944e
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
||||
@ -27,51 +29,93 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetConcatKernelCode(
|
||||
const OperationDef& op_def, int tensors_count,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
std::vector<TensorCodeGenerator> srcs(tensors_count);
|
||||
for (int i = 0; i < tensors_count; ++i) {
|
||||
const std::string tensor_name = "src_data_" + std::to_string(i);
|
||||
const std::string width = "src_size_" + std::to_string(i) + ".x";
|
||||
const std::string height = "src_size_" + std::to_string(i) + ".y";
|
||||
srcs[i] =
|
||||
TensorCodeGenerator(tensor_name, WHSPoint{width, height, "dst_size.z"},
|
||||
op_def.src_tensors[i]);
|
||||
std::string GetConcatKernelCode(const OperationDef& op_def,
|
||||
const ConcatAttributes& attr, Arguments* args) {
|
||||
std::vector<std::string> tensor_names(op_def.src_tensors.size());
|
||||
for (int i = 0; i < op_def.src_tensors.size(); ++i) {
|
||||
tensor_names[i] = "src_tensor_" + std::to_string(i);
|
||||
args->AddObjectRef(
|
||||
tensor_names[i], AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
}
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
std::map<Axis, std::string> axis_to_selector = {
|
||||
{Axis::WIDTH, "Width"}, {Axis::HEIGHT, "Height"},
|
||||
{Axis::DEPTH, "Depth"}, {Axis::CHANNELS, "Channels"},
|
||||
{Axis::BATCH, "Batch"},
|
||||
};
|
||||
std::map<Axis, std::string> axis_to_coord = {
|
||||
{Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
|
||||
{Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
|
||||
};
|
||||
|
||||
std::vector<std::string> src_coords;
|
||||
std::vector<std::string> dst_coords;
|
||||
for (auto axis :
|
||||
{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS, Axis::BATCH}) {
|
||||
if (op_def.src_tensors[0].HasAxis(axis) && axis != Axis::BATCH) {
|
||||
if (axis == attr.axis) {
|
||||
src_coords.push_back("coord");
|
||||
} else {
|
||||
src_coords.push_back(axis_to_coord[axis]);
|
||||
}
|
||||
}
|
||||
if (op_def.dst_tensors[0].HasAxis(axis)) {
|
||||
dst_coords.push_back(axis_to_coord[axis]);
|
||||
}
|
||||
}
|
||||
std::string src_coord = src_coords[0];
|
||||
for (int i = 1; i < src_coords.size(); ++i) {
|
||||
src_coord += ", " + src_coords[i];
|
||||
}
|
||||
std::string dst_coord = dst_coords[0];
|
||||
for (int i = 1; i < dst_coords.size(); ++i) {
|
||||
dst_coord += ", " + dst_coords[i];
|
||||
}
|
||||
TensorCodeGenerator dst("dst_data",
|
||||
WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
c += "__kernel void main_function(\n";
|
||||
for (const auto& src : srcs) {
|
||||
c += src.GetDeclaration(AccessType::READ) + ",\n";
|
||||
c += "$0) {\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id_0 = get_global_id(0);\n";
|
||||
c += " int X = linear_id_0 / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id_0 % args.dst_tensor.Batch();\n";
|
||||
} else {
|
||||
c += " int X = get_global_id(0);\n";
|
||||
}
|
||||
c += dst.GetDeclaration(AccessType::WRITE);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
for (int i = 0; i < tensors_count; ++i) {
|
||||
const std::string uniform_name = "src_size_" + std::to_string(i);
|
||||
c += " int4 " + uniform_name + ",\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
|
||||
c += " int linear_id_1 = get_global_id(1);\n";
|
||||
c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
|
||||
c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
|
||||
} else {
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
}
|
||||
c += " int4 dst_size \n";
|
||||
c += ") {\n";
|
||||
c += " int X = get_global_id(0);\n";
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (Z >= dst_size.z) return;\n";
|
||||
for (int i = 0; i < tensors_count; ++i) {
|
||||
const std::string size_name = "src_size_" + std::to_string(i);
|
||||
c += " if (X < " + size_name + ".x && Y < " + size_name + ".y) { \n";
|
||||
c += " FLT4 result = " + srcs[i].ReadWHS("X", "Y", "Z") + ";\n";
|
||||
c += " int dst_x = X + " + size_name + ".z;\n";
|
||||
c += " int dst_y = Y + " + size_name + ".w;\n";
|
||||
const LinkingContext context{"result", "dst_x", "dst_y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst.WriteWHS("result", "dst_x", "dst_y", "Z");
|
||||
c += " int S = get_global_id(2);\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"S >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " FLT4 result = (FLT4)(0.0f);\n";
|
||||
c += " int coord = " + axis_to_coord[attr.axis] + ";\n";
|
||||
for (int i = 0; i < op_def.src_tensors.size(); ++i) {
|
||||
const std::string field =
|
||||
"args." + tensor_names[i] + "." + axis_to_selector[attr.axis] + "()";
|
||||
c += " if (coord >= 0 && coord < " + field + ") { \n";
|
||||
if (op_def.src_tensors[i].HasAxis(Axis::BATCH)) {
|
||||
if (attr.axis == Axis::BATCH) {
|
||||
c += " args." + tensor_names[i] + ".SetBatchRef(coord);\n";
|
||||
} else {
|
||||
c += " args." + tensor_names[i] + ".SetBatchRef(B);\n";
|
||||
}
|
||||
}
|
||||
c += " result = args." + tensor_names[i] + ".Read(" + src_coord + ");\n";
|
||||
c += " } \n";
|
||||
c += " coord -= " + field + ";\n";
|
||||
}
|
||||
c += " args.dst_tensor.Write(result, " + dst_coord + ");\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
@ -97,46 +141,32 @@ ConcatXY& ConcatXY::operator=(ConcatXY&& operation) {
|
||||
}
|
||||
|
||||
absl::Status ConcatXY::Compile(const CreationContext& creation_context) {
|
||||
const auto code =
|
||||
GetConcatKernelCode(definition_, tensors_count_, linked_operations_);
|
||||
std::string code = GetConcatKernelCode(definition_, attr_, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status ConcatXY::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
for (int i = 0; i < tensors_count_; ++i) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
int x_offset = 0;
|
||||
int y_offset = 0;
|
||||
for (int i = 0; i < tensors_count_; ++i) {
|
||||
const int width = src_[i]->Width() * src_[i]->Batch();
|
||||
const int height = src_[i]->Height();
|
||||
for (int i = 0; i < definition_.src_tensors.size(); ++i) {
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(int4(width, height, x_offset, y_offset)));
|
||||
x_offset += attr_.axis == Axis::WIDTH ? width : 0;
|
||||
y_offset += attr_.axis == Axis::HEIGHT ? height : 0;
|
||||
args_.SetObjectRef("src_tensor_" + std::to_string(i), src_[i]));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 ConcatXY::GetGridSize() const {
|
||||
int max_src_width = 0;
|
||||
int max_src_height = 0;
|
||||
for (int i = 0; i < tensors_count_; ++i) {
|
||||
max_src_width = std::max(max_src_width, src_[i]->Width());
|
||||
max_src_height = std::max(max_src_height, src_[i]->Height());
|
||||
}
|
||||
|
||||
const int grid_x = max_src_width * dst_[0]->Batch();
|
||||
const int grid_y = max_src_height;
|
||||
const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
|
||||
const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
|
||||
const int grid_z = dst_[0]->Slices();
|
||||
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
|
||||
|
@ -105,8 +105,10 @@ absl::Status SelectConcat(const ConcatAttributes& attr,
|
||||
*ptr = absl::make_unique<ConcatZ>(std::move(operation));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case Axis::WIDTH:
|
||||
case Axis::HEIGHT: {
|
||||
case Axis::BATCH:
|
||||
case Axis::DEPTH:
|
||||
case Axis::HEIGHT:
|
||||
case Axis::WIDTH: {
|
||||
ConcatXY operation = CreateConcatXY(op_def, attr, channels.size());
|
||||
*ptr = absl::make_unique<ConcatXY>(std::move(operation));
|
||||
return absl::OkStatus();
|
||||
|
Loading…
x
Reference in New Issue
Block a user