Juhyun Lee b26b6b5669 Rename IntegralDivideRoundUp to DivideRoundUp.
PiperOrigin-RevId: 307447663
Change-Id: I1e0f6c9f058e3f0457a7522f1d10f7da8ab8610d
2020-04-20 12:06:03 -07:00

228 lines
7.5 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetSoftmax1x1Code(const DeviceInfo& device_info) {
const std::string barrier = device_info.IsWaveSizeEqualTo32()
? "SIMDGROUP_BARRIER"
: "threadgroup_barrier";
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 size;
float4 mask;
};
$0
kernel void ComputeFunction($1
uint tid[[thread_index_in_threadgroup]],
uint3 ugid[[thread_position_in_grid]])
{
int offset = 0;
float sum = 0.0f;
int s = 0;
do {
if (offset + tid < params.size.x) {
float4 mask_temp = offset + tid == params.size.x - 1 ? params.mask : float4(1.0h);
float4 src = float4(src_buffer[offset + tid]);
sum += dot(mask_temp, exp(src));
offset += 32;
}
s++;
} while (s < params.size.y);
threadgroup float4 tmp[8];
threadgroup float* tmpx1 = (threadgroup float*)tmp;
tmpx1[tid] = sum;
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
if (tid == 0) {
sum = dot(float4(1.0f), tmp[0]);
sum += dot(float4(1.0f), tmp[1]);
sum += dot(float4(1.0f), tmp[2]);
sum += dot(float4(1.0f), tmp[3]);
sum += dot(float4(1.0f), tmp[4]);
sum += dot(float4(1.0f), tmp[5]);
sum += dot(float4(1.0f), tmp[6]);
sum += dot(float4(1.0f), tmp[7]);
tmpx1[0] = 1.0 / sum;
}
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
sum = tmpx1[0];
offset = 0;
s = 0;
do {
if (offset + tid < params.size.x) {
int linear_index = offset + tid;
FLT4 value = FLT4(exp(float4(src_buffer[linear_index])) * sum);
uint3 gid = uint3(0, 0, linear_index);
$2
dst_buffer[linear_index] = value;
offset += 32;
}
s++;
} while (s < params.size.y);
})";
return code;
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
ValueId output_id,
int channels_count) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = R"(
#include <metal_stdlib>
using namespace metal;
constant int src_channels = )";
desc->shader_source += std::to_string(channels_count);
desc->shader_source += R"(;
$0
kernel void ComputeFunction(
$1
uint3 gid[[thread_position_in_grid]]) {
if (int(gid.x) >= size.x || int(gid.y) >= size.y) {
return;
}
float shift = 0.0f;
int remaining_channels = src_channels % 4;
float sum = 0.0f;
for (int d = 0; d < src_channels / 4; ++d) {
int buffer_index = (d * size.y + gid.y) * size.x + gid.x;
sum += dot(float4(1.0f), exp(float4(input_buffer[buffer_index]) - shift));
}
if (remaining_channels > 0) {
int buffer_index = ((src_channels / 4) * size.y + gid.y) * size.x + gid.x;
float4 last_element = float4(input_buffer[buffer_index]);
sum += exp(last_element.x - shift);
if (remaining_channels > 1) sum += exp(last_element.y - shift);
if (remaining_channels == 3) sum += exp(last_element.z - shift);
}
for (int d = 0; d < (src_channels + 3) / 4; ++d) {
const int linear_index = (d * size.y + gid.y) * size.x + gid.x;
FLT4 value = FLT4(exp(float4(input_buffer[linear_index]) - shift) / sum);
$2
output_buffer[linear_index] = value;
}
}
)";
desc->input_buffers = {
{input_id, "device FLT4* const input_buffer"},
};
desc->output_buffer = {output_id, "device FLT4* output_buffer",
[input_id](const std::map<ValueId, BHWC>& buffers) {
return buffers.find(input_id)->second;
}};
desc->uniform_buffers = {
{"constant int2& size",
[output_id](const std::map<ValueId, BHWC>& buffers) {
const auto& dimension = buffers.find(output_id)->second;
std::vector<int> sizes{dimension.w, dimension.h};
return GetByteBuffer(sizes);
}},
};
desc->resize_function = [output_id](const std::map<ValueId, BHWC>& buffers) {
uint3 groups_size{8, 4, 1};
const auto& dimension = buffers.find(output_id)->second;
uint3 groups_count{DivideRoundUp(dimension.w, groups_size.x),
DivideRoundUp(dimension.h, groups_size.y), 1};
return std::make_pair(groups_size, groups_count);
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
ValueId output_id,
const DeviceInfo& device_info,
int channels_count) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetSoftmax1x1Code(device_info);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {output_id, "device FLT4* dst_buffer",
[input_id](const std::map<ValueId, BHWC>& buffers) {
return buffers.find(input_id)->second;
}};
desc->uniform_buffers = {
{"constant uniforms& params",
[channels_count](const std::map<ValueId, BHWC>& buffers) {
const int src_depth = DivideRoundUp(channels_count, 4);
struct uniforms {
int4 size;
float4 mask;
};
uniforms params;
params.size = {src_depth, DivideRoundUp(src_depth, 32), 1, 1};
params.mask = {0.0f, 0.0f, 0.0f, 0.0f};
const int reminder = channels_count % 4 == 0 ? 4 : channels_count % 4;
for (int i = 0; i < reminder; ++i) {
params.mask[i] = 1.0f;
}
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&params);
return std::vector<uint8_t>(ptr, ptr + sizeof(uniforms));
}},
};
desc->resize_function = [](const std::map<ValueId, BHWC>& buffers) {
return std::make_pair(uint3{32u, 1u, 1u}, uint3{1u, 1u, 1u});
};
return {desc};
}
} // namespace metal
} // namespace gpu
} // namespace tflite