Using native_ equivalents in tanh for FP16.

PiperOrigin-RevId: 333154326
Change-Id: I83ed06d13d1c223cc6d72ca27bb8b4279a14669a
This commit is contained in:
Raman Sarokin 2020-09-22 14:00:00 -07:00 committed by TensorFlower Gardener
parent 50acd58a69
commit ff8cec00c2

View File

@ -67,17 +67,8 @@ std::string GetOneInputCode(const OperationType& op_type,
case OperationType::SIGMOID:
if (precision != CalculationsPrecision::F32) {
result =
"$0.x = convert_half(native_recip(1.0f + "
"native_exp(convert_float(-$0.x))));\n";
result +=
"$0.y = convert_half(native_recip(1.0f + "
"native_exp(convert_float(-$0.y))));\n";
result +=
"$0.z = convert_half(native_recip(1.0f + "
"native_exp(convert_float(-$0.z))));\n";
result +=
"$0.w = convert_half(native_recip(1.0f + "
"native_exp(convert_float(-$0.w))));\n";
"$0 = convert_half4(native_recip(1.0f + "
"native_exp(convert_float4(-$0))));\n";
} else {
result = "$0 = (FLT4)(1.0f) / ((FLT4)(1.0f) + exp(-($0)));\n";
}
@ -92,7 +83,12 @@ std::string GetOneInputCode(const OperationType& op_type,
result = "$0 *= $0;\n";
break;
case OperationType::TANH:
result = "$0 = tanh($0);\n";
if (precision != CalculationsPrecision::F32) {
result = "float4 t = native_exp(convert_float4($0 * 2.0h));\n";
result += "$0 = convert_half4(native_divide(t - 1.0f, t + 1.0f));\n";
} else {
result = "$0 = tanh($0);\n";
}
break;
default:
return "Unknown operation type;\n";