Added 3rd pass to Softmax in Metal to improve numerical stability.

PiperOrigin-RevId: 348569651
Change-Id: I309d607747ae50f4b30cc46501f128abf3e6d4a0
This commit is contained in:
Raman Sarokin 2020-12-21 20:07:46 -08:00 committed by TensorFlower Gardener
parent d712b970c9
commit 646d25d159
6 changed files with 202 additions and 78 deletions

View File

@ -124,10 +124,10 @@ std::unique_ptr<ComputeTaskDescriptor> SelectSoftmax(const OperationDef& op_def,
const BHWC& src_shape,
const GpuInfo& gpu_info) {
if (src_shape.w == 1 && src_shape.h == 1) {
auto gpu_op = Softmax1x1(op_def, gpu_info, src_shape.c);
auto gpu_op = Softmax1x1(op_def, gpu_info);
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
} else {
auto gpu_op = Softmax(op_def, src_shape.c);
auto gpu_op = Softmax(op_def);
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
}
}

View File

@ -39,10 +39,10 @@ using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface SoftmaxTest : XCTestCase
@interface PReLUTest : XCTestCase
@end
@implementation SoftmaxTest
@implementation PReLUTest
- (void)setUp {
[super setUp];
}

View File

@ -36,10 +36,10 @@ using ::tflite::gpu::TensorRef;
using ::tflite::gpu::metal::CompareVectors;
using ::tflite::gpu::metal::SingleOpModel;
@interface SliceTest : XCTestCase
@interface ReLUTest : XCTestCase
@end
@implementation SliceTest
@implementation ReLUTest
- (void)setUp {
[super setUp];
}

View File

@ -50,21 +50,55 @@ kernel void ComputeFunction($1
uint tid[[thread_index_in_threadgroup]],
uint3 ugid[[thread_position_in_grid]])
{
int offset = 0;
float sum = 0.0f;
int s = 0;
do {
if (offset + tid < params.size.x) {
float4 mask_temp = offset + tid == params.size.x - 1 ? params.mask : float4(1.0h);
float4 src = float4(src_tensor[offset + tid]);
sum += dot(mask_temp, exp(src));
offset += 32;
}
s++;
} while (s < params.size.y);
float4 maxx4 = float4(src_tensor[0].x);
for (int s = int(tid); s < params.size.x; s += 32) {
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[s]);
src = src * mask_a + mask_b * src.x;
maxx4 = max(maxx4, src);
}
float maximum = max(maxx4.x, maxx4.y);
maximum = max(maximum, maxx4.z);
maximum = max(maximum, maxx4.w);
threadgroup float4 tmp[8];
threadgroup float* tmpx1 = (threadgroup float*)tmp;
tmpx1[tid] = maximum;
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
if (tid == 0) {
maxx4 = max(tmp[0], tmp[1]);
maxx4 = max(maxx4, tmp[2]);
maxx4 = max(maxx4, tmp[3]);
maxx4 = max(maxx4, tmp[4]);
maxx4 = max(maxx4, tmp[5]);
maxx4 = max(maxx4, tmp[6]);
maxx4 = max(maxx4, tmp[7]);
maximum = max(maxx4.x, maxx4.y);
maximum = max(maximum, maxx4.z);
maximum = max(maximum, maxx4.w);
tmpx1[0] = maximum;
}
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
maximum = tmpx1[0];
float sum = 0.0f;
for (int s = int(tid); s < params.size.x; s += 32) {
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[s]) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
tmpx1[tid] = sum;
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
@ -85,74 +119,90 @@ kernel void ComputeFunction($1
code += R"(
sum = tmpx1[0];
offset = 0;
s = 0;
do {
if (offset + tid < params.size.x) {
int linear_index = offset + tid;
FLT4 value = FLT4(exp(float4(src_tensor[linear_index])) * sum);
uint3 gid = uint3(0, 0, linear_index);
$2
dst_tensor[linear_index] = value;
offset += 32;
}
s++;
} while (s < params.size.y);
int dst_s = int(ugid.x);
if (dst_s < params.size.x) {
int linear_index = dst_s;
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
FLT4 value = FLT4(exp(src) * sum);
uint3 gid = uint3(0, 0, linear_index);
$2
dst_tensor[linear_index] = value;
}
})";
return code;
}
} // namespace
ComputeTaskDescriptor Softmax(const OperationDef& definition,
int channels_count) {
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
ComputeTaskDescriptor desc(definition);
desc.shader_source = R"(
#include <metal_stdlib>
using namespace metal;
constant int src_channels = )";
desc.shader_source += std::to_string(channels_count);
desc.shader_source += R"(;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
return;
}
float shift = 0.0f;
int remaining_channels = src_channels % 4;
#include <metal_stdlib>
using namespace metal;
float sum = 0.0f;
for (int d = 0; d < src_channels / 4; ++d) {
int buffer_index = (d * size.y + gid.y) * size.x + gid.x;
sum += dot(float4(1.0f), exp(float4(src_tensor[buffer_index]) - shift));
}
if (remaining_channels > 0) {
int buffer_index = ((src_channels / 4) * size.y + gid.y) * size.x + gid.x;
float4 last_element = float4(src_tensor[buffer_index]);
sum += exp(last_element.x - shift);
if (remaining_channels > 1) sum += exp(last_element.y - shift);
if (remaining_channels == 3) sum += exp(last_element.z - shift);
}
struct uniforms {
int4 size;
float4 mask;
};
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
return;
}
for (int d = 0; d < (src_channels + 3) / 4; ++d) {
const int linear_index = (d * size.y + gid.y) * size.x + gid.x;
FLT4 value = FLT4(exp(float4(src_tensor[linear_index]) - shift) / sum);
$2
dst_tensor[linear_index] = value;
}
}
float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
for (int d = 0; d < params.size.z; ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
float4 mask_b = float4(1.0f) - mask_a;
float4 src = float4(src_tensor[buffer_index]);
src = src * mask_a + mask_b * src.x;
maximum = max(maximum, src.x);
maximum = max(maximum, src.y);
maximum = max(maximum, src.z);
maximum = max(maximum, src.w);
}
float sum = 0.0f;
for (int d = 0; d < params.size.z; ++d) {
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
sum += dot(mask_temp, exp(src));
}
for (int d = 0; d < params.size.z; ++d) {
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
FLT4 value = FLT4(exp(src) / sum);
$2
dst_tensor[linear_index] = value;
}
}
)";
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
desc.uniform_buffers = {
{"constant int2& size",
{"constant uniforms& params",
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
std::vector<int> sizes{dst_shapes[0].w, dst_shapes[0].h};
return GetByteBuffer(sizes);
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
struct uniforms {
int4 size;
float4 mask;
};
uniforms params;
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
}},
};
@ -168,7 +218,7 @@ ComputeTaskDescriptor Softmax(const OperationDef& definition,
}
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info, int channels_count) {
const GpuInfo& gpu_info) {
ComputeTaskDescriptor desc(definition);
desc.shader_source = GetSoftmax1x1Code(gpu_info);
@ -177,9 +227,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
desc.uniform_buffers = {
{"constant uniforms& params",
[channels_count](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int src_depth = DivideRoundUp(channels_count, 4);
[](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
struct uniforms {
int4 size;
float4 mask;
@ -187,7 +237,7 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
uniforms params;
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
const int reminder = channels_count % 4 == 0 ? 4 : channels_count % 4;
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
@ -198,7 +248,10 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
const std::vector<BHWC>& dst_shapes) {
return std::make_pair(uint3{32u, 1u, 1u}, uint3{1u, 1u, 1u});
uint3 groups_size{32, 1, 1};
uint3 groups_count{
DivideRoundUp(DivideRoundUp(dst_shapes[0].c, 4), groups_size.x), 1, 1};
return std::make_pair(groups_size, groups_count);
};
return desc;

View File

@ -27,13 +27,12 @@ namespace tflite {
namespace gpu {
namespace metal {
ComputeTaskDescriptor Softmax(const OperationDef& definition,
int channels_count);
ComputeTaskDescriptor Softmax(const OperationDef& definition);
// Softmax for case when width = height = 1 and AXIS = CHANNELS
// We have this case in MobilenetV1/V2.
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
const GpuInfo& gpu_info, int channels_count);
const GpuInfo& gpu_info);
} // namespace metal
} // namespace gpu

View File

@ -133,4 +133,76 @@ using ::tflite::gpu::metal::SingleOpModel;
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSoftmaxBigNumber {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 2, 1, 2);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 2, 1, 2);
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS;
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
XCTAssertFalse(std::isinf(std::exp(doubles[3])));
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
static_cast<float>(doubles[1]),
static_cast<float>(doubles[2]),
static_cast<float>(doubles[3])}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
static_cast<float>(std::exp(doubles[1]) / s0),
static_cast<float>(std::exp(doubles[2]) / s1),
static_cast<float>(std::exp(doubles[3]) / s1)},
model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testSoftmax1x1BigNumber {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 1, 4);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 1, 4);
SoftmaxAttributes attr;
attr.axis = Axis::CHANNELS;
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
XCTAssertFalse(std::isinf(std::exp(doubles[3])));
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
std::exp(doubles[2]) + std::exp(doubles[3]);
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
static_cast<float>(doubles[1]),
static_cast<float>(doubles[2]),
static_cast<float>(doubles[3])}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
static_cast<float>(std::exp(doubles[1]) / s0),
static_cast<float>(std::exp(doubles[2]) / s0),
static_cast<float>(std::exp(doubles[3]) / s0)},
model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
@end