Improving softmax precision.
PiperOrigin-RevId: 283449048 Change-Id: I336e2f7740305aabcea02dac22c7c47d92406bf8
This commit is contained in:
parent
250d9bc96b
commit
51249d605d
@ -62,14 +62,17 @@ class Softmax : public NodeShader {
|
|||||||
std::string source = R"(
|
std::string source = R"(
|
||||||
highp float sum = 0.0;
|
highp float sum = 0.0;
|
||||||
for (int d = 0; d < $src_depth$ - 1; ++d) {
|
for (int d = 0; d < $src_depth$ - 1; ++d) {
|
||||||
sum += dot(vec4(1.0), exp($input_data_0[gid.x, gid.y, d]$));
|
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
|
||||||
|
sum += dot(vec4(1.0), exp(v));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
int d = $src_depth$ - 1;
|
int d = $src_depth$ - 1;
|
||||||
sum += dot($mask$, exp($input_data_0[gid.x, gid.y, d]$));
|
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
|
||||||
|
sum += dot($mask$, exp(v));
|
||||||
}
|
}
|
||||||
for (int d = 0; d < $src_depth$; ++d) {
|
for (int d = 0; d < $src_depth$; ++d) {
|
||||||
vec4 temp_sum = exp($input_data_0[gid.x, gid.y, d]$) / sum;
|
highp vec4 v = $input_data_0[gid.x, gid.y, d]$;
|
||||||
|
vec4 temp_sum = exp(v) / sum;
|
||||||
$output_data_0[gid.x, gid.y, d] = temp_sum$;
|
$output_data_0[gid.x, gid.y, d] = temp_sum$;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
Loading…
x
Reference in New Issue
Block a user