Using common FullyConnected task and selector for Metal.

PiperOrigin-RevId: 355410583
Change-Id: Ieb7b9a337dc45b290df6dfb4a2176dced07ff75c
This commit is contained in:
Raman Sarokin 2021-02-03 09:14:23 -08:00 committed by TensorFlower Gardener
parent 54bcf2919f
commit 8130665856
11 changed files with 29 additions and 324 deletions

View File

@ -65,6 +65,7 @@ std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
result += "#define GROUP_SIZE_0 get_local_size(0)\n";
result += "#define GROUP_SIZE_1 get_local_size(1)\n";
result += "#define GROUP_SIZE_2 get_local_size(2)\n";
result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
result += "#define MAIN_FUNCTION __kernel void main_function\n";
result += "#define INIT_FLOAT(value) (float)(value)\n";

View File

@ -85,8 +85,8 @@ std::unique_ptr<GPUOperation> SelectFullyConnected(
const OperationDef& op_def, int batch_size) {
if (gpu_info.IsAdreno()) {
return SelectFullyConnectedAdreno(attr, gpu_info, op_def, batch_size);
} else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() ||
gpu_info.IsNvidia() || gpu_info.IsIntel()) {
} else if (gpu_info.IsPowerVR() || gpu_info.IsAMD() || gpu_info.IsNvidia() ||
gpu_info.IsIntel() || gpu_info.IsApple()) {
return SelectFullyConnectedPowerVR(attr, gpu_info, op_def, batch_size);
} else if (gpu_info.IsMali()) {
return SelectFullyConnectedMali(attr, gpu_info, op_def, batch_size);

View File

@ -374,6 +374,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"//tensorflow/lite/delegates/gpu/common/task:storage_type_util",
"//tensorflow/lite/delegates/gpu/common/task:tensor_desc",
"//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_linear_desc.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
@ -31,7 +32,8 @@ namespace gpu {
namespace {
bool UseBufferForWeights(const GpuInfo& gpu_info) {
return gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsMali();
return gpu_info.IsAdreno() || gpu_info.IsAMD() || gpu_info.IsMali() ||
gpu_info.IsApple();
}
} // namespace
@ -46,11 +48,8 @@ FullyConnected::FullyConnected(const OperationDef& definition,
} else {
work_group_size_ = int3(32, 4, 1);
}
} else if (gpu_info.IsIntel()) {
work_group_size_ = int3(8, 4, 1);
} else if (gpu_info.IsNvidia()) {
work_group_size_ = int3(8, 4, 1);
} else if (gpu_info.IsPowerVR()) {
} else if (gpu_info.IsIntel() || gpu_info.IsNvidia() ||
gpu_info.IsPowerVR() || gpu_info.IsApple()) {
work_group_size_ = int3(8, 4, 1);
} else {
work_group_size_ = int3(16, 4, 1);
@ -76,6 +75,11 @@ FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) {
std::string FullyConnected::GetFullyConnectedKernelCode(
const OperationDef& op_def, const GpuInfo& gpu_info) {
const int wg_total_size = work_group_size_.x * work_group_size_.y;
const std::string barrier =
wg_total_size == 32 && gpu_info.IsWaveSizeEqualTo32()
? "SIMD_LOCAL_MEM_BARRIER"
: "LOCAL_MEM_BARRIER";
AddSrcTensor("src_tensor", op_def.src_tensors[0]);
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
@ -127,7 +131,9 @@ std::string FullyConnected::GetFullyConnectedKernelCode(
}
__local ACCUM_FLT4 temp[WG_X][WG_Y];
temp[tid.x][tid.y] = s;
LOCAL_MEM_BARRIER;
)";
c += " " + barrier + ";\n";
c += R"(
if (gid >= args.dst_tensor.Slices()) {
return;
}
@ -157,6 +163,10 @@ FullyConnected CreateFullyConnected(const GpuInfo& gpu_info,
TensorLinearDescriptor desc;
desc.storage_type = gpu_info.SupportsImages() ? LinearStorageType::TEXTURE_2D
: LinearStorageType::BUFFER;
if (gpu_info.IsApple()) {
desc.storage_type =
DeduceLinearStorageType(definition.GetPrimaryStorageType());
}
desc.element_type = definition.GetDataType();
desc.UploadLinearData(attr.bias);
result.args_.AddObject(

View File

@ -136,6 +136,7 @@ absl::Status ComputeTask::CompileProgram(MetalDevice* device,
@"TO_ACCUM_TYPE" : toAccumulatorType4,
@"TO_FLT4" : [NSString stringWithFormat:@"%@4", storageType],
@"SIMDGROUP_BARRIER" : barrier,
@"SIMD_LOCAL_MEM_BARRIER" : barrier,
@"MAIN_FUNCTION" : @"\"kernel void ComputeFunction\"",
@"GLOBAL_ID_0" : @"static_cast<int>(reserved_gid.x)",
@"GLOBAL_ID_1" : @"static_cast<int>(reserved_gid.y)",

View File

@ -20,7 +20,6 @@ cc_library(
deps = [
":conv",
":depthwise_conv",
":fully_connected",
":transpose_conv",
":winograd",
],
@ -180,30 +179,12 @@ ios_unit_test(
deps = [":elementwise_test_lib"],
)
cc_library(
name = "fully_connected",
srcs = ["fully_connected.cc"],
hdrs = ["fully_connected.h"],
deps = [
":util",
"//tensorflow/lite/delegates/gpu/common:gpu_info",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common/task:buffer_desc",
"//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
"@com_google_absl//absl/strings",
],
)
objc_library(
name = "fully_connected_test_lib",
testonly = 1,
srcs = ["fully_connected_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":fully_connected",
":test_util",
"//tensorflow/lite/delegates/gpu/common/tasks:fully_connected_test_util",
],

View File

@ -1,182 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include <cstdint>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/util.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
std::string GetFullyConnectedCode(const GpuInfo& gpu_info, int src_channels,
int dst_channels) {
bool shared_memory = gpu_info.IsApple() &&
gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal();
const std::string barrier = gpu_info.IsWaveSizeEqualTo32()
? "SIMDGROUP_BARRIER"
: "threadgroup_barrier";
const int src_depth = DivideRoundUp(src_channels, 4);
std::stringstream code;
code << R"(
kernel void ComputeFunction(
$$0
uint3 tid[[thread_position_in_threadgroup]],
uint tid_index[[thread_index_in_threadgroup]],
uint3 ugid[[thread_position_in_grid]]) {
)";
if (shared_memory) {
code << R"(
float summa = 0.0f;
threadgroup FLT4 local_vector[32];
for (int j = 0; j < args.src_depth_sub_groups; ++j) {
local_vector[tid_index] = j * 32 + tid_index >= args.src_tensor.Slices() ?
FLT4(0.0f) : args.src_tensor.Read(0, 0, j * 32 + tid_index);
$0(mem_flags::mem_threadgroup);
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
summa += dot(local_vector[tid.y * 8 + i], args.weights.Read(counter * args.dst_channels_alignedx8 + ugid.x));
}
$0(mem_flags::mem_none);
}
)";
} else {
code << R"(
float summa = 0.0f;
int counter = int(ugid.y) * args.src_depth_sub_groups;
for (int i = 0; i < args.src_depth_sub_groups; ++i, ++counter) {
)";
if (src_depth % 4 != 0) {
code << " if (counter >= args.src_tensor.Slices()) continue;"
<< std::endl;
}
code << " summa += dot(args.src_tensor.Read(0, 0, counter), "
"args.weights.Read(counter * "
"args.dst_channels_alignedx8 + ugid.x));"
<< std::endl;
code << " }" << std::endl;
}
code << R"(
threadgroup float temp[8][4];
temp[tid.x][tid.y] = summa;
$0(mem_flags::mem_threadgroup);
if (tid.y == 0) {
summa += temp[tid.x][1];
summa += temp[tid.x][2];
summa += temp[tid.x][3];
temp[tid.x][0] = summa;
}
$0(mem_flags::mem_threadgroup);
int dst_s = ugid.x / 4;
if (tid.y == 0 && tid.x % 4 == 0 && dst_s < args.dst_tensor.Slices()) {
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
args.bias.Read(dst_s);
args.dst_tensor.Write(value, 0, 0, dst_s);
}
}
)";
return absl::Substitute(code.str(), barrier);
}
} // namespace
int3 FullyConnected::GetGridSize() const {
const int dst_channels_aligned = AlignByN(dst_[0]->Channels(), 8);
return int3(dst_channels_aligned, 1, 1);
}
FullyConnected CreateFullyConnected(const GpuInfo& gpu_info,
const OperationDef& definition,
const FullyConnectedAttributes& attr) {
FullyConnected desc(definition);
desc.code_ = GetFullyConnectedCode(gpu_info, attr.weights.shape.i,
attr.weights.shape.o);
bool shared_memory = gpu_info.IsApple() &&
gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal();
const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
const int src_depth_sub_groups = shared_memory ? DivideRoundUp(src_depth, 32)
: DivideRoundUp(src_depth, 4);
desc.args_.AddInt("dst_channels_alignedx8",
AlignByN(attr.weights.shape.o, 8));
desc.args_.AddInt("src_depth_sub_groups", src_depth_sub_groups);
desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
const int src_depth_aligned = AlignByN(src_depth, shared_memory ? 32 : 4);
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
int counter = 0;
std::vector<float> filters_reordered(dst_channels_aligned *
src_depth_aligned * 4);
for (int j = 0; j < src_depth_aligned; ++j) {
for (int i = 0; i < dst_channels_aligned; ++i) {
for (int k = 0; k < 4; ++k) {
if (j * 4 + k >= attr.weights.shape.i || i >= attr.weights.shape.o) {
filters_reordered[counter++] = 0.0f;
} else {
const int f_index =
attr.weights.shape.LinearIndex({i, 0, 0, j * 4 + k});
filters_reordered[counter++] = attr.weights.data[f_index];
}
}
}
}
auto data_type = DeduceDataTypeFromPrecision(definition.precision);
BufferDescriptor weights_desc;
weights_desc.element_type = data_type;
weights_desc.element_size = 4;
weights_desc.data = GetByteBufferConverted(filters_reordered, data_type);
weights_desc.size = weights_desc.data.size();
desc.args_.AddObject(
"weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
BufferDescriptor bias_desc;
bias_desc.element_type = data_type;
bias_desc.element_size = 4;
bias_desc.data = GetByteBufferConvertedResized(attr.bias.data, data_type,
dst_channels_aligned);
bias_desc.size = bias_desc.data.size();
desc.args_.AddObject(
"bias", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
desc.work_group_size_ = int3(8, 4, 1);
return desc;
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -1,63 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
namespace tflite {
namespace gpu {
namespace metal {
class FullyConnected : public GPUOperation {
public:
FullyConnected() = default;
void GetPossibleKernelWorkGroups(
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info,
std::vector<int3>* work_groups) const override {
work_groups->push_back(work_group_size_);
}
int3 GetGridSize() const override;
// Move only
FullyConnected(FullyConnected&& kernel) = default;
FullyConnected& operator=(FullyConnected&& kernel) = default;
FullyConnected(const FullyConnected&) = delete;
FullyConnected& operator=(const FullyConnected&) = delete;
private:
explicit FullyConnected(const OperationDef& definition)
: GPUOperation(definition) {}
friend FullyConnected CreateFullyConnected(
const GpuInfo& gpu_info, const OperationDef& definition,
const FullyConnectedAttributes& attr);
};
FullyConnected CreateFullyConnected(const GpuInfo& gpu_info,
const OperationDef& definition,
const FullyConnectedAttributes& attr);
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_KERNELS_FULLY_CONNECTED_H_

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/tasks/fully_connected_test_util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
@interface FullyConnectedMetalTest : XCTestCase
@ -34,50 +33,6 @@ limitations under the License.
tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
}
namespace tflite {
namespace gpu {
namespace metal {
absl::Status FullyConnectedTest(TestExecutionEnvironment* env) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 1, 2);
src_tensor.data = {1, 2};
FullyConnectedAttributes attr;
attr.weights.shape = OHWI(4, 1, 1, 2);
attr.weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
attr.bias.shape = Linear(4);
attr.bias.data = {1, 2, 3, 4};
for (auto storage : env->GetSupportedStorages()) {
for (auto precision : env->GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
FullyConnected operation = CreateFullyConnected(env->GetGpuInfo(), op_def, attr);
RETURN_IF_ERROR(env->ExecuteGPUOperation(
src_tensor, absl::make_unique<FullyConnected>(std::move(operation)), BHWC(1, 1, 1, 4),
&dst_tensor));
RETURN_IF_ERROR(PointWiseNear({6, 13, 20, 27}, dst_tensor.data, eps))
<< "Failed using precision " << ToString(precision);
}
}
return absl::OkStatus();
}
} // namespace metal
} // namespace gpu
} // namespace tflite
- (void)testFullyConnectedMetal {
auto status = tflite::gpu::metal::FullyConnectedTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
}
- (void)testFullyConnected {
auto status = tflite::gpu::FullyConnectedTest(&exec_env_);
XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());

View File

@ -18,6 +18,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/common:winograd_util",
"//tensorflow/lite/delegates/gpu/common/selectors:default_selector",
"//tensorflow/lite/delegates/gpu/common/selectors:fully_connected_selector",
"//tensorflow/lite/delegates/gpu/common/selectors:subgraph",
"//tensorflow/lite/delegates/gpu/common/tasks:add",
"//tensorflow/lite/delegates/gpu/common/tasks:concat_xy",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model_hints.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/selectors/default_selector.h"
#include "tensorflow/lite/delegates/gpu/common/selectors/fully_connected_selector.h"
#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -53,7 +54,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
@ -399,11 +399,11 @@ absl::Status GPUOperationFromNode(const GpuInfo& gpu_info,
gpu_info);
break;
case OperationType::FULLY_CONNECTED: {
FullyConnected conv_op = CreateFullyConnected(
gpu_info, op_def,
absl::any_cast<FullyConnectedAttributes>(node.operation.attributes));
*gpu_op = absl::make_unique<FullyConnected>(std::move(conv_op));
break;
auto attr =
absl::any_cast<FullyConnectedAttributes>(node.operation.attributes);
*gpu_op = SelectFullyConnected(attr, gpu_info, op_def,
inputs[0]->tensor.shape.b);
return absl::OkStatus();
}
case OperationType::LSTM: {
*gpu_op = SelectLSTM(op_def, gpu_info);