Fixed threadgroup memory barriers for FullyConnected and Softmax.
PiperOrigin-RevId: 304313053 Change-Id: I944e2d05c65cd502e35e550b0f6c6c8a2efd01ff
This commit is contained in:
parent
3e8cd000b9
commit
0e531f62a9
|
@ -129,12 +129,12 @@ std::vector<ComputeTaskDescriptorPtr> SelectReshape(
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(const GraphFloat32& graph,
|
||||
int id, ValueId input_id,
|
||||
ValueId output_id) {
|
||||
std::vector<ComputeTaskDescriptorPtr> SelectSoftmax(
|
||||
const GraphFloat32& graph, int id, ValueId input_id, ValueId output_id,
|
||||
const DeviceInfo& device_info) {
|
||||
const auto src_shape = graph.FindInputs(id)[0]->tensor.shape;
|
||||
if (src_shape.w == 1 && src_shape.h == 1) {
|
||||
return Softmax1x1(id, input_id, output_id, src_shape.c);
|
||||
return Softmax1x1(id, input_id, output_id, device_info, src_shape.c);
|
||||
} else {
|
||||
return Softmax(id, input_id, output_id, src_shape.c);
|
||||
}
|
||||
|
@ -334,7 +334,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||
return absl::UnimplementedError(
|
||||
"Softmax supports only CHANNELS dimension");
|
||||
}
|
||||
*tasks = SelectSoftmax(graph, node_id, inputs[0], outputs[0]);
|
||||
*tasks =
|
||||
SelectSoftmax(graph, node_id, inputs[0], outputs[0], device_info);
|
||||
break;
|
||||
}
|
||||
case OperationType::SPACE_TO_DEPTH:
|
||||
|
|
|
@ -707,6 +707,7 @@ cc_library(
|
|||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
|
||||
"//tensorflow/lite/delegates/gpu/metal:environment",
|
||||
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
|
|
@ -37,8 +37,13 @@ namespace gpu {
|
|||
namespace metal {
|
||||
namespace {
|
||||
|
||||
std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
|
||||
int dst_channels) {
|
||||
std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
||||
int src_channels, int dst_channels) {
|
||||
bool shared_memory =
|
||||
device_info.IsAppleGPU() &&
|
||||
device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
|
||||
const std::string barrier =
|
||||
device_info.IsAppleGPU() ? "BARRIER" : "threadgroup_barrier";
|
||||
const int src_depth = IntegralDivideRoundUp(src_channels, 4);
|
||||
std::stringstream code;
|
||||
code << R"(
|
||||
|
@ -67,11 +72,11 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
|
|||
for (int j = 0; j < $0; ++j) {
|
||||
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
|
||||
FLT4(0.0f) : vector[j * 32 + tid_index];
|
||||
BARRIER(mem_flags::mem_threadgroup);
|
||||
$1(mem_flags::mem_threadgroup);
|
||||
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
|
||||
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
|
||||
}
|
||||
BARRIER(mem_flags::mem_none);
|
||||
$1(mem_flags::mem_none);
|
||||
}
|
||||
)";
|
||||
} else {
|
||||
|
@ -92,14 +97,14 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
|
|||
|
||||
threadgroup float temp[8][4];
|
||||
temp[tid.x][tid.y] = summa;
|
||||
BARRIER(mem_flags::mem_threadgroup);
|
||||
$1(mem_flags::mem_threadgroup);
|
||||
if (tid.y == 0) {
|
||||
summa += temp[tid.x][1];
|
||||
summa += temp[tid.x][2];
|
||||
summa += temp[tid.x][3];
|
||||
temp[tid.x][0] = summa;
|
||||
}
|
||||
BARRIER(mem_flags::mem_threadgroup);
|
||||
$1(mem_flags::mem_threadgroup);
|
||||
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
|
||||
const int linear_index = ugid.x / 4;
|
||||
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
|
||||
|
@ -113,7 +118,7 @@ std::string GetFullyConnectedCode(bool shared_memory, int src_channels,
|
|||
const int src_depth_sub_groups = shared_memory
|
||||
? IntegralDivideRoundUp(src_depth, 32)
|
||||
: IntegralDivideRoundUp(src_depth, 4);
|
||||
return absl::Substitute(code.str(), src_depth_sub_groups);
|
||||
return absl::Substitute(code.str(), src_depth_sub_groups, barrier);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -124,9 +129,8 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
|
|||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = false;
|
||||
bool shared = device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
|
||||
desc->shader_source =
|
||||
GetFullyConnectedCode(shared, attr.weights.shape.i, attr.weights.shape.o);
|
||||
desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
|
||||
attr.weights.shape.o);
|
||||
|
||||
desc->input_buffers = {
|
||||
{input_id, "device FLT4* const vector"},
|
||||
|
@ -138,8 +142,11 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
|
|||
return CalculateOutputShape(buffers.find(input_id)->second, attr);
|
||||
}};
|
||||
|
||||
bool shared_memory =
|
||||
device_info.IsAppleGPU() &&
|
||||
device_info.apple_info.IsLocalMemoryPreferredOverGlobal();
|
||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||
const int src_depth_aligned = AlignByN(src_depth, shared ? 32 : 4);
|
||||
const int src_depth_aligned = AlignByN(src_depth, shared_memory ? 32 : 4);
|
||||
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
|
||||
|
||||
int counter = 0;
|
||||
|
|
|
@ -25,13 +25,16 @@ limitations under the License.
|
|||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
std::string GetSoftmax1x1Code() {
|
||||
std::string GetSoftmax1x1Code(const DeviceInfo& device_info) {
|
||||
const std::string barrier =
|
||||
device_info.IsAppleGPU() ? "BARRIER" : "threadgroup_barrier";
|
||||
std::string code = R"(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
@ -63,7 +66,9 @@ kernel void ComputeFunction($1
|
|||
threadgroup float4 tmp[8];
|
||||
threadgroup float* tmpx1 = (threadgroup float*)tmp;
|
||||
tmpx1[tid] = sum;
|
||||
BARRIER(mem_flags::mem_threadgroup);
|
||||
)";
|
||||
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
|
||||
code += R"(
|
||||
if (tid == 0) {
|
||||
sum = dot(float4(1.0f), tmp[0]);
|
||||
sum += dot(float4(1.0f), tmp[1]);
|
||||
|
@ -75,7 +80,9 @@ kernel void ComputeFunction($1
|
|||
sum += dot(float4(1.0f), tmp[7]);
|
||||
tmpx1[0] = 1.0 / sum;
|
||||
}
|
||||
BARRIER(mem_flags::mem_threadgroup);
|
||||
)";
|
||||
code += " " + barrier + "(mem_flags::mem_threadgroup);\n";
|
||||
code += R"(
|
||||
sum = tmpx1[0];
|
||||
|
||||
offset = 0;
|
||||
|
@ -171,11 +178,12 @@ std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
|
|||
|
||||
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const DeviceInfo& device_info,
|
||||
int channels_count) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = false;
|
||||
desc->shader_source = GetSoftmax1x1Code();
|
||||
desc->shader_source = GetSoftmax1x1Code(device_info);
|
||||
|
||||
desc->input_buffers = {
|
||||
{input_id, "device FLT4* const src_buffer"},
|
||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -35,6 +36,7 @@ std::vector<ComputeTaskDescriptorPtr> Softmax(int id, ValueId input_id,
|
|||
// We have this case in MobilenetV1/V2.
|
||||
std::vector<ComputeTaskDescriptorPtr> Softmax1x1(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const DeviceInfo& device_info,
|
||||
int channels_count);
|
||||
|
||||
} // namespace metal
|
||||
|
|
Loading…
Reference in New Issue