Mean converted to new style.

PiperOrigin-RevId: 317746588
Change-Id: I9bd2eb57a3a39db01636f0e12ab812d5c21bf251
This commit is contained in:
Raman Sarokin 2020-06-22 15:23:37 -07:00 committed by TensorFlower Gardener
parent 188ba91f08
commit 13deeb095c

View File

@ -27,41 +27,46 @@ namespace gpu {
namespace cl {
namespace {
std::string GetMeanKernelCode(
const OperationDef& op_def,
const std::vector<ElementwiseOperation*>& linked_operations,
const int3& work_group_size) {
TensorCodeGenerator src_tensor(
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data", WHSPoint{"1", "1", "src_size.z"},
op_def.dst_tensors[0]);
std::string GetMeanKernelCode(const OperationDef& op_def,
const int3& work_group_size, Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
args->AddFloat("inv_multiplier_1");
args->AddFloat("inv_multiplier_2");
std::string c = GetCommonDefines(op_def.precision);
const std::string wg_x = std::to_string(work_group_size.x);
const std::string wg_y = std::to_string(work_group_size.y);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " float2 inv_multipliers \n";
c += ") {\n";
c += "$0) {\n";
c += " __local float4 accum[" +
std::to_string(work_group_size.x * work_group_size.y) + "];\n";
c += " int local_x = get_local_id(0);\n";
c += " int local_y = get_local_id(1);\n";
c += " int local_id = local_y * " + wg_x + " + local_x;\n";
c += " int S = get_global_id(2);\n";
c += " if (S >= src_size.z) return;\n";
if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
c += " int linear_id_2 = get_global_id(2);\n";
c += " int S = linear_id_2 / args.dst_tensor.Batch();\n";
c += " int B = linear_id_2 % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
c += " args.src_tensor.SetBatchRef(B);\n";
} else {
c += " int S = get_global_id(2);\n";
}
c += " if (S >= args.dst_tensor.Slices()) return;\n";
c += " accum[local_id] = (float4)(0.0f);\n";
c += " for (int s_y = local_y; s_y < src_size.y; s_y += " + wg_y + ") {\n";
c += " for (int s_x = local_x; s_x < src_size.x; s_x += " + wg_x + ") {\n";
c += " accum[local_id] += " +
src_tensor.ReadAsFloatWHS("s_x", "s_y", "S") + ";\n";
c += " for (int s_y = local_y; s_y < args.src_tensor.Height(); s_y += " +
wg_y + ") {\n";
c += " for (int s_x = local_x; s_x < args.src_tensor.Width(); s_x += " +
wg_x + ") {\n";
c += " accum[local_id] += args.src_tensor.Read<float>(s_x, s_y, S);\n";
c += " }\n";
c += " }\n";
c += " accum[local_id] *= inv_multipliers.x;\n";
c += " accum[local_id] *= args.inv_multiplier_1;\n";
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
const int total_size = work_group_size.x * work_group_size.y;
int offset = 1;
@ -81,9 +86,8 @@ std::string GetMeanKernelCode(
for (int i = 1; i < reminder; ++i) {
c += " sum += accum[" + std::to_string(offset * i) + "];\n";
}
c += " FLT4 result = TO_FLT4(sum * inv_multipliers.y);\n";
c += PostProcess(linked_operations, {"result", "0", "0", "S"});
c += " " + dst_tensor.WriteWHS("result", "0", "0", "S");
c += " FLT4 result = TO_FLT4(sum * args.inv_multiplier_2);\n";
c += " args.dst_tensor.Write(result, 0, 0, S);\n";
c += "}\n";
return c;
}
@ -107,30 +111,34 @@ absl::Status Mean::Compile(const CreationContext& creation_context) {
if (creation_context.device->IsAdreno3xx()) {
work_group_size_ = int3(16, 8, 1);
}
const auto code =
GetMeanKernelCode(definition_, linked_operations_, work_group_size_);
std::string code = GetMeanKernelCode(definition_, work_group_size_, &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status Mean::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
const double total_size = src_[0]->Width() * src_[0]->Height();
const double size_0 = work_group_size_.x * work_group_size_.y;
const double size_1 = total_size / size_0;
RETURN_IF_ERROR(kernel_.SetBytesAuto(float2(1.0 / size_1, 1.0 / size_0)));
return absl::OkStatus();
RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_1", 1.0 / size_1));
RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_2", 1.0 / size_0));
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
return args_.Bind(kernel_.kernel());
}
int3 Mean::GetGridSize() const {
const int grid_x = work_group_size_.x * dst_[0]->Batch();
const int grid_x = work_group_size_.x;
const int grid_y = work_group_size_.y;
const int grid_z = dst_[0]->Slices();
const int grid_z = dst_[0]->Slices() * dst_[0]->Batch();
return int3(grid_x, grid_y, grid_z);
}