Using native_ equivalents in tanh for FP16.
PiperOrigin-RevId: 333154326 Change-Id: I83ed06d13d1c223cc6d72ca27bb8b4279a14669a
This commit is contained in:
parent
50acd58a69
commit
ff8cec00c2
@ -67,17 +67,8 @@ std::string GetOneInputCode(const OperationType& op_type,
|
||||
case OperationType::SIGMOID:
|
||||
if (precision != CalculationsPrecision::F32) {
|
||||
result =
|
||||
"$0.x = convert_half(native_recip(1.0f + "
|
||||
"native_exp(convert_float(-$0.x))));\n";
|
||||
result +=
|
||||
"$0.y = convert_half(native_recip(1.0f + "
|
||||
"native_exp(convert_float(-$0.y))));\n";
|
||||
result +=
|
||||
"$0.z = convert_half(native_recip(1.0f + "
|
||||
"native_exp(convert_float(-$0.z))));\n";
|
||||
result +=
|
||||
"$0.w = convert_half(native_recip(1.0f + "
|
||||
"native_exp(convert_float(-$0.w))));\n";
|
||||
"$0 = convert_half4(native_recip(1.0f + "
|
||||
"native_exp(convert_float4(-$0))));\n";
|
||||
} else {
|
||||
result = "$0 = (FLT4)(1.0f) / ((FLT4)(1.0f) + exp(-($0)));\n";
|
||||
}
|
||||
@ -92,7 +83,12 @@ std::string GetOneInputCode(const OperationType& op_type,
|
||||
result = "$0 *= $0;\n";
|
||||
break;
|
||||
case OperationType::TANH:
|
||||
result = "$0 = tanh($0);\n";
|
||||
if (precision != CalculationsPrecision::F32) {
|
||||
result = "float4 t = native_exp(convert_float4($0 * 2.0h));\n";
|
||||
result += "$0 = convert_half4(native_divide(t - 1.0f, t + 1.0f));\n";
|
||||
} else {
|
||||
result = "$0 = tanh($0);\n";
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return "Unknown operation type;\n";
|
||||
|
Loading…
Reference in New Issue
Block a user