Mean converted to new style.
PiperOrigin-RevId: 317746588 Change-Id: I9bd2eb57a3a39db01636f0e12ab812d5c21bf251
This commit is contained in:
parent
188ba91f08
commit
13deeb095c
@ -27,41 +27,46 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GetMeanKernelCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||
const int3& work_group_size) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor("dst_data", WHSPoint{"1", "1", "src_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GetMeanKernelCode(const OperationDef& op_def,
|
||||
const int3& work_group_size, Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
args->AddFloat("inv_multiplier_1");
|
||||
args->AddFloat("inv_multiplier_2");
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
const std::string wg_x = std::to_string(work_group_size.x);
|
||||
const std::string wg_y = std::to_string(work_group_size.y);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ);
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " float2 inv_multipliers \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " __local float4 accum[" +
|
||||
std::to_string(work_group_size.x * work_group_size.y) + "];\n";
|
||||
c += " int local_x = get_local_id(0);\n";
|
||||
c += " int local_y = get_local_id(1);\n";
|
||||
c += " int local_id = local_y * " + wg_x + " + local_x;\n";
|
||||
c += " int S = get_global_id(2);\n";
|
||||
c += " if (S >= src_size.z) return;\n";
|
||||
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
|
||||
c += " int linear_id_2 = get_global_id(2);\n";
|
||||
c += " int S = linear_id_2 / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id_2 % args.dst_tensor.Batch();\n";
|
||||
c += " args.dst_tensor.SetBatchRef(B);\n";
|
||||
c += " args.src_tensor.SetBatchRef(B);\n";
|
||||
} else {
|
||||
c += " int S = get_global_id(2);\n";
|
||||
}
|
||||
c += " if (S >= args.dst_tensor.Slices()) return;\n";
|
||||
c += " accum[local_id] = (float4)(0.0f);\n";
|
||||
c += " for (int s_y = local_y; s_y < src_size.y; s_y += " + wg_y + ") {\n";
|
||||
c += " for (int s_x = local_x; s_x < src_size.x; s_x += " + wg_x + ") {\n";
|
||||
c += " accum[local_id] += " +
|
||||
src_tensor.ReadAsFloatWHS("s_x", "s_y", "S") + ";\n";
|
||||
c += " for (int s_y = local_y; s_y < args.src_tensor.Height(); s_y += " +
|
||||
wg_y + ") {\n";
|
||||
c += " for (int s_x = local_x; s_x < args.src_tensor.Width(); s_x += " +
|
||||
wg_x + ") {\n";
|
||||
c += " accum[local_id] += args.src_tensor.Read<float>(s_x, s_y, S);\n";
|
||||
c += " }\n";
|
||||
c += " }\n";
|
||||
c += " accum[local_id] *= inv_multipliers.x;\n";
|
||||
c += " accum[local_id] *= args.inv_multiplier_1;\n";
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
const int total_size = work_group_size.x * work_group_size.y;
|
||||
int offset = 1;
|
||||
@ -81,9 +86,8 @@ std::string GetMeanKernelCode(
|
||||
for (int i = 1; i < reminder; ++i) {
|
||||
c += " sum += accum[" + std::to_string(offset * i) + "];\n";
|
||||
}
|
||||
c += " FLT4 result = TO_FLT4(sum * inv_multipliers.y);\n";
|
||||
c += PostProcess(linked_operations, {"result", "0", "0", "S"});
|
||||
c += " " + dst_tensor.WriteWHS("result", "0", "0", "S");
|
||||
c += " FLT4 result = TO_FLT4(sum * args.inv_multiplier_2);\n";
|
||||
c += " args.dst_tensor.Write(result, 0, 0, S);\n";
|
||||
c += "}\n";
|
||||
return c;
|
||||
}
|
||||
@ -107,30 +111,34 @@ absl::Status Mean::Compile(const CreationContext& creation_context) {
|
||||
if (creation_context.device->IsAdreno3xx()) {
|
||||
work_group_size_ = int3(16, 8, 1);
|
||||
}
|
||||
const auto code =
|
||||
GetMeanKernelCode(definition_, linked_operations_, work_group_size_);
|
||||
std::string code = GetMeanKernelCode(definition_, work_group_size_, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status Mean::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
const double total_size = src_[0]->Width() * src_[0]->Height();
|
||||
const double size_0 = work_group_size_.x * work_group_size_.y;
|
||||
const double size_1 = total_size / size_0;
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(float2(1.0 / size_1, 1.0 / size_0)));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_1", 1.0 / size_1));
|
||||
RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_2", 1.0 / size_0));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 Mean::GetGridSize() const {
|
||||
const int grid_x = work_group_size_.x * dst_[0]->Batch();
|
||||
const int grid_x = work_group_size_.x;
|
||||
const int grid_y = work_group_size_.y;
|
||||
const int grid_z = dst_[0]->Slices();
|
||||
const int grid_z = dst_[0]->Slices() * dst_[0]->Batch();
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user