Softmax1x1 converted to new style.
PiperOrigin-RevId: 316782370 Change-Id: I1f7761c0520d72876f352c9f156341b349b90cbe
This commit is contained in:
parent
dac169cd7f
commit
897e3c0eca
@ -25,47 +25,45 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetSoftmaxKernelCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor("src_data",
|
||||
WHSBPoint{"tensor_size.x", "tensor_size.y",
|
||||
"tensor_size.z", "tensor_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor("dst_data",
|
||||
WHSBPoint{"tensor_size.x", "tensor_size.y",
|
||||
"tensor_size.z", "tensor_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GetSoftmaxKernelCode(const OperationDef& op_def, Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
args->AddFloat("mask_x");
|
||||
args->AddFloat("mask_y");
|
||||
args->AddFloat("mask_z");
|
||||
args->AddFloat("mask_w");
|
||||
args->AddInt("slices_x32");
|
||||
|
||||
const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 tensor_size,\n";
|
||||
c += " int2 size,\n";
|
||||
c += " float4 mask\n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int batch_id = get_global_id(1);\n";
|
||||
c += " if (batch_id >= tensor_size.w) return;\n";
|
||||
c += " if (batch_id >= args.dst_tensor.Batch()) return;\n";
|
||||
c += " args.dst_tensor.SetBatchRef(batch_id);\n";
|
||||
c += " args.src_tensor.SetBatchRef(batch_id);\n";
|
||||
}
|
||||
c += " float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, "
|
||||
"args.mask_w);\n";
|
||||
c += " int offset = 0;\n";
|
||||
c += " float sum = 0.0f;\n";
|
||||
c += " int s = 0;\n";
|
||||
c += " int tid = get_local_id(0);\n";
|
||||
c += " do {\n";
|
||||
c += " int z = offset + tid;\n";
|
||||
c += " if (z < size.x) {\n";
|
||||
c += " float4 mask_temp = z == size.x - 1 ? mask : (float4)(1.0f);\n";
|
||||
c += " float4 src = " +
|
||||
src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ";\n";
|
||||
c += " if (z < args.dst_tensor.Slices()) {\n";
|
||||
c += " float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : "
|
||||
"(float4)(1.0f);\n";
|
||||
c += " float4 src = args.src_tensor.Read<float>(0, 0, z);\n";
|
||||
c += " sum += dot(mask_temp, exp(src));\n";
|
||||
c += " offset += 32;\n";
|
||||
c += " }\n";
|
||||
c += " s++;\n";
|
||||
c += " } while (s < size.y);\n";
|
||||
c += " } while (s < args.slices_x32);\n";
|
||||
c += "\n";
|
||||
c += " __local float4 tmp[8];\n";
|
||||
c += " __local float* tmpx1 = (__local float*)tmp;\n";
|
||||
@ -89,16 +87,14 @@ std::string GetSoftmaxKernelCode(
|
||||
c += " s = 0;\n";
|
||||
c += " do {\n";
|
||||
c += " int z = offset + tid;\n";
|
||||
c += " if (z < size.x) {\n";
|
||||
c += " FLT4 res = TO_FLT4(exp(" +
|
||||
src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ")*sum);\n";
|
||||
const LinkingContext context{"res", "0", "0", "z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHSB("res", "0", "0", "z", batch_id);
|
||||
c += " if (z < args.dst_tensor.Slices()) {\n";
|
||||
c += " FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, "
|
||||
"z))*sum);\n";
|
||||
c += " args.dst_tensor.Write(res, 0, 0, z);\n";
|
||||
c += " offset += 32;\n";
|
||||
c += " }\n";
|
||||
c += " s++;\n";
|
||||
c += " } while (s < size.y);\n";
|
||||
c += " } while (s < args.slices_x32);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
@ -116,23 +112,30 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) {
|
||||
}
|
||||
|
||||
absl::Status Softmax1x1::Compile(const CreationContext& creation_context) {
|
||||
const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
|
||||
std::string code = GetSoftmaxKernelCode(definition_, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
const int depth = src_[0]->Slices();
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(int2(depth, DivideRoundUp(depth, 32))));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
float4 mask = GetMaskForLastPlane(src_[0]->Channels());
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x));
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y));
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z));
|
||||
RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w));
|
||||
RETURN_IF_ERROR(
|
||||
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
|
||||
|
||||
args_.SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32)));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
RETURN_IF_ERROR(args_.Bind(kernel_.kernel()));
|
||||
return queue->DispatchImplicit(kernel_, {32, dst_[0]->Batch(), 1},
|
||||
{32, 1, 1});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user