Using common names(as in common/tasks) in Metal elementwise ops.
PiperOrigin-RevId: 351397942 Change-Id: Id292f0512eeba335c34b8740f794f33b4b562ae9
This commit is contained in:
parent
db98378537
commit
c62635d663
@ -39,8 +39,8 @@ ComputeTaskDescriptor Add(const OperationDef& definition) {
|
||||
for (int i = 1; i < definition.src_tensors.size(); ++i) {
|
||||
const std::string tensor_name = "src_tensor_" + std::to_string(i);
|
||||
desc.AddSrcTensor(tensor_name, definition.src_tensors[i]);
|
||||
desc.shader_source +=
|
||||
" value += args." + tensor_name + ".Read(gid.x, gid.y, gid.z);\n";
|
||||
desc.shader_source += " in_out_value += args." + tensor_name +
|
||||
".Read(X_COORD, Y_COORD, S_COORD);\n";
|
||||
}
|
||||
|
||||
return desc;
|
||||
|
||||
@ -88,9 +88,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
|
||||
OperationType op_type) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
const std::string x_coord = second_shape.w == 1 ? "0" : "gid.x";
|
||||
const std::string y_coord = second_shape.h == 1 ? "0" : "gid.y";
|
||||
const std::string s_coord = second_shape.c == 1 ? "0" : "gid.z";
|
||||
const std::string x_coord = second_shape.w == 1 ? "0" : "X_COORD";
|
||||
const std::string y_coord = second_shape.h == 1 ? "0" : "Y_COORD";
|
||||
const std::string s_coord = second_shape.c == 1 ? "0" : "S_COORD";
|
||||
std::string code;
|
||||
code = " FLT4 src_1 = args.second_tensor.Read(" + x_coord + ", " + y_coord +
|
||||
", " + s_coord + ");\n";
|
||||
@ -99,7 +99,9 @@ ComputeTaskDescriptor ElementwiseWithTwoInputs(const OperationDef& definition,
|
||||
code += " src_1.z = src_1.x;\n";
|
||||
code += " src_1.w = src_1.x;\n";
|
||||
}
|
||||
code += " value = " + TwoInputFunctor(op_type, "value", "src_1") + ";\n";
|
||||
code +=
|
||||
" in_out_value = " + TwoInputFunctor(op_type, "in_out_value", "src_1") +
|
||||
";\n";
|
||||
|
||||
desc.shader_source = code;
|
||||
|
||||
@ -112,7 +114,7 @@ ComputeTaskDescriptor ElementwiseWithOneInput(const OperationDef& definition,
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source =
|
||||
" value = " + OneInputFunctor(op_type, "value") + ";\n";
|
||||
" in_out_value = " + OneInputFunctor(op_type, "in_out_value") + ";\n";
|
||||
return desc;
|
||||
}
|
||||
|
||||
@ -138,7 +140,7 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
|
||||
linear_desc.size = linear_desc.data.size();
|
||||
desc.args.AddObject(
|
||||
"linear", absl::make_unique<BufferDescriptor>(std::move(linear_desc)));
|
||||
desc.shader_source += " FLT4 second_arg = args.linear.Read(gid.z);\n";
|
||||
desc.shader_source += " FLT4 second_arg = args.linear.Read(S_COORD);\n";
|
||||
} else if (hwc_buf) {
|
||||
TensorDescriptor hwc_desc{definition.GetDataType(),
|
||||
TensorStorageType::BUFFER, Layout::HWC};
|
||||
@ -146,9 +148,9 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
|
||||
desc.args.AddObject(
|
||||
"hwc", absl::make_unique<TensorDescriptor>(std::move(hwc_desc)));
|
||||
|
||||
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "gid.x";
|
||||
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "gid.y";
|
||||
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "gid.z";
|
||||
const std::string x_coord = hwc_buf->shape.w == 1 ? "0" : "X_COORD";
|
||||
const std::string y_coord = hwc_buf->shape.h == 1 ? "0" : "Y_COORD";
|
||||
const std::string s_coord = hwc_buf->shape.c == 1 ? "0" : "S_COORD";
|
||||
desc.shader_source += " FLT4 second_arg = args.hwc.Read(" + x_coord +
|
||||
", " + y_coord + ", " + s_coord + ");\n";
|
||||
if (hwc_buf->shape.c == 1) {
|
||||
@ -157,8 +159,9 @@ ComputeTaskDescriptor ElementwiseWithOneInputAndConstantArguent(
|
||||
desc.shader_source += " second_arg.w = second_arg.x;\n";
|
||||
}
|
||||
}
|
||||
desc.shader_source +=
|
||||
" value = " + TwoInputFunctor(op_type, "value", "second_arg") + ";\n";
|
||||
desc.shader_source += " in_out_value = " +
|
||||
TwoInputFunctor(op_type, "in_out_value", "second_arg") +
|
||||
";\n";
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
||||
@ -46,12 +46,12 @@ ComputeTaskDescriptor PReLU(const OperationDef& definition,
|
||||
desc.args.AddFloat("clip", attr.clip);
|
||||
desc.shader_source =
|
||||
R"(
|
||||
value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
|
||||
in_out_value = FLT4(clamp(in_out_value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(S_COORD) * min(FLT4(0.0f), in_out_value));
|
||||
)";
|
||||
} else {
|
||||
desc.shader_source =
|
||||
R"(
|
||||
value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.z) * min(FLT4(0.0f), value));
|
||||
in_out_value = FLT4(max(FLT4(0.0f), in_out_value) + args.alpha.Read(S_COORD) * min(FLT4(0.0f), in_out_value));
|
||||
)";
|
||||
}
|
||||
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
|
||||
@ -79,12 +79,12 @@ ComputeTaskDescriptor PReLUFull(const OperationDef& definition,
|
||||
desc.args.AddFloat("clip", attr.clip);
|
||||
desc.shader_source =
|
||||
R"(
|
||||
value = FLT4(clamp(value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
|
||||
in_out_value = FLT4(clamp(in_out_value, FLT4(0.0f), FLT4(args.clip)) + args.alpha.Read(X_COORD, Y_COORD, S_COORD) * min(FLT4(0.0f), in_out_value));
|
||||
)";
|
||||
} else {
|
||||
desc.shader_source =
|
||||
R"(
|
||||
value = FLT4(max(FLT4(0.0f), value) + args.alpha.Read(gid.x, gid.y, gid.z) * min(FLT4(0.0f), value));
|
||||
in_out_value = FLT4(max(FLT4(0.0f), in_out_value) + args.alpha.Read(X_COORD, Y_COORD, S_COORD) * min(FLT4(0.0f), in_out_value));
|
||||
)";
|
||||
}
|
||||
TensorDescriptor alpha_desc{definition.GetDataType(),
|
||||
|
||||
@ -29,9 +29,9 @@ ComputeTaskDescriptor QuantizeAndDequantize(
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
desc.shader_source = R"(
|
||||
value = clamp(value, FLT4(args.qmin), FLT4(args.qmax));
|
||||
value = (value - FLT4(args.qmin)) / FLT4(args.qscale);
|
||||
value = round(value) * FLT4(args.qscale) + FLT4(args.qmin);)";
|
||||
in_out_value = clamp(in_out_value, FLT4(args.qmin), FLT4(args.qmax));
|
||||
in_out_value = (in_out_value - FLT4(args.qmin)) / FLT4(args.qscale);
|
||||
in_out_value = round(in_out_value) * FLT4(args.qscale) + FLT4(args.qmin);)";
|
||||
desc.args.AddFloat("qmax", attr.max);
|
||||
desc.args.AddFloat("qmin", attr.min);
|
||||
desc.args.AddFloat("qscale", attr.scale);
|
||||
|
||||
@ -35,12 +35,13 @@ ComputeTaskDescriptor ReLU(const OperationDef& definition,
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.is_linkable = true;
|
||||
const std::string min_func =
|
||||
attr.alpha == 0 ? "FLT4(0.0f)" : "min(value * args.alpha, 0.0f)";
|
||||
attr.alpha == 0 ? "FLT4(0.0f)" : "min(in_out_value * args.alpha, 0.0f)";
|
||||
if (attr.clip != 0.0) {
|
||||
desc.shader_source =
|
||||
"value = FLT4(clamp(value, " + min_func + ", FLT4(args.clip)));";
|
||||
desc.shader_source = "in_out_value = FLT4(clamp(in_out_value, " + min_func +
|
||||
", FLT4(args.clip)));";
|
||||
} else {
|
||||
desc.shader_source = "value = FLT4(max(value, " + min_func + "));";
|
||||
desc.shader_source =
|
||||
"in_out_value = FLT4(max(in_out_value, " + min_func + "));";
|
||||
}
|
||||
desc.args.AddFloat("alpha", attr.alpha);
|
||||
desc.args.AddFloat("clip", attr.clip);
|
||||
|
||||
@ -506,10 +506,10 @@ absl::Status MetalArguments::ResolveSelector(
|
||||
// x_coord can have batch size property of link_object
|
||||
ResolveObjectNames(object_name, names, &x_coord);
|
||||
*result = it->second;
|
||||
ReplaceAllWords("value", value_name, result);
|
||||
ReplaceAllWords("gid.x", x_coord, result);
|
||||
ReplaceAllWords("gid.y", y_coord, result);
|
||||
ReplaceAllWords("gid.z", s_coord, result);
|
||||
ReplaceAllWords("in_out_value", value_name, result);
|
||||
ReplaceAllWords("X_COORD", x_coord, result);
|
||||
ReplaceAllWords("Y_COORD", y_coord, result);
|
||||
ReplaceAllWords("S_COORD", s_coord, result);
|
||||
RETURN_IF_ERROR(ResolveSelectorsPass(args, {}, result));
|
||||
if (selector == "Linking") {
|
||||
return absl::OkStatus();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user