Added Write method for TensorDescriptor.
Showed usage in Transpose kernel. PiperOrigin-RevId: 314339905 Change-Id: Ic8aa401956f15ec952ddb880962598544fdb4815
This commit is contained in:
parent
2c59f45f8f
commit
d1611593b0
|
@ -31,37 +31,34 @@ std::string GetTransposeCode(
|
|||
const OperationDef& op_def, const TransposeAttributes& attr,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||
Arguments* args) {
|
||||
TensorCodeGenerator dst_tensor("dst_data",
|
||||
WHSBPoint{"args.dst_width", "args.dst_height",
|
||||
"args.dst_slices", "args.dst_batch"},
|
||||
op_def.dst_tensors[0]);
|
||||
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddInt("dst_width");
|
||||
args->AddInt("dst_height");
|
||||
args->AddInt("dst_slices");
|
||||
args->AddInt("dst_batch");
|
||||
args->AddInt("dst_channels");
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
std::string linked_args = GetArgsDeclaration(linked_operations);
|
||||
if (linked_args[0] == ',') {
|
||||
linked_args[0] = ' ';
|
||||
}
|
||||
c += "__kernel void main_function(\n";
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += linked_args;
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / args.dst_batch;\n";
|
||||
c += " int B = linear_id % args.dst_batch;\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
c += " args.dst_tensor.SetBatchRef(B);\n";
|
||||
} else {
|
||||
c += " int X = get_global_id(0);\n";
|
||||
}
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " if (X >= args.dst_width || Y >= args.dst_height || Z >= "
|
||||
"args.dst_slices) { \n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
|
||||
"Z >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
c += " FLT temps[4];\n";
|
||||
|
@ -89,7 +86,7 @@ std::string GetTransposeCode(
|
|||
} else {
|
||||
c += " for (int i = 0; i < 4; ++i) {\n";
|
||||
c += " int dst_channel = Z * 4 + i;\n";
|
||||
c += " if (dst_channel < args.dst_channels) {\n";
|
||||
c += " if (dst_channel < args.dst_tensor.Channels()) {\n";
|
||||
const std::string bhwc[] = {"B", "Y", "X", "dst_channel"};
|
||||
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " args.src_tensor.SetBatchRef(" + bhwc[remap[0]] + ");\n";
|
||||
|
@ -107,10 +104,10 @@ std::string GetTransposeCode(
|
|||
}
|
||||
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
|
||||
std::string x_3dcoord =
|
||||
op_def.IsBatchSupported() ? "X * args.dst_batch + B" : "X";
|
||||
op_def.IsBatchSupported() ? "X * args.dst_tensor.Batch() + B" : "X";
|
||||
const LinkingContext context{"result", x_3dcoord, "Y", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id);
|
||||
c += " args.dst_tensor.Write(result, X, Y, Z);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
@ -146,13 +143,8 @@ absl::Status Transpose::Compile(const CreationContext& creation_context) {
|
|||
|
||||
absl::Status Transpose::BindArguments() {
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetInt("dst_width", dst_[0]->Width()));
|
||||
RETURN_IF_ERROR(args_.SetInt("dst_height", dst_[0]->Height()));
|
||||
RETURN_IF_ERROR(args_.SetInt("dst_slices", dst_[0]->Slices()));
|
||||
RETURN_IF_ERROR(args_.SetInt("dst_batch", dst_[0]->Batch()));
|
||||
RETURN_IF_ERROR(args_.SetInt("dst_channels", dst_[0]->Channels()));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -136,6 +136,17 @@ std::string GetReadImageFromDataType(DataType data_type) {
|
|||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetWriteImageFromDataType(DataType data_type) {
|
||||
if (data_type == DataType::FLOAT32) {
|
||||
return "write_imagef";
|
||||
} else if (data_type == DataType::FLOAT16) {
|
||||
return "write_imageh";
|
||||
} else {
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string ToString(TensorStorageType type) {
|
||||
|
@ -245,9 +256,11 @@ absl::Status TensorDescriptor::PerformSelector(
|
|||
return absl::OkStatus();
|
||||
} else if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
} else if (selector == "Write") {
|
||||
return PerformWriteSelector(args, result);
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"TensorLinearDescriptor don't have selector with name - ", selector));
|
||||
"TensorDescriptor don't have selector with name - ", selector));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -283,6 +296,39 @@ absl::Status TensorDescriptor::PerformReadSelector(
|
|||
}
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformWriteSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
std::string xc;
|
||||
std::string yc;
|
||||
std::string zc;
|
||||
std::string sc;
|
||||
std::string bc;
|
||||
bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
|
||||
if (args.size() < 2 || !parsed) {
|
||||
return absl::NotFoundError("Unrecognized Write selector");
|
||||
}
|
||||
|
||||
if (layout == Layout::HWC) {
|
||||
*result = Write(args[0],
|
||||
GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::BHWC) {
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc,
|
||||
storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::HWDC) {
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc,
|
||||
storage_type));
|
||||
return absl::OkStatus();
|
||||
} else if (layout == Layout::BHWDC) {
|
||||
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB(
|
||||
xc, yc, zc, sc, bc, storage_type));
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::NotFoundError("Unsupported layout");
|
||||
}
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::Read(const std::string& global_address) const {
|
||||
std::string image_type;
|
||||
if (storage_type == TensorStorageType::TEXTURE_2D ||
|
||||
|
@ -310,6 +356,32 @@ std::string TensorDescriptor::Read(const std::string& global_address) const {
|
|||
}
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::Write(const std::string& var_name,
|
||||
const std::string& global_address) const {
|
||||
std::string image_type;
|
||||
if (storage_type == TensorStorageType::TEXTURE_2D ||
|
||||
storage_type == TensorStorageType::SINGLE_TEXTURE_2D) {
|
||||
image_type = "image2d";
|
||||
} else if (storage_type == TensorStorageType::TEXTURE_3D) {
|
||||
image_type = "image3d";
|
||||
} else if (storage_type == TensorStorageType::TEXTURE_ARRAY) {
|
||||
image_type = "image2d_array";
|
||||
}
|
||||
switch (storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::StrCat("buffer[", global_address, "] = ", var_name, ";\n");
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
return absl::StrCat(GetWriteImageFromDataType(data_type), "(", image_type,
|
||||
", ", global_address, ", ", var_name, ");\n");
|
||||
case TensorStorageType::UNKNOWN:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorDescriptor::HasAxis(Axis axis) const {
|
||||
if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
|
||||
return true;
|
||||
|
|
|
@ -82,7 +82,12 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
|||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
absl::Status PerformWriteSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
std::string Read(const std::string& global_address) const;
|
||||
std::string Write(const std::string& var_name,
|
||||
const std::string& global_address) const;
|
||||
|
||||
bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset,
|
||||
std::string* xc, std::string* yc, std::string* zc,
|
||||
|
|
Loading…
Reference in New Issue