Fixed threadgroup memory barriers for FullyConnected and Softmax.

PiperOrigin-RevId: 304313053
Change-Id: I944e2d05c65cd502e35e550b0f6c6c8a2efd01ff
This commit is contained in:
Raman Sarokin 2020-04-01 19:54:09 -07:00 committed by TensorFlower Gardener
parent 3e8cd000b9
commit 0e531f62a9
5 changed files with 39 additions and 20 deletions

View File

@ -129,12 +129,12 @@ std::vector<ComputeTaskDescriptorPtr> SelectReshape(
} }
} }
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(const GraphFloat32& graph, std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(
int id, ValueId input_id, const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
ValueId output_id) { const DeviceInfo& device_info) {
const auto src_shape = graph.FindInputs(id)[0]->tensor.shape; const auto src_shape = graph.FindInputs(id)[0]->tensor.shape;
if (src_shape.w == 1 && src_shape.h == 1) { if (src_shape.w == 1 && src_shape.h == 1) {
return Softmax1x1(id, input_id, output_id, src_shape.c); return Softmax1x1(id, input_id, output_id, device_info, src_shape.c);
} else { } else {
return Softmax(id, input_id, output_id, src_shape.c); return Softmax(id, input_id, output_id, src_shape.c);
} }
@ -334,7 +334,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
return absl::UnimplementedError( return absl::UnimplementedError(
"Softmax supports only CHANNELS dimension"); "Softmax supports only CHANNELS dimension");
} }
*tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]); *tasks =
SelectSoftmax(graph, node_id, inputs[0], outputs[0], device_info);
break; break;
} }
case OperationType::SPACE_TO_DEPTH: case OperationType::SPACE_TO_DEPTH:

View File

@ -707,6 +707,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor", "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
"//tensorflow/lite/delegates/gpu/metal:environment",
"//tensorflow/lite/delegates/gpu/metal:runtime_options", "//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -37,8 +37,13 @@ namespace gpu {
namespace metal { namespace metal {
namespace { namespace {
std::string GetFullyConnectedCode(bool shared_memory, int src_channels, std::string GetFullyConnectedCode(const DeviceInfo& device_info,
int dst_channels) { int src_channels, int dst_channels) {
bool shared_memory =
device_info.IsAppleGPU() &&
device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
const std::string barrier =
device_info.IsAppleGPU() ? "BARRIER" : "threadgroup_barrier";
const int src_depth = IntegralDivideRoundUp(src_channels, 4); const int src_depth = IntegralDivideRoundUp(src_channels, 4);
std::stringstream code; std::stringstream code;
code << R"( code << R"(
@ -67,11 +72,11 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
for (int j = 0; j < $0; ++j) { for (int j = 0; j < $0; ++j) {
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ? local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
FLT4(0.0f) : vector[j * 32 + tid_index]; FLT4(0.0f) : vector[j * 32 + tid_index];
BARRIER(mem_flags::mem_threadgroup); $1(mem_flags::mem_threadgroup);
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) { for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]); summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
} }
BARRIER(mem_flags::mem_none); $1(mem_flags::mem_none);
} }
)"; )";
} else { } else {
@ -92,14 +97,14 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
threadgroup float temp[8][4]; threadgroup float temp[8][4];
temp[tid.x][tid.y] = summa; temp[tid.x][tid.y] = summa;
BARRIER(mem_flags::mem_threadgroup); $1(mem_flags::mem_threadgroup);
if (tid.y == 0) { if (tid.y == 0) {
summa += temp[tid.x][1]; summa += temp[tid.x][1];
summa += temp[tid.x][2]; summa += temp[tid.x][2];
summa += temp[tid.x][3]; summa += temp[tid.x][3];
temp[tid.x][0] = summa; temp[tid.x][0] = summa;
} }
BARRIER(mem_flags::mem_threadgroup); $1(mem_flags::mem_threadgroup);
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) { if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
const int linear_index = ugid.x / 4; const int linear_index = ugid.x / 4;
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) + FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
@ -113,7 +118,7 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
const int src_depth_sub_groups = shared_memory const int src_depth_sub_groups = shared_memory
? IntegralDivideRoundUp(src_depth, 32) ? IntegralDivideRoundUp(src_depth, 32)
: IntegralDivideRoundUp(src_depth, 4); : IntegralDivideRoundUp(src_depth, 4);
return absl::Substitute(code.str(), src_depth_sub_groups); return absl::Substitute(code.str(), src_depth_sub_groups, barrier);
} }
} // namespace } // namespace
@ -124,9 +129,8 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
auto desc = std::make_shared<ComputeTaskDescriptor>(); auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id; desc->id = id;
desc->is_linkable = false; desc->is_linkable = false;
bool shared = device_info.apple_info.IsLocalMemoryPreferredOverGlobal(); desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
desc->shader_source = attr.weights.shape.o);
GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o);
desc->input_buffers = { desc->input_buffers = {
{input_id, "device FLT4* const vector"}, {input_id, "device FLT4* const vector"},
@ -138,8 +142,11 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
return CalculateOutputShape(buffers.find(input_id)->second, attr); return CalculateOutputShape(buffers.find(input_id)->second, attr);
}}; }};
bool shared_memory =
device_info.IsAppleGPU() &&
device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
const int src_depth_aligned = AlignByN(src_depth, shared ? 32 : 4); const int src_depth_aligned = AlignByN(src_depth, shared_memory ? 32 : 4);
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8); const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
int counter = 0; int counter = 0;

View File

@ -25,13 +25,16 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h" #include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
namespace metal { namespace metal {
namespace { namespace {
std::string GetSoftmax1x1Code() { std::string GetSoftmax1x1Code(const DeviceInfo& device_info) {
const std::string barrier =
device_info.IsAppleGPU() ? "BARRIER" : "threadgroup_barrier";
std::string code = R"( std::string code = R"(
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
@ -63,7 +66,9 @@ kernel void ComputeFunction($1
threadgroup float4 tmp[8]; threadgroup float4 tmp[8];
threadgroup float* tmpx1 = (threadgroup float*)tmp; threadgroup float* tmpx1 = (threadgroup float*)tmp;
tmpx1[tid] = sum; tmpx1[tid] = sum;
BARRIER(mem_flags::mem_threadgroup); )";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
if (tid == 0) { if (tid == 0) {
sum = dot(float4(1.0f), tmp[0]); sum = dot(float4(1.0f), tmp[0]);
sum += dot(float4(1.0f), tmp[1]); sum += dot(float4(1.0f), tmp[1]);
@ -75,7 +80,9 @@ kernel void ComputeFunction($1
sum += dot(float4(1.0f), tmp[7]); sum += dot(float4(1.0f), tmp[7]);
tmpx1[0] = 1.0 / sum; tmpx1[0] = 1.0 / sum;
} }
BARRIER(mem_flags::mem_threadgroup); )";
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
code += R"(
sum = tmpx1[0]; sum = tmpx1[0];
offset = 0; offset = 0;
@ -171,11 +178,12 @@ std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id, std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
ValueId output_id, ValueId output_id,
const DeviceInfo& device_info,
int channels_count) { int channels_count) {
auto desc = std::make_shared<ComputeTaskDescriptor>(); auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id; desc->id = id;
desc->is_linkable = false; desc->is_linkable = false;
desc->shader_source = GetSoftmax1x1Code(); desc->shader_source = GetSoftmax1x1Code(device_info);
desc->input_buffers = { desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"}, {input_id, "device FLT4* const src_buffer"},

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h" #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite { namespace tflite {
@ -35,6 +36,7 @@ std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
// We have this case in MobilenetV1/V2. // We have this case in MobilenetV1/V2.
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id, std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
ValueId output_id, ValueId output_id,
const DeviceInfo& device_info,
int channels_count); int channels_count);
} // namespace metal } // namespace metal