ConverterToConvWeights converted to new style.

PiperOrigin-RevId: 317752091
Change-Id: I4e23ab1a9a943c197db7a7be25e5272114fd43c3
This commit is contained in:
Raman Sarokin 2020-06-22 15:52:35 -07:00 committed by TensorFlower Gardener
parent 752ce6670f
commit afeac170f0
3 changed files with 70 additions and 45 deletions

View File

@ -26,51 +26,51 @@ namespace cl {
namespace {
std::string GetConverterToConvWeightsCode(
const OperationDef& op_def,
const ConvWeightsDescription& conv_weights_desc) {
TensorCodeGenerator src_tensor(
"src_data",
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const OperationDef& op_def, const ConvWeightsDescription& conv_weights_desc,
Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
args->AddFloat("mask_x");
args->AddFloat("mask_y");
args->AddFloat("mask_z");
args->AddFloat("mask_w");
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " float4 mask\n";
c += ") {\n";
c += "$0) {\n";
c += " int GROUP_SIZE = " +
std::to_string(conv_weights_desc.output_group_size) + ";\n";
c += " int O = get_global_id(0) * 4;\n";
c += " int I = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
c += " int W = Z % src_size.x;\n";
c += " int H = Z / src_size.x;\n";
c += " if (O >= src_size.w || I >= src_size.z || H >= src_size.y) return;\n";
c += " FLT4 v0 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 0") + ";\n";
c += " int W = Z % args.src_tensor.Width();\n";
c += " int H = Z / args.src_tensor.Width();\n";
c += " if (O >= args.src_tensor.Batch() || I >= args.src_tensor.Slices() || "
"H >= args.src_tensor.Height()) return;\n";
c += " FLT4 v0 = args.src_tensor.Read(W, H, I, O + 0);\n";
c += " FLT4 v1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
c += " FLT4 v2 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
c += " FLT4 v3 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
c += " if (O + 1 < src_size.w) {\n";
c += " v1 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 1") + ";\n";
c += " if (O + 1 < args.src_tensor.Batch()) {\n";
c += " v1 = args.src_tensor.Read(W, H, I, O + 1);\n";
c += " }\n";
c += " if (O + 2 < src_size.w) {\n";
c += " v2 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 2") + ";\n";
c += " if (O + 2 < args.src_tensor.Batch()) {\n";
c += " v2 = args.src_tensor.Read(W, H, I, O + 2);\n";
c += " }\n";
c += " if (O + 3 < src_size.w) {\n";
c += " v3 =" + src_tensor.ReadWHSB("W", "H", "I", "O + 3") + ";\n";
c += " if (O + 3 < args.src_tensor.Batch()) {\n";
c += " v3 = args.src_tensor.Read(W, H, I, O + 3);\n";
c += " }\n";
c += " if (I == src_size.z - 1) {\n";
c += " FLT4 mask_t = TO_FLT4(mask);\n";
c += " v0 *= mask_t;\n";
c += " v1 *= mask_t;\n";
c += " v2 *= mask_t;\n";
c += " v3 *= mask_t;\n";
c += " if (I == args.src_tensor.Slices() - 1) {\n";
c += " FLT4 mask = (FLT4)(args.mask_x, args.mask_y, args.mask_z, "
"args.mask_w);\n";
c += " v0 *= mask;\n";
c += " v1 *= mask;\n";
c += " v2 *= mask;\n";
c += " v3 *= mask;\n";
c += " }\n";
c += " FLT4 r0 = (FLT4)(v0.x, v1.x, v2.x, v3.x);\n";
c += " FLT4 r1 = (FLT4)(v0.y, v1.y, v2.y, v3.y);\n";
@ -78,17 +78,18 @@ std::string GetConverterToConvWeightsCode(
c += " FLT4 r3 = (FLT4)(v0.w, v1.w, v2.w, v3.w);\n";
c += " int d_index = O / (GROUP_SIZE * 4);\n";
c += " int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
c += " int dst_offset = (((d_index * src_size.y + H) * src_size.x + W) * "
"src_size.z + I) * GROUP_SIZE + "
c += " int dst_offset = (((d_index * args.src_tensor.Height() + H) * "
"args.src_tensor.Width() + W) * "
"args.src_tensor.Slices() + I) * GROUP_SIZE + "
"k_index;\n";
c += " int address0 = dst_offset * 4 + 0;\n";
c += " int address1 = dst_offset * 4 + 1;\n";
c += " int address2 = dst_offset * 4 + 2;\n";
c += " int address3 = dst_offset * 4 + 3;\n";
c += " " + dst_tensor.Write("r0", "address0");
c += " " + dst_tensor.Write("r1", "address1");
c += " " + dst_tensor.Write("r2", "address2");
c += " " + dst_tensor.Write("r3", "address3");
c += " args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0)\n;";
c += " args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1)\n;";
c += " args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2)\n;";
c += " args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3)\n;";
c += "}\n";
return c;
}
@ -115,20 +116,24 @@ ConverterToConvWeights& ConverterToConvWeights::operator=(
absl::Status ConverterToConvWeights::Compile(
const CreationContext& creation_context) {
std::string code =
GetConverterToConvWeightsCode(definition_, conv_weights_desc_);
GetConverterToConvWeightsCode(definition_, conv_weights_desc_, &args_);
RETURN_IF_ERROR(
args_.TransformToCLCode(creation_context.device->GetInfo(), {}, &code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status ConverterToConvWeights::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
return absl::OkStatus();
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
float4 mask = GetMaskForLastPlane(src_[0]->Channels());
RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x));
RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y));
RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z));
RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w));
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
return args_.Bind(kernel_.kernel());
}
int3 ConverterToConvWeights::GetGridSize() const {

View File

@ -168,6 +168,8 @@ absl::Status TensorDescriptor::PerformSelector(
return PerformReadSelector(args, template_args, result);
} else if (selector == "Write") {
return PerformWriteSelector(args, result);
} else if (selector == "WriteLinear") {
return PerformWriteLinearSelector(args, result);
} else if (selector == "GetAddress") {
return PerformGetAddressSelector(args, result);
} else {
@ -253,6 +255,21 @@ absl::Status TensorDescriptor::PerformWriteSelector(
return absl::OkStatus();
}
absl::Status TensorDescriptor::PerformWriteLinearSelector(
const std::vector<std::string>& args, std::string* result) const {
if (storage_type != TensorStorageType::BUFFER &&
storage_type != TensorStorageType::IMAGE_BUFFER) {
return absl::InvalidArgumentError(
"WriteLinear selector can be used only with linear "
"storages(BUFFER/IMAGE_BUFFER)");
}
if (args.size() != 2) {
return absl::NotFoundError("Unrecognized WriteLinear selector");
}
*result = Write(args[0], "(" + args[1] + ")");
return absl::OkStatus();
}
std::string TensorDescriptor::Read(DataType read_as_type,
const std::string& global_address) const {
const std::string read_as =

View File

@ -93,6 +93,9 @@ struct TensorDescriptor : public GPUObjectDescriptor {
absl::Status PerformWriteSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformWriteLinearSelector(const std::vector<std::string>& args,
std::string* result) const;
std::string Read(DataType read_as_type,
const std::string& global_address) const;
std::string Write(const std::string& var_name,