Improving softmax precision.

PiperOrigin-RevId: 283449048
Change-Id: I336e2f7740305aabcea02dac22c7c47d92406bf8
This commit is contained in:
A. Unique TensorFlower 2019-12-02 17:04:12 -08:00 committed by TensorFlower Gardener
parent 250d9bc96b
commit 51249d605d

View File

@ -62,14 +62,17 @@ class Softmax : public NodeShader {
std::string source = R"(
highp float sum = 0.0;
for (int d = 0; d < $src_depth$ - 1; ++d) {
sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$));
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
sum += dot(vec4(1.0), exp(v));
}
{
int d = $src_depth$ - 1;
sum += dot($mask$, exp($input_data_0[gid.x, gid.y, d]$));
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
sum += dot($mask$, exp(v));
}
for (int d = 0; d < $src_depth$; ++d) {
vec4 temp_sum = exp($input_data_0[gid.x, gid.y, d]$) / sum;
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
vec4 temp_sum = exp(v) / sum;
$output_data_0[gid.x, gid.y, d] = temp_sum$;
}
)";