Fixed threadgroup memory barriers for FullyConnected and Softmax.

PiperOrigin-RevId: 304313053
Change-Id: I944e2d05c65cd502e35e550b0f6c6c8a2efd01ff
This commit is contained in:
Raman Sarokin 2020-04-01 19:54:09 -07:00 committed by TensorFlower Gardener
parent 3e8cd000b9
commit 0e531f62a9
5 changed files with 39 additions and 20 deletions

View File

@ -129,12 +129,12 @@ std::vector<ComputeTaskDescriptorPtr> SelectReshape(
}
}
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(const GraphFloat32& graph,
int id, ValueId input_id,
ValueId output_id) {
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
const DeviceInfo& device_info) {
const auto src_shape = graph.FindInputs(id)[0]->tensor.shape;
if (src_shape.w == 1 && src_shape.h == 1) {
return Softmax1x1(id, input_id, output_id, src_shape.c);
return Softmax1x1(id, input_id, output_id, device_info, src_shape.c);
} else {
return Softmax(id, input_id, output_id, src_shape.c);
}
@ -334,7 +334,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
return absl::UnimplementedError(
"Softmax supports only CHANNELS dimension");
}
*tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]);
*tasks =
SelectSoftmax(graph, node_id, inputs[0], outputs[0], device_info);
break;
}
case OperationType::SPACE_TO_DEPTH:

View File

@ -707,6 +707,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:environment",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@com_google_absl//absl/strings",
],

View File

@ -37,8 +37,13 @@ namespace gpu {
namespace metal {
namespace {
std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
int dst_channels) {
std::string GetFullyConnectedCode(const DeviceInfo& device_info,
int src_channels, int dst_channels) {
bool shared_memory =
device_info.IsAppleGPU() &&
device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
const std::string barrier =
device_info.IsAppleGPU() ? "BARRIER" : "threadgroup_barrier";
const int src_depth = IntegralDivideRoundUp(src_channels, 4);
std::stringstream code;
code << R"(
@ -67,11 +72,11 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
for (int j = 0; j < $0; ++j) {
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
FLT4(0.0f) : vector[j * 32 + tid_index];
BARRIER(mem_flags::mem_threadgroup);
$1(mem_flags::mem_threadgroup);
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
}
BARRIER(mem_flags::mem_none);
$1(mem_flags::mem_none);
}
)";
} else {
@ -92,14 +97,14 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
threadgroup float temp[8][4];
temp[tid.x][tid.y] = summa;
BARRIER(mem_flags::mem_threadgroup);
$1(mem_flags::mem_threadgroup);
if (tid.y == 0) {
summa += temp[tid.x][1];
summa += temp[tid.x][2];
summa += temp[tid.x][3];
temp[tid.x][0] = summa;
}
BARRIER(mem_flags::mem_threadgroup);
$1(mem_flags::mem_threadgroup);
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
const int linear_index = ugid.x / 4;
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
@ -113,7 +118,7 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
const int src_depth_sub_groups = shared_memory
? IntegralDivideRoundUp(src_depth, 32)
: IntegralDivideRoundUp(src_depth, 4);
return absl::Substitute(code.str(), src_depth_sub_groups);
return absl::Substitute(code.str(), src_depth_sub_groups, barrier);
}
} // namespace
@ -124,9 +129,8 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
bool shared = device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
desc->shader_source =
GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o);
desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
attr.weights.shape.o);
desc->input_buffers = {
{input_id, "device FLT4* const vector"},
@ -138,8 +142,11 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
return CalculateOutputShape(buffers.find(input_id)->second, attr);
}};
bool shared_memory =
device_info.IsAppleGPU() &&
device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
const int src_depth_aligned = AlignByN(src_depth, shared ? 32 : 4);
const int src_depth_aligned = AlignByN(src_depth, shared_memory ? 32 : 4);
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
int counter = 0;

View File

@ -25,13 +25,16 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetSoftmax1x1Code() {
std::string GetSoftmax1x1Code(const DeviceInfo& device_info) {
const std::string barrier =
device_info.IsAppleGPU() ? "BARRIER" : "threadgroup_barrier";
std::string code = R"(
#include <metal_stdlib>
using namespace metal;
@ -63,7 +66,9 @@ kernel void ComputeFunction($1
threadgroup float4 tmp[8];
threadgroup float* tmpx1 = (threadgroup float*)tmp;
tmpx1[tid] = sum;
BARRIER(mem_flags::mem_threadgroup);
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
if (tid == 0) {
sum = dot(float4(1.0f), tmp[0]);
sum += dot(float4(1.0f), tmp[1]);
@ -75,7 +80,9 @@ kernel void ComputeFunction($1
sum += dot(float4(1.0f), tmp[7]);
tmpx1[0] = 1.0 / sum;
}
BARRIER(mem_flags::mem_threadgroup);
)";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
sum = tmpx1[0];
offset = 0;
@ -171,11 +178,12 @@ std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
ValueId output_id,
const DeviceInfo& device_info,
int channels_count) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetSoftmax1x1Code();
desc->shader_source = GetSoftmax1x1Code(device_info);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -35,6 +36,7 @@ std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
// We have this case in MobilenetV1/V2.
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
ValueId output_id,
const DeviceInfo& device_info,
int channels_count);
} // namespace metal