Improving softmax precision.
PiperOrigin-RevId: 283449048 Change-Id: I336e2f7740305aabcea02dac22c7c47d92406bf8
This commit is contained in:
parent
250d9bc96b
commit
51249d605d
@ -62,14 +62,17 @@ class Softmax : public NodeShader {
|
||||
std::string source = R"(
|
||||
highp float sum = 0.0;
|
||||
for (int d = 0; d < $src_depth$ - 1; ++d) {
|
||||
sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$));
|
||||
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
|
||||
sum += dot(vec4(1.0), exp(v));
|
||||
}
|
||||
{
|
||||
int d = $src_depth$ - 1;
|
||||
sum += dot($mask$, exp($input_data_0[gid.x, gid.y, d]$));
|
||||
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
|
||||
sum += dot($mask$, exp(v));
|
||||
}
|
||||
for (int d = 0; d < $src_depth$; ++d) {
|
||||
vec4 temp_sum = exp($input_data_0[gid.x, gid.y, d]$) / sum;
|
||||
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
|
||||
vec4 temp_sum = exp(v) / sum;
|
||||
$output_data_0[gid.x, gid.y, d] = temp_sum$;
|
||||
}
|
||||
)";
|
||||
|
Loading…
x
Reference in New Issue
Block a user