Softmax1x1 converted to new style.

PiperOrigin-RevId: 316782370
Change-Id: I1f7761c0520d72876f352c9f156341b349b90cbe
This commit is contained in:
Raman Sarokin 2020-06-16 16:39:31 -07:00 committed by TensorFlower Gardener
parent dac169cd7f
commit 897e3c0eca

View File

@ -25,47 +25,45 @@ namespace gpu {
namespace cl {
namespace {
std::string GetSoftmaxKernelCode(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor("src_data",
WHSBPoint{"tensor_size.x", "tensor_size.y",
"tensor_size.z", "tensor_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data",
WHSBPoint{"tensor_size.x", "tensor_size.y",
"tensor_size.z", "tensor_size.w"},
op_def.dst_tensors[0]);
std::string GetSoftmaxKernelCode(const OperationDef& op_def, Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
args->AddFloat("mask_x");
args->AddFloat("mask_y");
args->AddFloat("mask_z");
args->AddFloat("mask_w");
args->AddInt("slices_x32");
const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 tensor_size,\n";
c += " int2 size,\n";
c += " float4 mask\n";
c += ") {\n";
c += "$0) {\n";
if (op_def.IsBatchSupported()) {
c += " int batch_id = get_global_id(1);\n";
c += " if (batch_id >= tensor_size.w) return;\n";
c += " if (batch_id >= args.dst_tensor.Batch()) return;\n";
c += " args.dst_tensor.SetBatchRef(batch_id);\n";
c += " args.src_tensor.SetBatchRef(batch_id);\n";
}
c += " float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, "
"args.mask_w);\n";
c += " int offset = 0;\n";
c += " float sum = 0.0f;\n";
c += " int s = 0;\n";
c += " int tid = get_local_id(0);\n";
c += " do {\n";
c += " int z = offset + tid;\n";
c += " if (z < size.x) {\n";
c += " float4 mask_temp = z == size.x - 1 ? mask : (float4)(1.0f);\n";
c += " float4 src = " +
src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ";\n";
c += " if (z < args.dst_tensor.Slices()) {\n";
c += " float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : "
"(float4)(1.0f);\n";
c += " float4 src = args.src_tensor.Read<float>(0, 0, z);\n";
c += " sum += dot(mask_temp, exp(src));\n";
c += " offset += 32;\n";
c += " }\n";
c += " s++;\n";
c += " } while (s < size.y);\n";
c += " } while (s < args.slices_x32);\n";
c += "\n";
c += " __local float4 tmp[8];\n";
c += " __local float* tmpx1 = (__local float*)tmp;\n";
@ -89,16 +87,14 @@ std::string GetSoftmaxKernelCode(
c += " s = 0;\n";
c += " do {\n";
c += " int z = offset + tid;\n";
c += " if (z < size.x) {\n";
c += " FLT4 res = TO_FLT4(exp(" +
src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ")*sum);\n";
const LinkingContext context{"res", "0", "0", "z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHSB("res", "0", "0", "z", batch_id);
c += " if (z < args.dst_tensor.Slices()) {\n";
c += " FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, "
"z))*sum);\n";
c += " args.dst_tensor.Write(res, 0, 0, z);\n";
c += " offset += 32;\n";
c += " }\n";
c += " s++;\n";
c += " } while (s < size.y);\n";
c += " } while (s < args.slices_x32);\n";
c += "}\n";
return c;
}
@ -116,23 +112,30 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) {
}
absl::Status Softmax1x1::Compile(const CreationContext& creation_context) {
const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
std::string code = GetSoftmaxKernelCode(definition_, &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
const int depth = src_[0]->Slices();
RETURN_IF_ERROR(kernel_.SetBytesAuto(int2(depth, DivideRoundUp(depth, 32))));
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
float4 mask = GetMaskForLastPlane(src_[0]->Channels());
RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x));
RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y));
RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z));
RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
args_.SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32)));
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
RETURN_IF_ERROR(args_.Bind(kernel_.kernel()));
return queue->DispatchImplicit(kernel_, {32, dst_[0]->Batch(), 1},
{32, 1, 1});
}