Added Write method for TensorDescriptor.

Showed usage in Transpose kernel.

PiperOrigin-RevId: 314339905
Change-Id: Ic8aa401956f15ec952ddb880962598544fdb4815
This commit is contained in:
Raman Sarokin 2020-06-02 08:58:32 -07:00 committed by TensorFlower Gardener
parent 2c59f45f8f
commit d1611593b0
3 changed files with 95 additions and 26 deletions

View File

@ -31,37 +31,34 @@ std::string GetTransposeCode(
const OperationDef& op_def, const TransposeAttributes& attr,
const std::vector<ElementwiseOperation*>& linked_operations,
Arguments* args) {
TensorCodeGenerator dst_tensor("dst_data",
WHSBPoint{"args.dst_width", "args.dst_height",
"args.dst_slices", "args.dst_batch"},
op_def.dst_tensors[0]);
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddInt("dst_width");
args->AddInt("dst_height");
args->AddInt("dst_slices");
args->AddInt("dst_batch");
args->AddInt("dst_channels");
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
std::string c = GetCommonDefines(op_def.precision);
std::string linked_args = GetArgsDeclaration(linked_operations);
if (linked_args[0] == ',') {
linked_args[0] = ' ';
}
c += "__kernel void main_function(\n";
c += dst_tensor.GetDeclaration(AccessType::WRITE);
c += GetArgsDeclaration(linked_operations);
c += linked_args;
c += "$0) {\n";
if (op_def.IsBatchSupported()) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / args.dst_batch;\n";
c += " int B = linear_id % args.dst_batch;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
} else {
c += " int X = get_global_id(0);\n";
}
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " if (X >= args.dst_width || Y >= args.dst_height || Z >= "
"args.dst_slices) { \n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
"Z >= args.dst_tensor.Slices()) { \n";
c += " return; \n";
c += " } \n";
c += " FLT temps[4];\n";
@ -89,7 +86,7 @@ std::string GetTransposeCode(
} else {
c += " for (int i = 0; i < 4; ++i) {\n";
c += " int dst_channel = Z * 4 + i;\n";
c += " if (dst_channel < args.dst_channels) {\n";
c += " if (dst_channel < args.dst_tensor.Channels()) {\n";
const std::string bhwc[] = {"B", "Y", "X", "dst_channel"};
if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) {
c += " args.src_tensor.SetBatchRef(" + bhwc[remap[0]] + ");\n";
@ -107,10 +104,10 @@ std::string GetTransposeCode(
}
c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
std::string x_3dcoord =
op_def.IsBatchSupported() ? "X * args.dst_batch + B" : "X";
op_def.IsBatchSupported() ? "X * args.dst_tensor.Batch() + B" : "X";
const LinkingContext context{"result", x_3dcoord, "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id);
c += " args.dst_tensor.Write(result, X, Y, Z);\n";
c += "}\n";
return c;
}
@ -146,13 +143,8 @@ absl::Status Transpose::Compile(const CreationContext& creation_context) {
absl::Status Transpose::BindArguments() {
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetInt("dst_width", dst_[0]->Width()));
RETURN_IF_ERROR(args_.SetInt("dst_height", dst_[0]->Height()));
RETURN_IF_ERROR(args_.SetInt("dst_slices", dst_[0]->Slices()));
RETURN_IF_ERROR(args_.SetInt("dst_batch", dst_[0]->Batch()));
RETURN_IF_ERROR(args_.SetInt("dst_channels", dst_[0]->Channels()));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
return absl::OkStatus();

View File

@ -136,6 +136,17 @@ std::string GetReadImageFromDataType(DataType data_type) {
return "error";
}
}
std::string GetWriteImageFromDataType(DataType data_type) {
if (data_type == DataType::FLOAT32) {
return "write_imagef";
} else if (data_type == DataType::FLOAT16) {
return "write_imageh";
} else {
return "error";
}
}
} // namespace
std::string ToString(TensorStorageType type) {
@ -245,9 +256,11 @@ absl::Status TensorDescriptor::PerformSelector(
return absl::OkStatus();
} else if (selector == "Read") {
return PerformReadSelector(args, result);
} else if (selector == "Write") {
return PerformWriteSelector(args, result);
} else {
return absl::NotFoundError(absl::StrCat(
"TensorLinearDescriptor don't have selector with name - ", selector));
"TensorDescriptor don't have selector with name - ", selector));
}
}
@ -283,6 +296,39 @@ absl::Status TensorDescriptor::PerformReadSelector(
}
}
absl::Status TensorDescriptor::PerformWriteSelector(
const std::vector<std::string>& args, std::string* result) const {
std::string xc;
std::string yc;
std::string zc;
std::string sc;
std::string bc;
bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
if (args.size() < 2 || !parsed) {
return absl::NotFoundError("Unrecognized Write selector");
}
if (layout == Layout::HWC) {
*result = Write(args[0],
GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc,
storage_type));
return absl::OkStatus();
} else if (layout == Layout::HWDC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc,
storage_type));
return absl::OkStatus();
} else if (layout == Layout::BHWDC) {
*result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB(
xc, yc, zc, sc, bc, storage_type));
return absl::OkStatus();
} else {
return absl::NotFoundError("Unsupported layout");
}
}
std::string TensorDescriptor::Read(const std::string& global_address) const {
std::string image_type;
if (storage_type == TensorStorageType::TEXTURE_2D ||
@ -310,6 +356,32 @@ std::string TensorDescriptor::Read(const std::string& global_address) const {
}
}
std::string TensorDescriptor::Write(const std::string& var_name,
const std::string& global_address) const {
std::string image_type;
if (storage_type == TensorStorageType::TEXTURE_2D ||
storage_type == TensorStorageType::SINGLE_TEXTURE_2D) {
image_type = "image2d";
} else if (storage_type == TensorStorageType::TEXTURE_3D) {
image_type = "image3d";
} else if (storage_type == TensorStorageType::TEXTURE_ARRAY) {
image_type = "image2d_array";
}
switch (storage_type) {
case TensorStorageType::BUFFER:
case TensorStorageType::IMAGE_BUFFER:
return absl::StrCat("buffer[", global_address, "] = ", var_name, ";\n");
case TensorStorageType::TEXTURE_2D:
case TensorStorageType::TEXTURE_3D:
case TensorStorageType::SINGLE_TEXTURE_2D:
case TensorStorageType::TEXTURE_ARRAY:
return absl::StrCat(GetWriteImageFromDataType(data_type), "(", image_type,
", ", global_address, ", ", var_name, ");\n");
case TensorStorageType::UNKNOWN:
return "";
}
}
bool TensorDescriptor::HasAxis(Axis axis) const {
if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) {
return true;

View File

@ -82,7 +82,12 @@ struct TensorDescriptor : public GPUObjectDescriptor {
absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformWriteSelector(const std::vector<std::string>& args,
std::string* result) const;
std::string Read(const std::string& global_address) const;
std::string Write(const std::string& var_name,
const std::string& global_address) const;
bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset,
std::string* xc, std::string* yc, std::string* zc,