Resize converted to new style.

PiperOrigin-RevId: 316787130
Change-Id: I67db63fa6eaec2bccc87031f2e202da65a2ce439
This commit is contained in:
Raman Sarokin 2020-06-16 17:07:23 -07:00 committed by TensorFlower Gardener
parent de0d8ddc27
commit 01cfc8a8a3

View File

@ -25,168 +25,166 @@ namespace gpu {
namespace cl {
namespace {
std::string GetResizeCode(
const OperationDef& op_def, SamplingType sampling_type,
bool half_pixel_centers,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
std::string GetResizeCode(const OperationDef& op_def,
SamplingType sampling_type, bool half_pixel_centers,
Arguments* args) {
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
if (op_def.IsBatchSupported()) {
src_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
auto dst_desc = absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]);
if (op_def.IsBatchSupported()) {
dst_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
args->AddInt("border_x");
args->AddInt("border_y");
args->AddFloat("scale_factor_x");
args->AddFloat("scale_factor_y");
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
c += " int2 border, \n";
c += " float2 scale_factor \n";
c += ") {\n";
c += "$0) {\n";
c += " int Y = get_global_id(1);\n";
c += " int Z = get_global_id(2);\n";
if (op_def.IsBatchSupported()) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / dst_size.w;\n";
c += " int B = linear_id % dst_size.w;\n";
c += " if (get_global_id(0) >= dst_size.x || Y >= dst_size.y || Z >= "
"dst_size.z) return;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " if (linear_id >= args.dst_tensor.Width() || Y >= "
"args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
} else {
c += " int X = get_global_id(0);\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) "
"return;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
"|| Z >= args.dst_tensor.Slices()) return;\n";
}
if (sampling_type == SamplingType::NEAREST) {
c += " int2 coord = (int2)(X * scale_factor.x, Y * scale_factor.y);\n";
c += " int2 coord = (int2)(X * args.scale_factor_x, Y * "
"args.scale_factor_y);\n";
if (op_def.IsBatchSupported()) {
c += " coord.x = coord.x * src_size.w + B;\n";
c += " X = X * src_size.w + B;\n";
c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
c += " X = X * args.src_tensor.Batch() + B;\n";
}
c += " FLT4 r0 = " + src_tensor.ReadWHS("coord.x", "coord.y", "Z") + ";\n";
c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
} else {
if (half_pixel_centers) {
c += " float2 f_coords = ((float2)(X, Y) + 0.5f) * scale_factor - "
c += " float2 f_coords = ((float2)(X, Y) + 0.5f) * "
"(float2)(args.scale_factor_x, args.scale_factor_y) - "
"0.5f;\n";
} else {
c += " float2 f_coords = (float2)(X, Y) * scale_factor;\n";
c += " float2 f_coords = (float2)(X, Y) * (float2)(args.scale_factor_x, "
"args.scale_factor_y);\n";
}
c += " float2 f_coords_floor = floor(f_coords);\n";
c += " int2 coords_floor = (int2)(f_coords_floor.x, f_coords_floor.y);\n";
c += " int4 st;\n";
c += " st.xy = max(coords_floor, (int2)(0, 0));\n";
c += " st.zw = min(coords_floor + (int2)(1, 1), border);\n";
c += " st.zw = min(coords_floor + (int2)(1, 1), (int2)(args.border_x, "
"args.border_y));\n";
c += " float2 t = f_coords - f_coords_floor;\n";
if (op_def.IsBatchSupported()) {
c += " st.x = st.x * src_size.w + B;\n";
c += " st.z = st.z * src_size.w + B;\n";
c += " X = X * src_size.w + B;\n";
c += " st.x = st.x * args.src_tensor.Batch() + B;\n";
c += " st.z = st.z * args.src_tensor.Batch() + B;\n";
c += " X = X * args.src_tensor.Batch() + B;\n";
}
c += " float4 src0 = " + src_tensor.ReadAsFloatWHS("st.x", "st.y", "Z") +
";\n";
c += " float4 src1 = " + src_tensor.ReadAsFloatWHS("st.z", "st.y", "Z") +
";\n";
c += " float4 src2 = " + src_tensor.ReadAsFloatWHS("st.x", "st.w", "Z") +
";\n";
c += " float4 src3 = " + src_tensor.ReadAsFloatWHS("st.z", "st.w", "Z") +
";\n";
c += " float4 src0 = args.src_tensor.Read<float>(st.x, st.y, Z);\n";
c += " float4 src1 = args.src_tensor.Read<float>(st.z, st.y, Z);\n";
c += " float4 src2 = args.src_tensor.Read<float>(st.x, st.w, Z);\n";
c += " float4 src3 = args.src_tensor.Read<float>(st.z, st.w, Z);\n";
c += " FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
"t.y));\n";
}
const LinkingContext context{"r0", "X", "Y", "Z"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHS("r0", "X", "Y", "Z");
c += " args.dst_tensor.Write(r0, X, Y, Z);\n";
c += "}\n";
return c;
}
std::string GetResize3DCode(
const OperationDef& op_def, SamplingType sampling_type,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data",
WHDSPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
std::string GetResize3DCode(const OperationDef& op_def,
SamplingType sampling_type, Arguments* args) {
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
if (op_def.IsBatchSupported()) {
src_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
auto dst_desc = absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]);
if (op_def.IsBatchSupported()) {
dst_desc->SetStateVar("BatchedWidth", "true");
}
args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
args->AddInt("border_x");
args->AddInt("border_y");
args->AddInt("border_z");
args->AddFloat("scale_factor_x");
args->AddFloat("scale_factor_y");
args->AddFloat("scale_factor_z");
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ);
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
if (op_def.IsBatchSupported()) {
c += " int batch_size, \n";
}
c += " int4 border, \n";
c += " float4 scale_factor \n";
c += ") {\n";
c += "$0) {\n";
c += " int Y = get_global_id(1);\n";
c += " int linear_id_z = get_global_id(2);\n";
c += " int S = linear_id_z % dst_size.w;\n";
c += " int Z = linear_id_z / dst_size.w;\n";
c += " int S = linear_id_z % args.dst_tensor.Slices();\n";
c += " int Z = linear_id_z / args.dst_tensor.Slices();\n";
if (op_def.IsBatchSupported()) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / batch_size;\n";
c += " int B = linear_id % batch_size;\n";
c += " if (linear_id >= dst_size.x || Y >= dst_size.y || Z >= "
"dst_size.z) return;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " if (linear_id >= args.dst_tensor.Width() || Y >= "
"args.dst_tensor.Height() || Z >= args.dst_tensor.Depth()) return;\n";
} else {
c += " int X = get_global_id(0);\n";
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) "
"return;\n";
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
"|| Z >= args.dst_tensor.Depth()) return;\n";
}
if (sampling_type == SamplingType::NEAREST) {
c += " int4 coord = (int4)(X * scale_factor.x, Y * scale_factor.y, Z * "
"scale_factor.z, 0);\n";
c += " int4 coord = (int4)(X * args.scale_factor_x, Y * "
"args.scale_factor_y, Z * "
"args.scale_factor_z, 0);\n";
if (op_def.IsBatchSupported()) {
c += " coord.x = coord.x * batch_size + B;\n";
c += " X = X * batch_size + B;\n";
c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
c += " X = X * args.src_tensor.Batch() + B;\n";
}
c += " FLT4 r0 = " +
src_tensor.ReadWHDS("coord.x", "coord.y", "coord.z", "S") + ";\n";
c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, coord.z, S);\n";
} else {
c += " float4 f_coords = (float4)(X, Y, Z, 0) * scale_factor;\n";
c += " float4 f_coords;\n";
c += " f_coords.x = (float)(X) * args.scale_factor_x;\n";
c += " f_coords.y = (float)(Y) * args.scale_factor_y;\n";
c += " f_coords.z = (float)(Z) * args.scale_factor_z;\n";
c += " int4 start = (int4)(f_coords.x, f_coords.y, f_coords.z, 0);\n";
c += " int4 end = min(start + (int4)(1, 1, 1, 0), border);\n";
c += " int4 end;\n";
c += " end.x = min(start.x + 1, args.border_x);\n";
c += " end.y = min(start.y + 1, args.border_y);\n";
c += " end.z = min(start.z + 1, args.border_z);\n";
c += " float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
if (op_def.IsBatchSupported()) {
c += " start.x = start.x * batch_size + B;\n";
c += " end.x = end.x * batch_size + B;\n";
c += " X = X * batch_size + B;\n";
c += " start.x = start.x * args.src_tensor.Batch() + B;\n";
c += " end.x = end.x * args.src_tensor.Batch() + B;\n";
c += " X = X * args.src_tensor.Batch() + B;\n";
}
c += " float4 src0 = " +
src_tensor.ReadAsFloatWHDS("start.x", "start.y", "start.z", "S") +
";\n";
c += " float4 src1 = " +
src_tensor.ReadAsFloatWHDS("end.x", "start.y", "start.z", "S") + ";\n";
c += " float4 src2 = " +
src_tensor.ReadAsFloatWHDS("start.x", "end.y", "start.z", "S") + ";\n";
c += " float4 src3 = " +
src_tensor.ReadAsFloatWHDS("end.x", "end.y", "start.z", "S") + ";\n";
c += " float4 src4 = " +
src_tensor.ReadAsFloatWHDS("start.x", "start.y", "end.z", "S") + ";\n";
c += " float4 src5 = " +
src_tensor.ReadAsFloatWHDS("end.x", "start.y", "end.z", "S") + ";\n";
c += " float4 src6 = " +
src_tensor.ReadAsFloatWHDS("start.x", "end.y", "end.z", "S") + ";\n";
c += " float4 src7 = " +
src_tensor.ReadAsFloatWHDS("end.x", "end.y", "end.z", "S") + ";\n";
c += " float4 src0 = args.src_tensor.Read<float>(start.x, start.y, "
"start.z, S);\n";
c += " float4 src1 = args.src_tensor.Read<float>(end.x, start.y, start.z, "
"S);\n";
c += " float4 src2 = args.src_tensor.Read<float>(start.x, end.y, start.z, "
"S);\n";
c += " float4 src3 = args.src_tensor.Read<float>(end.x, end.y, start.z, "
"S);\n";
c += " float4 src4 = args.src_tensor.Read<float>(start.x, start.y, end.z, "
"S);\n";
c += " float4 src5 = args.src_tensor.Read<float>(end.x, start.y, end.z, "
"S);\n";
c += " float4 src6 = args.src_tensor.Read<float>(start.x, end.y, end.z, "
"S);\n";
c += " float4 src7 = args.src_tensor.Read<float>(end.x, end.y, end.z, "
"S);\n";
c +=
" float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
c +=
" float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
c += " FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
}
const LinkingContext context{"r0", "X", "Y", "S"};
c += PostProcess(linked_operations, context);
c += " " + dst_tensor.WriteWHDS("r0", "X", "Y", "Z", "S");
c += " args.dst_tensor.Write(r0, X, Y, Z, S);\n";
c += "}\n";
return c;
}
@ -210,27 +208,32 @@ Resize& Resize::operator=(Resize&& operation) {
}
absl::Status Resize::Compile(const CreationContext& creation_context) {
const auto code = GetResizeCode(definition_, attr_.type,
attr_.half_pixel_centers, linked_operations_);
std::string code =
GetResizeCode(definition_, attr_.type, attr_.half_pixel_centers, &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status Resize::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
RETURN_IF_ERROR(
kernel_.SetBytesAuto(int2(src_[0]->Width() - 1, src_[0]->Height() - 1)));
float2 scale_factor =
float2(CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_),
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor));
return absl::OkStatus();
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1));
RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1));
RETURN_IF_ERROR(args_.SetFloat(
"scale_factor_x",
CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
RETURN_IF_ERROR(args_.SetFloat(
"scale_factor_y",
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
return args_.Bind(kernel_.kernel());
}
int3 Resize::GetGridSize() const {
@ -272,31 +275,35 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) {
}
absl::Status Resize3D::Compile(const CreationContext& creation_context) {
const auto code =
GetResize3DCode(definition_, attr_.type, linked_operations_);
std::string code = GetResize3DCode(definition_, attr_.type, &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
return creation_context.cache->GetOrCreateCLKernel(
code, "main_function", *creation_context.context,
*creation_context.device, &kernel_);
}
absl::Status Resize3D::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
if (definition_.IsBatchSupported()) {
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Batch()));
}
RETURN_IF_ERROR(kernel_.SetBytesAuto(int4(
src_[0]->Width() - 1, src_[0]->Height() - 1, src_[0]->Depth() - 1, 0)));
float4 scale_factor = float4(
CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_),
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_),
CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_), 1.0f);
RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor));
return absl::OkStatus();
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1));
RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1));
RETURN_IF_ERROR(args_.SetInt("border_z", src_[0]->Depth() - 1));
RETURN_IF_ERROR(args_.SetFloat(
"scale_factor_x",
CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
RETURN_IF_ERROR(args_.SetFloat(
"scale_factor_y",
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
RETURN_IF_ERROR(args_.SetFloat(
"scale_factor_z",
CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
return args_.Bind(kernel_.kernel());
}
int3 Resize3D::GetGridSize() const {