Added 3rd pass to Softmax in Metal to improve numerical stability.
PiperOrigin-RevId: 348569651 Change-Id: I309d607747ae50f4b30cc46501f128abf3e6d4a0
This commit is contained in:
parent
d712b970c9
commit
646d25d159
@ -124,10 +124,10 @@ std::unique_ptr<ComputeTaskDescriptor> SelectSoftmax(const OperationDef& op_def,
|
||||
const BHWC& src_shape,
|
||||
const GpuInfo& gpu_info) {
|
||||
if (src_shape.w == 1 && src_shape.h == 1) {
|
||||
auto gpu_op = Softmax1x1(op_def, gpu_info, src_shape.c);
|
||||
auto gpu_op = Softmax1x1(op_def, gpu_info);
|
||||
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
} else {
|
||||
auto gpu_op = Softmax(op_def, src_shape.c);
|
||||
auto gpu_op = Softmax(op_def);
|
||||
return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
|
||||
}
|
||||
}
|
||||
|
@ -39,10 +39,10 @@ using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
|
||||
@interface SoftmaxTest : XCTestCase
|
||||
@interface PReLUTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation SoftmaxTest
|
||||
@implementation PReLUTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
@ -36,10 +36,10 @@ using ::tflite::gpu::TensorRef;
|
||||
using ::tflite::gpu::metal::CompareVectors;
|
||||
using ::tflite::gpu::metal::SingleOpModel;
|
||||
|
||||
@interface SliceTest : XCTestCase
|
||||
@interface ReLUTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation SliceTest
|
||||
@implementation ReLUTest
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
@ -50,21 +50,55 @@ kernel void ComputeFunction($1
|
||||
uint tid[[thread_index_in_threadgroup]],
|
||||
uint3 ugid[[thread_position_in_grid]])
|
||||
{
|
||||
int offset = 0;
|
||||
float sum = 0.0f;
|
||||
int s = 0;
|
||||
do {
|
||||
if (offset + tid < params.size.x) {
|
||||
float4 mask_temp = offset + tid == params.size.x - 1 ? params.mask : float4(1.0h);
|
||||
float4 src = float4(src_tensor[offset + tid]);
|
||||
sum += dot(mask_temp, exp(src));
|
||||
offset += 32;
|
||||
}
|
||||
s++;
|
||||
} while (s < params.size.y);
|
||||
|
||||
float4 maxx4 = float4(src_tensor[0].x);
|
||||
for (int s = int(tid); s < params.size.x; s += 32) {
|
||||
float4 mask_a = s == params.size.x - 1 ? params.mask : float4(1.0f);
|
||||
float4 mask_b = float4(1.0f) - mask_a;
|
||||
float4 src = float4(src_tensor[s]);
|
||||
src = src * mask_a + mask_b * src.x;
|
||||
maxx4 = max(maxx4, src);
|
||||
}
|
||||
float maximum = max(maxx4.x, maxx4.y);
|
||||
maximum = max(maximum, maxx4.z);
|
||||
maximum = max(maximum, maxx4.w);
|
||||
|
||||
threadgroup float4 tmp[8];
|
||||
threadgroup float* tmpx1 = (threadgroup float*)tmp;
|
||||
|
||||
tmpx1[tid] = maximum;
|
||||
)";
|
||||
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
|
||||
code += R"(
|
||||
if (tid == 0) {
|
||||
maxx4 = max(tmp[0], tmp[1]);
|
||||
maxx4 = max(maxx4, tmp[2]);
|
||||
maxx4 = max(maxx4, tmp[3]);
|
||||
maxx4 = max(maxx4, tmp[4]);
|
||||
maxx4 = max(maxx4, tmp[5]);
|
||||
maxx4 = max(maxx4, tmp[6]);
|
||||
maxx4 = max(maxx4, tmp[7]);
|
||||
maximum = max(maxx4.x, maxx4.y);
|
||||
maximum = max(maximum, maxx4.z);
|
||||
maximum = max(maximum, maxx4.w);
|
||||
tmpx1[0] = maximum;
|
||||
}
|
||||
)";
|
||||
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
|
||||
code += R"(
|
||||
maximum = tmpx1[0];
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int s = int(tid); s < params.size.x; s += 32) {
|
||||
float4 mask_temp = s == params.size.x - 1 ? params.mask : float4(1.0f);
|
||||
float4 src = float4(src_tensor[s]) - float4(maximum);
|
||||
sum += dot(mask_temp, exp(src));
|
||||
}
|
||||
|
||||
)";
|
||||
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
|
||||
code += R"(
|
||||
|
||||
tmpx1[tid] = sum;
|
||||
)";
|
||||
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
|
||||
@ -85,74 +119,90 @@ kernel void ComputeFunction($1
|
||||
code += R"(
|
||||
sum = tmpx1[0];
|
||||
|
||||
offset = 0;
|
||||
s = 0;
|
||||
do {
|
||||
if (offset + tid < params.size.x) {
|
||||
int linear_index = offset + tid;
|
||||
FLT4 value = FLT4(exp(float4(src_tensor[linear_index])) * sum);
|
||||
uint3 gid = uint3(0, 0, linear_index);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
offset += 32;
|
||||
}
|
||||
s++;
|
||||
} while (s < params.size.y);
|
||||
int dst_s = int(ugid.x);
|
||||
if (dst_s < params.size.x) {
|
||||
int linear_index = dst_s;
|
||||
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
|
||||
FLT4 value = FLT4(exp(src) * sum);
|
||||
uint3 gid = uint3(0, 0, linear_index);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
}
|
||||
})";
|
||||
return code;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ComputeTaskDescriptor Softmax(const OperationDef& definition,
|
||||
int channels_count) {
|
||||
ComputeTaskDescriptor Softmax(const OperationDef& definition) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.shader_source = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
constant int src_channels = )";
|
||||
desc.shader_source += std::to_string(channels_count);
|
||||
desc.shader_source += R"(;
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
|
||||
return;
|
||||
}
|
||||
float shift = 0.0f;
|
||||
int remaining_channels = src_channels % 4;
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < src_channels / 4; ++d) {
|
||||
int buffer_index = (d * size.y + gid.y) * size.x + gid.x;
|
||||
sum += dot(float4(1.0f), exp(float4(src_tensor[buffer_index]) - shift));
|
||||
}
|
||||
if (remaining_channels > 0) {
|
||||
int buffer_index = ((src_channels / 4) * size.y + gid.y) * size.x + gid.x;
|
||||
float4 last_element = float4(src_tensor[buffer_index]);
|
||||
sum += exp(last_element.x - shift);
|
||||
if (remaining_channels > 1) sum += exp(last_element.y - shift);
|
||||
if (remaining_channels == 3) sum += exp(last_element.z - shift);
|
||||
}
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
};
|
||||
$0
|
||||
kernel void ComputeFunction(
|
||||
$1
|
||||
uint3 gid[[thread_position_in_grid]]) {
|
||||
if (int(gid.x) >= params.size.x || int(gid.y) >= params.size.y) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int d = 0; d < (src_channels + 3) / 4; ++d) {
|
||||
const int linear_index = (d * size.y + gid.y) * size.x + gid.x;
|
||||
FLT4 value = FLT4(exp(float4(src_tensor[linear_index]) - shift) / sum);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
}
|
||||
}
|
||||
float maximum = src_tensor[gid.y * params.size.x + gid.x].x;
|
||||
for (int d = 0; d < params.size.z; ++d) {
|
||||
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
||||
float4 mask_a = d == params.size.z - 1 ? params.mask : float4(1.0f);
|
||||
float4 mask_b = float4(1.0f) - mask_a;
|
||||
float4 src = float4(src_tensor[buffer_index]);
|
||||
src = src * mask_a + mask_b * src.x;
|
||||
maximum = max(maximum, src.x);
|
||||
maximum = max(maximum, src.y);
|
||||
maximum = max(maximum, src.z);
|
||||
maximum = max(maximum, src.w);
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < params.size.z; ++d) {
|
||||
int buffer_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
||||
float4 mask_temp = d == params.size.z - 1 ? params.mask : float4(1.0f);
|
||||
float4 src = float4(src_tensor[buffer_index]) - float4(maximum);
|
||||
sum += dot(mask_temp, exp(src));
|
||||
}
|
||||
|
||||
for (int d = 0; d < params.size.z; ++d) {
|
||||
const int linear_index = (d * params.size.y + gid.y) * params.size.x + gid.x;
|
||||
float4 src = float4(src_tensor[linear_index]) - float4(maximum);
|
||||
FLT4 value = FLT4(exp(src) / sum);
|
||||
$2
|
||||
dst_tensor[linear_index] = value;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
|
||||
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
|
||||
|
||||
desc.uniform_buffers = {
|
||||
{"constant int2& size",
|
||||
{"constant uniforms& params",
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
std::vector<int> sizes{dst_shapes[0].w, dst_shapes[0].h};
|
||||
return GetByteBuffer(sizes);
|
||||
const int dst_depth = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
};
|
||||
uniforms params;
|
||||
params.size = {dst_shapes[0].w, dst_shapes[0].h, dst_depth, 1};
|
||||
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
|
||||
for (int i = 0; i < reminder; ++i) {
|
||||
params.mask[i] = 1.0f;
|
||||
}
|
||||
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
||||
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
|
||||
}},
|
||||
};
|
||||
|
||||
@ -168,7 +218,7 @@ ComputeTaskDescriptor Softmax(const OperationDef& definition,
|
||||
}
|
||||
|
||||
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
const GpuInfo& gpu_info, int channels_count) {
|
||||
const GpuInfo& gpu_info) {
|
||||
ComputeTaskDescriptor desc(definition);
|
||||
desc.shader_source = GetSoftmax1x1Code(gpu_info);
|
||||
|
||||
@ -177,9 +227,9 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
|
||||
desc.uniform_buffers = {
|
||||
{"constant uniforms& params",
|
||||
[channels_count](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const int src_depth = DivideRoundUp(channels_count, 4);
|
||||
[](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
const int src_depth = DivideRoundUp(dst_shapes[0].c, 4);
|
||||
struct uniforms {
|
||||
int4 size;
|
||||
float4 mask;
|
||||
@ -187,7 +237,7 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
uniforms params;
|
||||
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
|
||||
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
const int reminder = channels_count % 4 == 0 ? 4 : channels_count % 4;
|
||||
int reminder = dst_shapes[0].c % 4 == 0 ? 4 : dst_shapes[0].c % 4;
|
||||
for (int i = 0; i < reminder; ++i) {
|
||||
params.mask[i] = 1.0f;
|
||||
}
|
||||
@ -198,7 +248,10 @@ ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
|
||||
desc.resize_function = [](const std::vector<BHWC>& src_shapes,
|
||||
const std::vector<BHWC>& dst_shapes) {
|
||||
return std::make_pair(uint3{32u, 1u, 1u}, uint3{1u, 1u, 1u});
|
||||
uint3 groups_size{32, 1, 1};
|
||||
uint3 groups_count{
|
||||
DivideRoundUp(DivideRoundUp(dst_shapes[0].c, 4), groups_size.x), 1, 1};
|
||||
return std::make_pair(groups_size, groups_count);
|
||||
};
|
||||
|
||||
return desc;
|
||||
|
@ -27,13 +27,12 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
ComputeTaskDescriptor Softmax(const OperationDef& definition,
|
||||
int channels_count);
|
||||
ComputeTaskDescriptor Softmax(const OperationDef& definition);
|
||||
|
||||
// Softmax for case when width = height = 1 and AXIS = CHANNELS
|
||||
// We have this case in MobilenetV1/V2.
|
||||
ComputeTaskDescriptor Softmax1x1(const OperationDef& definition,
|
||||
const GpuInfo& gpu_info, int channels_count);
|
||||
const GpuInfo& gpu_info);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
|
@ -133,4 +133,76 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSoftmaxBigNumber {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 2, 1, 2);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 1, 2);
|
||||
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::CHANNELS;
|
||||
|
||||
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
|
||||
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
|
||||
XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
|
||||
XCTAssertFalse(std::isinf(std::exp(doubles[3])));
|
||||
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]);
|
||||
double s1 = std::exp(doubles[2]) + std::exp(doubles[3]);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
|
||||
static_cast<float>(doubles[1]),
|
||||
static_cast<float>(doubles[2]),
|
||||
static_cast<float>(doubles[3])}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
|
||||
static_cast<float>(std::exp(doubles[1]) / s0),
|
||||
static_cast<float>(std::exp(doubles[2]) / s1),
|
||||
static_cast<float>(std::exp(doubles[3]) / s1)},
|
||||
model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
- (void)testSoftmax1x1BigNumber {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 4);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 1, 4);
|
||||
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::CHANNELS;
|
||||
|
||||
double doubles[4] = {1.0, 2.0, 3.0, 100.0};
|
||||
// exp(100) is inf in float (32 bit) but representable in double (64 bit)
|
||||
XCTAssertTrue(std::isinf(std::exp(static_cast<float>(doubles[3]))));
|
||||
XCTAssertFalse(std::isinf(std::exp(doubles[3])));
|
||||
double s0 = std::exp(doubles[0]) + std::exp(doubles[1]) +
|
||||
std::exp(doubles[2]) + std::exp(doubles[3]);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {static_cast<float>(doubles[0]),
|
||||
static_cast<float>(doubles[1]),
|
||||
static_cast<float>(doubles[2]),
|
||||
static_cast<float>(doubles[3])}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
status = CompareVectors({static_cast<float>(std::exp(doubles[0]) / s0),
|
||||
static_cast<float>(std::exp(doubles[1]) / s0),
|
||||
static_cast<float>(std::exp(doubles[2]) / s0),
|
||||
static_cast<float>(std::exp(doubles[3]) / s0)},
|
||||
model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
|
||||
}
|
||||
|
||||
@end
|
||||
|
Loading…
Reference in New Issue
Block a user