Cleaning, removed old style linking variables.

PiperOrigin-RevId: 351392707
Change-Id: Ibe28063fdddda49502dea9d27afdd565374ae9b8
This commit is contained in:
Raman Sarokin 2021-01-12 10:00:15 -08:00 committed by TensorFlower Gardener
parent 61a7ab7203
commit 242e0a5815
16 changed files with 6 additions and 93 deletions

View File

@ -45,7 +45,6 @@ kernel void ComputeFunction($1
return; return;
} }
FLT4 value = args.src_tensor.Read(X, Y, Z); FLT4 value = args.src_tensor.Read(X, Y, Z);
$2
args.dst_tensor.Write(value, X, Y, Z); args.dst_tensor.Write(value, X, Y, Z);
} }
)"; )";
@ -148,7 +147,6 @@ void ComputeTaskDescriptor::AssembleCode() {
return std::make_pair(groups_size, groups_count); return std::make_pair(groups_size, groups_count);
}; };
} }
shader_source = absl::Substitute(shader_source, "$0", "$1", "");
} }
} // namespace metal } // namespace metal

View File

@ -84,23 +84,14 @@ std::string GetConcatChannelsCode(const OperationDef& op_def,
c += " for (int i = 0; i < " + t_name + ".Slices(); i += 2) {\n"; c += " for (int i = 0; i < " + t_name + ".Slices(); i += 2) {\n";
c += " FLT4 result0 = " + t_name + ".Read(" + coords + ", i);\n"; c += " FLT4 result0 = " + t_name + ".Read(" + coords + ", i);\n";
c += " FLT4 result1 = " + t_name + ".Read(" + coords + ", i + 1);\n"; c += " FLT4 result1 = " + t_name + ".Read(" + coords + ", i + 1);\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(S));\n"; c += " args.dst_tensor.Write(result0, " + coords + ", S);\n";
c += " $2\n"; c += " args.dst_tensor.Write(result1, " + coords + ", S + 1);\n";
c += " FLT4 value = result0;\n";
c += " args.dst_tensor.Write(value, " + coords + ", S);\n";
c += " gid = uint3(ugid.x, ugid.y, uint(S + 1));\n";
c += " $2\n";
c += " value = result1;\n";
c += " args.dst_tensor.Write(value, " + coords + ", S + 1);\n";
c += " S += 2;\n"; c += " S += 2;\n";
c += " }\n"; c += " }\n";
} else { } else {
c += " for (int i = 0; i < " + t_name + ".Slices(); ++i) {\n"; c += " for (int i = 0; i < " + t_name + ".Slices(); ++i) {\n";
c += " FLT4 result = " + t_name + ".Read(" + coords + ", i);\n"; c += " FLT4 result = " + t_name + ".Read(" + coords + ", i);\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(S));\n"; c += " args.dst_tensor.Write(result, " + coords + ", S);\n";
c += " $2\n";
c += " FLT4 value = result;\n";
c += " args.dst_tensor.Write(value, " + coords + ", S);\n";
c += " S++;\n"; c += " S++;\n";
c += " }\n"; c += " }\n";
} }
@ -125,11 +116,6 @@ std::string GetConcatChannelsCode(const OperationDef& op_def,
out_channel++; out_channel++;
if (out_channel == 4) { if (out_channel == 4) {
out_channel = 0; out_channel = 0;
c += " {\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(" +
std::to_string(z) + "));\n";
c += " $2\n";
c += " }\n";
c += " args.dst_tensor.Write(value, " + coords + ", " + c += " args.dst_tensor.Write(value, " + coords + ", " +
std::to_string(z) + ");\n"; std::to_string(z) + ");\n";
z++; z++;
@ -139,11 +125,6 @@ std::string GetConcatChannelsCode(const OperationDef& op_def,
} }
} }
if (out_channel != 0) { if (out_channel != 0) {
c += " {\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, uint(" + std::to_string(z) +
"));\n";
c += " $2\n";
c += " }\n";
c += " args.dst_tensor.Write(value, " + coords + ", " + c += " args.dst_tensor.Write(value, " + coords + ", " +
std::to_string(z) + ");\n"; std::to_string(z) + ");\n";
} }
@ -273,10 +254,6 @@ std::string GetConcatKernelCode(const OperationDef& op_def,
c += " } \n"; c += " } \n";
c += " coord -= " + field + ";\n"; c += " coord -= " + field + ";\n";
} }
c += " {\n";
c += " uint3 gid = uint3(ugid.x, ugid.y, ugid.z);\n";
c += " $2\n";
c += " }\n";
c += " args.dst_tensor.Write(value, " + dst_coord + ");\n"; c += " args.dst_tensor.Write(value, " + dst_coord + ");\n";
c += "}\n"; c += "}\n";
return c; return c;

View File

@ -549,9 +549,6 @@ kernel void ComputeFunction(
c += " FLT4 value = FLT4(r" + s_zyx + ");\n"; c += " FLT4 value = FLT4(r" + s_zyx + ");\n";
c += " int linear_index = offset_" + s_yx + c += " int linear_index = offset_" + s_yx +
" + args.dst_tensor.SliceStride() * " + s_z + ";\n"; " + args.dst_tensor.SliceStride() * " + s_z + ";\n";
c += " uint3 gid = uint3(X + " + s_x + ", Y + " + s_y + ", Z + " +
s_z + ");\n";
c += " $2\n";
c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " + c += " args.dst_tensor.Linking(value, X + " + s_x + ", Y + " +
s_y + ", Z + " + s_z + ");\n"; s_y + ", Z + " + s_z + ");\n";
c += " args.dst_tensor.WriteLinear(value, linear_index);\n"; c += " args.dst_tensor.WriteLinear(value, linear_index);\n";

View File

@ -160,26 +160,18 @@ kernel void ComputeFunction(
if (y0_in && x0_in) { if (y0_in && x0_in) {
FLT4 value = FLT4(r0); FLT4 value = FLT4(r0);
uint3 gid = uint3(gid_x, gid_y, gid_z);
$2
args.dst_tensor.Write(value, gid_x, gid_y, gid_z); args.dst_tensor.Write(value, gid_x, gid_y, gid_z);
} }
if (y1_in && x0_in) { if (y1_in && x0_in) {
FLT4 value = FLT4(l0); FLT4 value = FLT4(l0);
uint3 gid = uint3(gid_x, gid_y + 1, gid_z);
$2
args.dst_tensor.Write(value, gid_x, gid_y + 1, gid_z); args.dst_tensor.Write(value, gid_x, gid_y + 1, gid_z);
} }
if (y0_in && x1_in) { if (y0_in && x1_in) {
FLT4 value = FLT4(t0); FLT4 value = FLT4(t0);
uint3 gid = uint3(gid_x + 1, gid_y, gid_z);
$2
args.dst_tensor.Write(value, gid_x + 1, gid_y, gid_z); args.dst_tensor.Write(value, gid_x + 1, gid_y, gid_z);
} }
if (y1_in && x1_in) { if (y1_in && x1_in) {
FLT4 value = FLT4(b0); FLT4 value = FLT4(b0);
uint3 gid = uint3(gid_x + 1, gid_y + 1, gid_z);
$2
args.dst_tensor.Write(value, gid_x + 1, gid_y + 1, gid_z); args.dst_tensor.Write(value, gid_x + 1, gid_y + 1, gid_z);
} }
} }
@ -330,14 +322,10 @@ kernel void ComputeFunction(
if (y0_in) { if (y0_in) {
FLT4 value = FLT4(r0); FLT4 value = FLT4(r0);
uint3 gid = uint3(gid_x, gid_y, gid_z);
$2
args.dst_tensor.Write(value, gid_x, gid_y, gid_z); args.dst_tensor.Write(value, gid_x, gid_y, gid_z);
} }
if (y1_in) { if (y1_in) {
FLT4 value = FLT4(l0); FLT4 value = FLT4(l0);
uint3 gid = uint3(gid_x, gid_y + 1, gid_z);
$2
args.dst_tensor.Write(value, gid_x, gid_y + 1, gid_z); args.dst_tensor.Write(value, gid_x, gid_y + 1, gid_z);
} }
} }
@ -454,9 +442,7 @@ kernel void ComputeFunction(
} }
} }
FLT4 res = FLT4(sum0) + args.biases.Read(dst_z); FLT4 res = FLT4(sum0) + args.biases.Read(dst_z);
FLT4 value = res; args.dst_tensor.Write(res, dst_x, dst_y, dst_z);
$2
args.dst_tensor.Write(value, dst_x, dst_y, dst_z);
} }
)"; )";
ComputeTaskDescriptor desc(definition); ComputeTaskDescriptor desc(definition);

View File

@ -104,8 +104,6 @@ std::string GetFullyConnectedCode(const GpuInfo& gpu_info, int src_channels,
if (tid.y == 0 && tid.x % 4 == 0 && dst_s < args.dst_tensor.Slices()) { if (tid.y == 0 && tid.x % 4 == 0 && dst_s < args.dst_tensor.Slices()) {
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) + FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
args.bias.Read(dst_s); args.bias.Read(dst_s);
uint3 gid = uint3(0u, 0u, uint(dst_s));
$$2
args.dst_tensor.Write(value, 0, 0, dst_s); args.dst_tensor.Write(value, 0, 0, dst_s);
} }
} }

View File

@ -67,7 +67,6 @@ kernel void ComputeFunction(
value.z = t_index == indexes.z ? src_color.z : 0.0; value.z = t_index == indexes.z ? src_color.z : 0.0;
value.w = t_index == indexes.w ? src_color.w : 0.0; value.w = t_index == indexes.w ? src_color.w : 0.0;
$2
args.dst_tensor.Write(value, X, Y, gid.z); args.dst_tensor.Write(value, X, Y, gid.z);
} }
)"; )";

View File

@ -86,7 +86,6 @@ std::string GetMeanCode(const int3& work_group_size) {
} }
c += " FLT4 value = FLT4(sum * args.inv_multiplier_y);\n"; c += " FLT4 value = FLT4(sum * args.inv_multiplier_y);\n";
c += R"( c += R"(
$2
args.dst_tensor.Write(value, 0, 0, gid.z); args.dst_tensor.Write(value, 0, 0, gid.z);
} }
)"; )";

View File

@ -119,7 +119,6 @@ std::string GetPaddingCode(const PadAttributes& attr) {
} }
code += " }\n"; code += " }\n";
} }
code += " $2\n";
code += " args.dst_tensor.Write(value, gid.x, gid.y, gid.z);\n"; code += " args.dst_tensor.Write(value, gid.x, gid.y, gid.z);\n";
code += "}\n"; code += "}\n";
return code; return code;

View File

@ -59,9 +59,7 @@ kernel void ComputeFunction(
maximum = max(maximum, src_color); maximum = max(maximum, src_color);
} }
} }
FLT4 value = maximum; args.dst_tensor.Write(maximum, gid.x, gid.y, gid.z);
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
} }
)"; )";
return shader_source; return shader_source;
@ -110,9 +108,7 @@ kernel void ComputeFunction(
index_counter++; index_counter++;
} }
} }
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
FLT4 value = static_cast<FLT4>(indexes); FLT4 value = static_cast<FLT4>(indexes);
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z); args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
} }
)"; )";
@ -147,11 +143,9 @@ kernel void ComputeFunction(
sum += src_color; sum += src_color;
} }
} }
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
// If window_size==0, window covered nothing. This situation is a sign of // If window_size==0, window covered nothing. This situation is a sign of
// incorrectly constructed operation. NaNs are expected as output. // incorrectly constructed operation. NaNs are expected as output.
FLT4 value = FLT4(sum / window_size); FLT4 value = FLT4(sum / window_size);
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z); args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
} }
)"; )";

View File

@ -62,8 +62,6 @@ kernel void ComputeFunction(
value[i] = args.src_tensor.Read(src_x, src_y, src_layer)[src_channel]; value[i] = args.src_tensor.Read(src_x, src_y, src_layer)[src_channel];
} }
} }
$2
args.dst_tensor.Write(value, igid.x, igid.y, igid.z); args.dst_tensor.Write(value, igid.x, igid.y, igid.z);
})"; })";
return code; return code;
@ -91,7 +89,6 @@ kernel void ComputeFunction(
int src_z = t0 - src_x * args.src_tensor.Slices(); // t0 % args.src_tensor.Slices(); int src_z = t0 - src_x * args.src_tensor.Slices(); // t0 % args.src_tensor.Slices();
FLT4 value = args.src_tensor.Read(src_x, src_y, src_z); FLT4 value = args.src_tensor.Read(src_x, src_y, src_z);
$2
args.dst_tensor.Write(value, X, Y, Z); args.dst_tensor.Write(value, X, Y, Z);
})"; })";
return code; return code;

View File

@ -65,7 +65,6 @@ kernel void ComputeFunction(
// bilinear interpolation // bilinear interpolation
FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)), FLT4 value = mix(mix(tex11, tex21, static_cast<FLT>(t.x)),
mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y)); mix(tex12, tex22, static_cast<FLT>(t.x)), static_cast<FLT>(t.y));
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z); args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
} }
)"; )";
@ -106,8 +105,6 @@ kernel void ComputeFunction(
c += " coord.y = min(coord.y, args.src_tensor.Height() - 1);\n"; c += " coord.y = min(coord.y, args.src_tensor.Height() - 1);\n";
c += R"( c += R"(
FLT4 value = args.src_tensor.Read(coord.x, coord.y, gid.z); FLT4 value = args.src_tensor.Read(coord.x, coord.y, gid.z);
args.dst_tensor.GetAddress(linear_index, gid.x, gid.y, gid.z);
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z); args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
} }
)"; )";

View File

@ -134,9 +134,7 @@ kernel void ComputeFunction($1
c += " }\n"; c += " }\n";
} }
} }
c += " FLT4 value = result;\n"; c += " args.dst_tensor.Write(result, X, Y, Z);\n";
c += " $2\n";
c += " args.dst_tensor.Write(value, X, Y, Z);\n";
c += "}\n"; c += "}\n";
return c; return c;
} }

View File

@ -117,8 +117,6 @@ kernel void ComputeFunction($1
if (dst_s < args.src_tensor.Slices()) { if (dst_s < args.src_tensor.Slices()) {
float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum); float4 src = float4(args.src_tensor.Read(0, 0, dst_s)) - float4(maximum);
FLT4 value = FLT4(exp(src) * sum); FLT4 value = FLT4(exp(src) * sum);
uint3 gid = uint3(0, 0, dst_s);
$2
args.dst_tensor.Write(value, 0, 0, dst_s); args.dst_tensor.Write(value, 0, 0, dst_s);
} }
})"; })";
@ -161,7 +159,6 @@ kernel void ComputeFunction(
for (int d = 0; d < args.dst_tensor.Slices(); ++d) { for (int d = 0; d < args.dst_tensor.Slices(); ++d) {
float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum); float4 src = float4(args.src_tensor.Read(gid.x, gid.y, d)) - float4(maximum);
FLT4 value = FLT4(exp(src) / sum); FLT4 value = FLT4(exp(src) / sum);
$2
args.dst_tensor.Write(value, gid.x, gid.y, d); args.dst_tensor.Write(value, gid.x, gid.y, d);
} }
} }

View File

@ -50,7 +50,6 @@ kernel void ComputeFunction($1 uint3 gid[[thread_position_in_grid]]) {
uint src_c = dst_c % args.src_tensor.Channels(); uint src_c = dst_c % args.src_tensor.Channels();
value[i] = args.src_tensor.Read(src_x, src_y, src_c / 4)[src_c % 4]; value[i] = args.src_tensor.Read(src_x, src_y, src_c / 4)[src_c % 4];
} }
$2
args.dst_tensor.Write(value, gid.x, gid.y, gid.z); args.dst_tensor.Write(value, gid.x, gid.y, gid.z);
})"; })";

View File

@ -103,8 +103,6 @@ std::string GetDeconvolution(const ConvolutionTransposedAttributes& attr) {
for (short l = 0; l < dst_depth; ++l) { for (short l = 0; l < dst_depth; ++l) {
FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + args.biases.Read(l); FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + args.biases.Read(l);
uint3 gid = uint3(ugid.x, ugid.y, uint(l));
$$2
args.dst_tensor.Write(value, ugid.x, ugid.y, l); args.dst_tensor.Write(value, ugid.x, ugid.y, l);
} }
} }
@ -225,8 +223,6 @@ std::string GetDeconvolutionShared(const ConvolutionTransposedAttributes& attr,
for (short l = 0; l < dst_depth; ++l) { for (short l = 0; l < dst_depth; ++l) {
FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + args.biases.Read(l); FLT4 value = FLT4(out[l * 4], out[l * 4 + 1], out[l * 4 + 2], out[l * 4 + 3]) + args.biases.Read(l);
uint3 gid = uint3(ugid.x, ugid.y, uint(l));
$$2
args.dst_tensor.Write(value, ugid.x, ugid.y, l); args.dst_tensor.Write(value, ugid.x, ugid.y, l);
} }
} }
@ -400,8 +396,6 @@ kernel void ComputeFunction(
c += " if (" + x_check + " && " + y_check + ") {\n"; c += " if (" + x_check + " && " + y_check + ") {\n";
c += " FLT4 value = FLT4(" + R + ") + bias_val;\n"; c += " FLT4 value = FLT4(" + R + ") + bias_val;\n";
std::string dst_coords = dst_x + ", " + dst_y + ", Z"; std::string dst_coords = dst_x + ", " + dst_y + ", Z";
c += " uint3 gid = uint3(" + dst_coords + ");\n";
c += " $2\n";
c += " args.dst_tensor.Write(value, " + dst_coords + ");\n"; c += " args.dst_tensor.Write(value, " + dst_coords + ");\n";
c += " }\n"; c += " }\n";
} }

View File

@ -307,28 +307,20 @@ kernel void ComputeFunction($1
FLT4 t1 = I[y][3] + I[y][4]; FLT4 t1 = I[y][3] + I[y][4];
if (tile_x < args.dst_tensor.Width()) { if (tile_x < args.dst_tensor.Width()) {
FLT4 value = I[y][0] + t0 + t1 + bias_val; FLT4 value = I[y][0] + t0 + t1 + bias_val;
uint3 gid = uint3(tile_x, tile_y + y, global_ids.z);
$2
args.dst_tensor.Write(value, tile_x, tile_y + y, global_ids.z); args.dst_tensor.Write(value, tile_x, tile_y + y, global_ids.z);
} }
FLT4 t2 = I[y][1] - I[y][2]; FLT4 t2 = I[y][1] - I[y][2];
FLT4 t3 = I[y][3] - I[y][4]; FLT4 t3 = I[y][3] - I[y][4];
if (tile_x + 1 < args.dst_tensor.Width()) { if (tile_x + 1 < args.dst_tensor.Width()) {
FLT4 value = t2 * At[7] + t3 * At[9] + bias_val; FLT4 value = t2 * At[7] + t3 * At[9] + bias_val;
uint3 gid = uint3(tile_x + 1, tile_y + y, global_ids.z);
$2
args.dst_tensor.Write(value, tile_x + 1, tile_y + y, global_ids.z); args.dst_tensor.Write(value, tile_x + 1, tile_y + y, global_ids.z);
} }
if (tile_x + 2 < args.dst_tensor.Width()) { if (tile_x + 2 < args.dst_tensor.Width()) {
FLT4 value = t0 * At[13] + t1 * At[15] + bias_val; FLT4 value = t0 * At[13] + t1 * At[15] + bias_val;
uint3 gid = uint3(tile_x + 2, tile_y + y, global_ids.z);
$2
args.dst_tensor.Write(value, tile_x + 2, tile_y + y, global_ids.z); args.dst_tensor.Write(value, tile_x + 2, tile_y + y, global_ids.z);
} }
if (tile_x + 3 < args.dst_tensor.Width()) { if (tile_x + 3 < args.dst_tensor.Width()) {
FLT4 value = t2 * At[19] + t3 * At[21] + I[y][5] + bias_val; FLT4 value = t2 * At[19] + t3 * At[21] + I[y][5] + bias_val;
uint3 gid = uint3(tile_x + 3, tile_y + y, global_ids.z);
$2
args.dst_tensor.Write(value, tile_x + 3, tile_y + y, global_ids.z); args.dst_tensor.Write(value, tile_x + 3, tile_y + y, global_ids.z);
} }
} }
@ -411,28 +403,20 @@ kernel void ComputeFunction($1
FLT4 bias_val = args.biases.Read(DST_Z); FLT4 bias_val = args.biases.Read(DST_Z);
if (tile_x < args.dst_tensor.Width()) { if (tile_x < args.dst_tensor.Width()) {
FLT4 value = I0 + t0 + t1 + bias_val; FLT4 value = I0 + t0 + t1 + bias_val;
uint3 gid = uint3(tile_x, tile_y, global_ids.z);
$2;
args.dst_tensor.Write(value, tile_x, tile_y, global_ids.z); args.dst_tensor.Write(value, tile_x, tile_y, global_ids.z);
} }
FLT4 t2 = I1 - I2; FLT4 t2 = I1 - I2;
FLT4 t3 = I3 - I4; FLT4 t3 = I3 - I4;
if (tile_x + 1 < args.dst_tensor.Width()) { if (tile_x + 1 < args.dst_tensor.Width()) {
FLT4 value = t2 * At[7] + t3 * At[9] + bias_val; FLT4 value = t2 * At[7] + t3 * At[9] + bias_val;
uint3 gid = uint3(tile_x + 1, tile_y, global_ids.z);
$2;
args.dst_tensor.Write(value, tile_x + 1, tile_y, global_ids.z); args.dst_tensor.Write(value, tile_x + 1, tile_y, global_ids.z);
} }
if (tile_x + 2 < args.dst_tensor.Width()) { if (tile_x + 2 < args.dst_tensor.Width()) {
FLT4 value = t0 * At[13] + t1 * At[15] + bias_val; FLT4 value = t0 * At[13] + t1 * At[15] + bias_val;
uint3 gid = uint3(tile_x + 2, tile_y, global_ids.z);
$2;
args.dst_tensor.Write(value, tile_x + 2, tile_y, global_ids.z); args.dst_tensor.Write(value, tile_x + 2, tile_y, global_ids.z);
} }
if (tile_x + 3 < args.dst_tensor.Width()) { if (tile_x + 3 < args.dst_tensor.Width()) {
FLT4 value = t2 * At[19] + t3 * At[21] + I5 + bias_val; FLT4 value = t2 * At[19] + t3 * At[21] + I5 + bias_val;
uint3 gid = uint3(tile_x + 3, tile_y, global_ids.z);
$2;
args.dst_tensor.Write(value, tile_x + 3, tile_y, global_ids.z); args.dst_tensor.Write(value, tile_x + 3, tile_y, global_ids.z);
} }
} }