Added 3rd pass to Softmax in OpenGL es 3.1 to improve numerical stability.

PiperOrigin-RevId: 348473873
Change-Id: I163710c9812e6f785ac36f231a689ef680b1450f
This commit is contained in:
Raman Sarokin 2020-12-21 08:47:52 -08:00 committed by TensorFlower Gardener
parent 373f458fb6
commit 2ccbbdb4b0
2 changed files with 140 additions and 35 deletions

View File

@ -66,25 +66,58 @@ class Softmax : public NodeShader {
};
std::vector<Variable> uniform_parameters = {
{"depth", depth},
{"depth_div_32", DivideRoundUp(depth, 32)},
{"mask", GetMask(ctx.output_shapes[0][3])},
};
std::string source_code = R"(
highp vec4 kOnes = vec4(1.0);
highp float sum = 0.0;
int offset = 0;
int s = 0;
int tid = int(gl_LocalInvocationID.x);
do {
int z = offset + tid;
if (z < $depth$) {
highp vec4 mask_temp = z == $depth$ - 1 ? $mask$ : kOnes;
highp vec4 src = $input_data_0[0, 0, z]$;
sum += dot(mask_temp, exp(src));
offset += 32;
}
s++;
} while (s < $depth_div_32$);
highp vec4 maxx4 = $input_data_0[0, 0, 0]$;
maxx4.y = maxx4.x;
maxx4.z = maxx4.x;
maxx4.w = maxx4.x;
for (int s = tid; s < $depth$; s += 32) {
highp vec4 mask_a = s == $depth$ - 1 ? $mask$ : kOnes;
highp vec4 mask_b = kOnes - mask_a;
highp vec4 src = $input_data_0[0, 0, s]$;
src = src * mask_a + mask_b * src.x;
maxx4 = max(maxx4, src);
}
highp float maximum = max(maxx4.x, maxx4.y);
maximum = max(maximum, maxx4.z);
maximum = max(maximum, maxx4.w);
partial_sum[tid / 4][tid % 4] = maximum;
memoryBarrierShared();
barrier();
if (tid == 0) {
maxx4 = max(partial_sum[0], partial_sum[1]);
maxx4 = max(maxx4, partial_sum[2]);
maxx4 = max(maxx4, partial_sum[3]);
maxx4 = max(maxx4, partial_sum[4]);
maxx4 = max(maxx4, partial_sum[5]);
maxx4 = max(maxx4, partial_sum[6]);
maxx4 = max(maxx4, partial_sum[7]);
maximum = max(maxx4.x, maxx4.y);
maximum = max(maximum, maxx4.z);
maximum = max(maximum, maxx4.w);
partial_sum[0][0] = maximum;
}
memoryBarrierShared();
barrier();
maximum = partial_sum[0][0];
highp float sum = 0.0;
for (int s = tid; s < $depth$; s += 32) {
highp vec4 mask_temp = s == $depth$ - 1 ? $mask$ : kOnes;
highp vec4 src = $input_data_0[0, 0, s]$ - vec4(maximum);
sum += dot(mask_temp, exp(src));
}
memoryBarrierShared();
barrier();
partial_sum[tid / 4][tid % 4] = sum;
@ -108,24 +141,19 @@ class Softmax : public NodeShader {
sum = partial_sum[0][0];
offset = 0;
s = 0;
do {
int z = offset + tid;
if (z < $depth$) {
highp vec4 src = $input_data_0[0, 0, z]$;
highp vec4 temp = exp(src) * sum;
$output_data_0[0, 0, z] = temp$;
offset += 32;
}
s++;
} while (s < $depth_div_32$);
int dst_s = int(gl_GlobalInvocationID.x);
if (dst_s < $depth$) {
highp vec4 src = $input_data_0[0, 0, dst_s]$ - vec4(maximum);
highp vec4 temp = exp(src) * sum;
$output_data_0[0, 0, dst_s] = temp$;
}
)";
*generated_code = {
/*parameters=*/std::move(uniform_parameters),
/*objects=*/{},
/*shared_variables=*/std::move(shared_variables),
/*workload=*/uint3(32, 1, 1),
/*workload=*/uint3(depth, 1, 1),
/*workgroup=*/uint3(32, 1, 1),
/*source_code=*/std::move(source_code),
/*input=*/IOStructure::ONLY_DEFINITIONS,
@ -145,17 +173,24 @@ class Softmax : public NodeShader {
std::string source_code = R"(
highp vec4 kOnes = vec4(1.0);
highp float sum = 0.0;
for (int d = 0; d < $src_depth$ - 1; ++d) {
highp float maximum = $input_data_0[gid.x, gid.y, 0]$.x;
for (int d = 0; d < $src_depth$; ++d) {
highp vec4 mask_a = d == $src_depth$ - 1 ? $mask$ : kOnes;
highp vec4 mask_b = kOnes - mask_a;
highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
sum += dot(kOnes, exp(src));
}
{
int d = $src_depth$ - 1;
highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
sum += dot($mask$, exp(src));
src = src * mask_a + mask_b * src.x;
maximum = max(maximum, src.x);
maximum = max(maximum, src.y);
maximum = max(maximum, src.z);
maximum = max(maximum, src.w);
}
for (int d = 0; d < $src_depth$; ++d) {
highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
highp vec4 mask_temp = d == $src_depth$ - 1 ? $mask$ : kOnes;
highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
sum += dot(mask_temp, exp(src));
}
for (int d = 0; d < $src_depth$; ++d) {
highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
highp vec4 temp_sum = exp(src) / sum;
$output_data_0[gid.x, gid.y, d] = temp_sum$;
}

View File

@ -121,6 +121,76 @@ TEST(SoftmaxTest, Softmax1x1) {
std::exp(0.3f) / sum, std::exp(0.4f) / sum}));
}
TEST(SoftmaxTest, SoftmaxBigNumber) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 1, 2);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 1, 2);
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS;
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
ASSERT_TRUE(std::isinf(std::exp(static_cast<float>(doubles[3]))));
ASSERT_FALSE(std::isinf(std::exp(doubles[3])));
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(
0, {static_cast<float>(doubles[0]), static_cast<float>(doubles[1]),
static_cast<float>(doubles[2]), static_cast<float>(doubles[3])}));
ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6f),
{static_cast<float>(std::exp(doubles[0]) / s0),
static_cast<float>(std::exp(doubles[1]) / s0),
static_cast<float>(std::exp(doubles[2]) / s1),
static_cast<float>(std::exp(doubles[3]) / s1)}));
}
TEST(SoftmaxTest, Softmax1x1BigNumber) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 4);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 1, 4);
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS;
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
ASSERT_TRUE(std::isinf(std::exp(static_cast<float>(doubles[3]))));
ASSERT_FALSE(std::isinf(std::exp(doubles[3])));
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
std::exp(doubles[2]) + std::exp(doubles[3]);
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
{output});
ASSERT_TRUE(model.PopulateTensor(
0, {static_cast<float>(doubles[0]), static_cast<float>(doubles[1]),
static_cast<float>(doubles[2]), static_cast<float>(doubles[3])}));
ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6f),
{static_cast<float>(std::exp(doubles[0]) / s0),
static_cast<float>(std::exp(doubles[1]) / s0),
static_cast<float>(std::exp(doubles[2]) / s0),
static_cast<float>(std::exp(doubles[3]) / s0)}));
}
} // namespace
} // namespace gl
} // namespace gpu