diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc index 4c46a88d3c4..6f1d49a1494 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc @@ -31,37 +31,34 @@ std::string GetTransposeCode( const OperationDef& op_def, const TransposeAttributes& attr, const std::vector& linked_operations, Arguments* args) { - TensorCodeGenerator dst_tensor("dst_data", - WHSBPoint{"args.dst_width", "args.dst_height", - "args.dst_slices", "args.dst_batch"}, - op_def.dst_tensors[0]); - args->AddObjectRef( "src_tensor", AccessType::READ, absl::make_unique(op_def.src_tensors[0])); - args->AddInt("dst_width"); - args->AddInt("dst_height"); - args->AddInt("dst_slices"); - args->AddInt("dst_batch"); - args->AddInt("dst_channels"); + args->AddObjectRef( + "dst_tensor", AccessType::WRITE, + absl::make_unique(op_def.dst_tensors[0])); const std::string batch_id = op_def.IsBatchSupported() ? "B" : ""; std::string c = GetCommonDefines(op_def.precision); + std::string linked_args = GetArgsDeclaration(linked_operations); + if (linked_args[0] == ',') { + linked_args[0] = ' '; + } c += "__kernel void main_function(\n"; - c += dst_tensor.GetDeclaration(AccessType::WRITE); - c += GetArgsDeclaration(linked_operations); + c += linked_args; c += "$0) {\n"; if (op_def.IsBatchSupported()) { c += " int linear_id = get_global_id(0);\n"; - c += " int X = linear_id / args.dst_batch;\n"; - c += " int B = linear_id % args.dst_batch;\n"; + c += " int X = linear_id / args.dst_tensor.Batch();\n"; + c += " int B = linear_id % args.dst_tensor.Batch();\n"; + c += " args.dst_tensor.SetBatchRef(B);\n"; } else { c += " int X = get_global_id(0);\n"; } c += " int Y = get_global_id(1);\n"; c += " int Z = get_global_id(2);\n"; - c += " if (X >= args.dst_width || Y >= args.dst_height || Z >= " - "args.dst_slices) { \n"; + c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " + "Z >= args.dst_tensor.Slices()) { \n"; c += " return; \n"; c += " } \n"; c += " FLT temps[4];\n"; @@ -89,7 +86,7 @@ std::string GetTransposeCode( } else { c += " for (int i = 0; i < 4; ++i) {\n"; c += " int dst_channel = Z * 4 + i;\n"; - c += " if (dst_channel < args.dst_channels) {\n"; + c += " if (dst_channel < args.dst_tensor.Channels()) {\n"; const std::string bhwc[] = {"B", "Y", "X", "dst_channel"}; if (op_def.src_tensors[0].HasAxis(Axis::BATCH)) { c += " args.src_tensor.SetBatchRef(" + bhwc[remap[0]] + ");\n"; @@ -107,10 +104,10 @@ std::string GetTransposeCode( } c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n"; std::string x_3dcoord = - op_def.IsBatchSupported() ? "X * args.dst_batch + B" : "X"; + op_def.IsBatchSupported() ? "X * args.dst_tensor.Batch() + B" : "X"; const LinkingContext context{"result", x_3dcoord, "Y", "Z"}; c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id); + c += " args.dst_tensor.Write(result, X, Y, Z);\n"; c += "}\n"; return c; } @@ -146,13 +143,8 @@ absl::Status Transpose::Compile(const CreationContext& creation_context) { absl::Status Transpose::BindArguments() { RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0])); - RETURN_IF_ERROR(args_.SetInt("dst_width", dst_[0]->Width())); - RETURN_IF_ERROR(args_.SetInt("dst_height", dst_[0]->Height())); - RETURN_IF_ERROR(args_.SetInt("dst_slices", dst_[0]->Slices())); - RETURN_IF_ERROR(args_.SetInt("dst_batch", dst_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("dst_channels", dst_[0]->Channels())); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0])); kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc index 47caa7fa123..11e1ca2ca07 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.cc +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.cc @@ -136,6 +136,17 @@ std::string GetReadImageFromDataType(DataType data_type) { return "error"; } } + +std::string GetWriteImageFromDataType(DataType data_type) { + if (data_type == DataType::FLOAT32) { + return "write_imagef"; + } else if (data_type == DataType::FLOAT16) { + return "write_imageh"; + } else { + return "error"; + } +} + } // namespace std::string ToString(TensorStorageType type) { @@ -245,9 +256,11 @@ absl::Status TensorDescriptor::PerformSelector( return absl::OkStatus(); } else if (selector == "Read") { return PerformReadSelector(args, result); + } else if (selector == "Write") { + return PerformWriteSelector(args, result); } else { return absl::NotFoundError(absl::StrCat( - "TensorLinearDescriptor don't have selector with name - ", selector)); + "TensorDescriptor don't have selector with name - ", selector)); } } @@ -283,6 +296,39 @@ absl::Status TensorDescriptor::PerformReadSelector( } } +absl::Status TensorDescriptor::PerformWriteSelector( + const std::vector& args, std::string* result) const { + std::string xc; + std::string yc; + std::string zc; + std::string sc; + std::string bc; + bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc); + if (args.size() < 2 || !parsed) { + return absl::NotFoundError("Unrecognized Write selector"); + } + + if (layout == Layout::HWC) { + *result = Write(args[0], + GetGlobalAddressNoDeclarationWHS(xc, yc, sc, storage_type)); + return absl::OkStatus(); + } else if (layout == Layout::BHWC) { + *result = Write(args[0], GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc, + storage_type)); + return absl::OkStatus(); + } else if (layout == Layout::HWDC) { + *result = Write(args[0], GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc, + storage_type)); + return absl::OkStatus(); + } else if (layout == Layout::BHWDC) { + *result = Write(args[0], GetGlobalAddressNoDeclarationWHDSB( + xc, yc, zc, sc, bc, storage_type)); + return absl::OkStatus(); + } else { + return absl::NotFoundError("Unsupported layout"); + } +} + std::string TensorDescriptor::Read(const std::string& global_address) const { std::string image_type; if (storage_type == TensorStorageType::TEXTURE_2D || @@ -310,6 +356,32 @@ std::string TensorDescriptor::Read(const std::string& global_address) const { } } +std::string TensorDescriptor::Write(const std::string& var_name, + const std::string& global_address) const { + std::string image_type; + if (storage_type == TensorStorageType::TEXTURE_2D || + storage_type == TensorStorageType::SINGLE_TEXTURE_2D) { + image_type = "image2d"; + } else if (storage_type == TensorStorageType::TEXTURE_3D) { + image_type = "image3d"; + } else if (storage_type == TensorStorageType::TEXTURE_ARRAY) { + image_type = "image2d_array"; + } + switch (storage_type) { + case TensorStorageType::BUFFER: + case TensorStorageType::IMAGE_BUFFER: + return absl::StrCat("buffer[", global_address, "] = ", var_name, ";\n"); + case TensorStorageType::TEXTURE_2D: + case TensorStorageType::TEXTURE_3D: + case TensorStorageType::SINGLE_TEXTURE_2D: + case TensorStorageType::TEXTURE_ARRAY: + return absl::StrCat(GetWriteImageFromDataType(data_type), "(", image_type, + ", ", global_address, ", ", var_name, ");\n"); + case TensorStorageType::UNKNOWN: + return ""; + } +} + bool TensorDescriptor::HasAxis(Axis axis) const { if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS) { return true; diff --git a/tensorflow/lite/delegates/gpu/cl/tensor_type.h b/tensorflow/lite/delegates/gpu/cl/tensor_type.h index 42f4f9b98e5..7d5ff888a85 100644 --- a/tensorflow/lite/delegates/gpu/cl/tensor_type.h +++ b/tensorflow/lite/delegates/gpu/cl/tensor_type.h @@ -82,7 +82,12 @@ struct TensorDescriptor : public GPUObjectDescriptor { absl::Status PerformReadSelector(const std::vector& args, std::string* result) const; + absl::Status PerformWriteSelector(const std::vector& args, + std::string* result) const; + std::string Read(const std::string& global_address) const; + std::string Write(const std::string& var_name, + const std::string& global_address) const; bool ParseCoordsFromArgs(const std::vector& args, int offset, std::string* xc, std::string* yc, std::string* zc,