ConverterToConvWeights converted to new style.
PiperOrigin-RevId: 317752091 Change-Id: I4e23ab1a9a943c197db7a7be25e5272114fd43c3
This commit is contained in:
parent
752ce6670f
commit
afeac170f0
|
@ -26,51 +26,51 @@ namespace cl {
|
|||
namespace {
|
||||
|
||||
std::string GetConverterToConvWeightsCode(
|
||||
const OperationDef& op_def,
|
||||
const ConvWeightsDescription& conv_weights_desc) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
const OperationDef& op_def, const ConvWeightsDescription& conv_weights_desc,
|
||||
Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
args->AddFloat("mask_x");
|
||||
args->AddFloat("mask_y");
|
||||
args->AddFloat("mask_z");
|
||||
args->AddFloat("mask_w");
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " float4 mask\n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " int GROUP_SIZE = " +
|
||||
std::to_string(conv_weights_desc.output_group_size) + ";\n";
|
||||
c += " int O = get_global_id(0) * 4;\n";
|
||||
c += " int I = get_global_id(1);\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " int W = Z % src_size.x;\n";
|
||||
c += " int H = Z / src_size.x;\n";
|
||||
c += " if (O >= src_size.w || I >= src_size.z || H >= src_size.y) return;\n";
|
||||
c += " FLT4 v0 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 0") + ";\n";
|
||||
c += " int W = Z % args.src_tensor.Width();\n";
|
||||
c += " int H = Z / args.src_tensor.Width();\n";
|
||||
c += " if (O >= args.src_tensor.Batch() || I >= args.src_tensor.Slices() || "
|
||||
"H >= args.src_tensor.Height()) return;\n";
|
||||
c += " FLT4 v0 = args.src_tensor.Read(W, H, I, O + 0);\n";
|
||||
c += " FLT4 v1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
|
||||
c += " FLT4 v2 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
|
||||
c += " FLT4 v3 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
|
||||
c += " if (O + 1 < src_size.w) {\n";
|
||||
c += " v1 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 1") + ";\n";
|
||||
c += " if (O + 1 < args.src_tensor.Batch()) {\n";
|
||||
c += " v1 = args.src_tensor.Read(W, H, I, O + 1);\n";
|
||||
c += " }\n";
|
||||
c += " if (O + 2 < src_size.w) {\n";
|
||||
c += " v2 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 2") + ";\n";
|
||||
c += " if (O + 2 < args.src_tensor.Batch()) {\n";
|
||||
c += " v2 = args.src_tensor.Read(W, H, I, O + 2);\n";
|
||||
c += " }\n";
|
||||
c += " if (O + 3 < src_size.w) {\n";
|
||||
c += " v3 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 3") + ";\n";
|
||||
c += " if (O + 3 < args.src_tensor.Batch()) {\n";
|
||||
c += " v3 = args.src_tensor.Read(W, H, I, O + 3);\n";
|
||||
c += " }\n";
|
||||
c += " if (I == src_size.z - 1) {\n";
|
||||
c += " FLT4 mask_t = TO_FLT4(mask);\n";
|
||||
c += " v0 *= mask_t;\n";
|
||||
c += " v1 *= mask_t;\n";
|
||||
c += " v2 *= mask_t;\n";
|
||||
c += " v3 *= mask_t;\n";
|
||||
c += " if (I == args.src_tensor.Slices() - 1) {\n";
|
||||
c += " FLT4 mask = (FLT4)(args.mask_x, args.mask_y, args.mask_z, "
|
||||
"args.mask_w);\n";
|
||||
c += " v0 *= mask;\n";
|
||||
c += " v1 *= mask;\n";
|
||||
c += " v2 *= mask;\n";
|
||||
c += " v3 *= mask;\n";
|
||||
c += " }\n";
|
||||
c += " FLT4 r0 = (FLT4)(v0.x, v1.x, v2.x, v3.x);\n";
|
||||
c += " FLT4 r1 = (FLT4)(v0.y, v1.y, v2.y, v3.y);\n";
|
||||
|
@ -78,17 +78,18 @@ std::string GetConverterToConvWeightsCode(
|
|||
c += " FLT4 r3 = (FLT4)(v0.w, v1.w, v2.w, v3.w);\n";
|
||||
c += " int d_index = O / (GROUP_SIZE * 4);\n";
|
||||
c += " int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
|
||||
c += " int dst_offset = (((d_index * src_size.y + H) * src_size.x + W) * "
|
||||
"src_size.z + I) * GROUP_SIZE + "
|
||||
c += " int dst_offset = (((d_index * args.src_tensor.Height() + H) * "
|
||||
"args.src_tensor.Width() + W) * "
|
||||
"args.src_tensor.Slices() + I) * GROUP_SIZE + "
|
||||
"k_index;\n";
|
||||
c += " int address0 = dst_offset * 4 + 0;\n";
|
||||
c += " int address1 = dst_offset * 4 + 1;\n";
|
||||
c += " int address2 = dst_offset * 4 + 2;\n";
|
||||
c += " int address3 = dst_offset * 4 + 3;\n";
|
||||
c += " " + dst_tensor.Write("r0", "address0");
|
||||
c += " " + dst_tensor.Write("r1", "address1");
|
||||
c += " " + dst_tensor.Write("r2", "address2");
|
||||
c += " " + dst_tensor.Write("r3", "address3");
|
||||
c += " args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0)\n;";
|
||||
c += " args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1)\n;";
|
||||
c += " args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2)\n;";
|
||||
c += " args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3)\n;";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
|
@ -115,20 +116,24 @@ ConverterToConvWeights& ConverterToConvWeights::operator=(
|
|||
absl::Status ConverterToConvWeights::Compile(
|
||||
const CreationContext& creation_context) {
|
||||
std::string code =
|
||||
GetConverterToConvWeightsCode(definition_, conv_weights_desc_);
|
||||
GetConverterToConvWeightsCode(definition_, conv_weights_desc_, &args_);
|
||||
RETURN_IF_ERROR(
|
||||
args_.TransformToCLCode(creation_context.device->GetInfo(), {}, &code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status ConverterToConvWeights::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
float4 mask = GetMaskForLastPlane(src_[0]->Channels());
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x));
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y));
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z));
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 ConverterToConvWeights::GetGridSize() const {
|
||||
|
|
|
@ -168,6 +168,8 @@ absl::Status TensorDescriptor::PerformSelector(
|
|||
return PerformReadSelector(args, template_args, result);
|
||||
} else if (selector == "Write") {
|
||||
return PerformWriteSelector(args, result);
|
||||
} else if (selector == "WriteLinear") {
|
||||
return PerformWriteLinearSelector(args, result);
|
||||
} else if (selector == "GetAddress") {
|
||||
return PerformGetAddressSelector(args, result);
|
||||
} else {
|
||||
|
@ -253,6 +255,21 @@ absl::Status TensorDescriptor::PerformWriteSelector(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformWriteLinearSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
if (storage_type != TensorStorageType::BUFFER &&
|
||||
storage_type != TensorStorageType::IMAGE_BUFFER) {
|
||||
return absl::InvalidArgumentError(
|
||||
"WriteLinear selector can be used only with linear "
|
||||
"storages(BUFFER/IMAGE_BUFFER)");
|
||||
}
|
||||
if (args.size() != 2) {
|
||||
return absl::NotFoundError("Unrecognized WriteLinear selector");
|
||||
}
|
||||
*result = Write(args[0], "(" + args[1] + ")");
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::Read(DataType read_as_type,
|
||||
const std::string& global_address) const {
|
||||
const std::string read_as =
|
||||
|
|
|
@ -93,6 +93,9 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
|||
absl::Status PerformWriteSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
absl::Status PerformWriteLinearSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
std::string Read(DataType read_as_type,
|
||||
const std::string& global_address) const;
|
||||
std::string Write(const std::string& var_name,
|
||||
|
|
Loading…
Reference in New Issue