Added new specialized kernel that replace FullyConnected + FullyConnected + Add.

PiperOrigin-RevId: 334262389
Change-Id: I6b900ed0626e49df3761af5b8f828ba4b436aee9
This commit is contained in:
Raman Sarokin 2020-09-28 16:33:47 -07:00 committed by TensorFlower Gardener
parent e1b710dd04
commit 55956aced7
3 changed files with 423 additions and 0 deletions

View File

@ -23,3 +23,30 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:types",
],
)
cc_library(
name = "fc_fc_add",
srcs = ["fc_fc_add.cc"],
hdrs = ["fc_fc_add.h"],
deps = [
"//tensorflow/lite/delegates/gpu/cl:arguments",
"//tensorflow/lite/delegates/gpu/cl:buffer",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/cl:device_info",
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
"//tensorflow/lite/delegates/gpu/cl:precision",
"//tensorflow/lite/delegates/gpu/cl:tensor",
"//tensorflow/lite/delegates/gpu/cl:tensor_type",
"//tensorflow/lite/delegates/gpu/cl:texture2d",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
"//tensorflow/lite/delegates/gpu/cl/kernels:tuning_parameters",
"//tensorflow/lite/delegates/gpu/cl/kernels:util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
"@com_google_absl//absl/memory",
],
)

View File

@ -0,0 +1,207 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h"
#include <string>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
bool UseBufferForWeights(const DeviceInfo& device_info) {
return device_info.IsAdreno() || device_info.IsAMD() || device_info.IsMali();
}
} // namespace
FCFCAdd::FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info)
: GPUOperation(definition) {
if (device_info.IsAdreno()) {
if (device_info.IsAdreno3xx()) {
work_group_size_ = int3(16, 4, 1);
} else if (device_info.IsAdreno4xx()) {
work_group_size_ = int3(32, 4, 1);
} else {
work_group_size_ = int3(32, 4, 1);
}
} else if (device_info.IsIntel()) {
work_group_size_ = int3(8, 4, 1);
} else if (device_info.IsNvidia()) {
work_group_size_ = int3(8, 4, 1);
} else if (device_info.IsPowerVR()) {
work_group_size_ = int3(8, 4, 1);
} else {
work_group_size_ = int3(16, 4, 1);
}
code_ = GetFCFCAddKernelCode(definition_, device_info);
}
FCFCAdd::FCFCAdd(FCFCAdd&& kernel) : GPUOperation(std::move(kernel)) {}
FCFCAdd& FCFCAdd::operator=(FCFCAdd&& kernel) {
if (this != &kernel) {
GPUOperation::operator=(std::move(kernel));
}
return *this;
}
// We split vec vec dot (every thread do vec vec dot product in basic
// vec mat mult) on 4 parts to create more threads
// tid.y thread process every 4-th element in vec vec dot
// Good results for ~1024 x 1024 sizes, for other can be written more
// optimized shaders
std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def,
const DeviceInfo& device_info) {
AddSrcTensor("src_tensor_0", op_def.src_tensors[0]);
AddSrcTensor("src_tensor_1", op_def.src_tensors[1]);
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
const bool weights_are_buffer = UseBufferForWeights(device_info);
std::string c = GetCommonDefines(op_def.precision);
switch (op_def.precision) {
case CalculationsPrecision::F32:
c += "#define FLT16 float16\n";
break;
case CalculationsPrecision::F32_F16:
case CalculationsPrecision::F16:
c += "#define FLT16 half16\n";
break;
}
c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n";
c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n";
c += R"(__kernel void main_function($0) {
int gid = get_global_id(0);
int2 tid = (int2)(get_local_id(0), get_local_id(1));
ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);
if (gid < args.dst_tensor.Slices()) {
for (int c = tid.y; c < args.src_tensor_0.Slices(); c += WG_Y) {
FLT4 v = args.src_tensor_0.Read(0, 0, c);
)";
if (weights_are_buffer) {
c += R"(FLT16 w = args.weights0.Read(c * args.dst_tensor.Slices() + gid);
FLT4 partial = v.s0 * w.s0123;
partial = mad(v.s1, w.s4567, partial);
partial = mad(v.s2, w.s89ab, partial);
partial = mad(v.s3, w.scdef, partial);
s += TO_ACCUM_TYPE(partial);
)";
} else {
c += R"(FLT4 w0 = args.weights0.Read(c * 4 + 0, gid);
FLT4 w1 = args.weights0.Read(c * 4 + 1, gid);
FLT4 w2 = args.weights0.Read(c * 4 + 2, gid);
FLT4 w3 = args.weights0.Read(c * 4 + 3, gid);
FLT4 partial = v.s0 * w0;
partial = mad(v.s1, w1, partial);
partial = mad(v.s2, w2, partial);
partial = mad(v.s3, w3, partial);
s += TO_ACCUM_TYPE(partial);
)";
}
c += R"( }
for (int c = tid.y; c < args.src_tensor_1.Slices(); c += WG_Y) {
FLT4 v = args.src_tensor_1.Read(0, 0, c);
)";
if (weights_are_buffer) {
c += R"(FLT16 w = args.weights1.Read(c * args.dst_tensor.Slices() + gid);
FLT4 partial = v.s0 * w.s0123;
partial = mad(v.s1, w.s4567, partial);
partial = mad(v.s2, w.s89ab, partial);
partial = mad(v.s3, w.scdef, partial);
s += TO_ACCUM_TYPE(partial);
)";
} else {
c += R"(FLT4 w0 = args.weights1.Read(c * 4 + 0, gid);
FLT4 w1 = args.weights1.Read(c * 4 + 1, gid);
FLT4 w2 = args.weights1.Read(c * 4 + 2, gid);
FLT4 w3 = args.weights1.Read(c * 4 + 3, gid);
FLT4 partial = v.s0 * w0;
partial = mad(v.s1, w1, partial);
partial = mad(v.s2, w2, partial);
partial = mad(v.s3, w3, partial);
s += TO_ACCUM_TYPE(partial);
)";
}
c += R"( }
}
__local ACCUM_FLT4 temp[WG_X][WG_Y];
temp[tid.x][tid.y] = s;
barrier(CLK_LOCAL_MEM_FENCE);
if (gid >= args.dst_tensor.Slices()) {
return;
}
if (tid.y == 0) {
)";
for (int i = 1; i < work_group_size_.y; ++i) {
c += " s += temp[tid.x][" + std::to_string(i) + "];\n";
}
c +=
R"( FLT4 r0 = TO_FLT4(s) + args.biases0.Read(gid) + args.biases1.Read(gid);
args.dst_tensor.Write(r0, 0, 0, gid);
}
})";
return c;
}
int3 FCFCAdd::GetGridSize() const { return int3(dst_[0]->Slices(), 1, 1); }
FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info,
const OperationDef& definition,
const FullyConnectedAttributes& attr0,
const FullyConnectedAttributes& attr1) {
FCFCAdd result(definition, device_info);
result.UploadWeights(attr0.weights, "weights0",
UseBufferForWeights(device_info));
result.UploadWeights(attr1.weights, "weights1",
UseBufferForWeights(device_info));
TensorLinearDescriptor desc0;
desc0.storage_type = LinearStorageType::TEXTURE_2D;
desc0.element_type = definition.GetDataType();
desc0.UploadLinearData(attr0.bias);
result.args_.AddObject(
"biases0", absl::make_unique<TensorLinearDescriptor>(std::move(desc0)));
TensorLinearDescriptor desc1;
desc1.storage_type = LinearStorageType::TEXTURE_2D;
desc1.element_type = definition.GetDataType();
desc1.UploadLinearData(attr1.bias);
result.args_.AddObject(
"biases1", absl::make_unique<TensorLinearDescriptor>(std::move(desc1)));
return result;
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,189 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
namespace tflite {
namespace gpu {
namespace cl {
template <DataType T, typename S>
void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor<OHWI, T>& weights,
S* dst) {
const int src_channels = weights.shape.i;
const int padded_src_channels = AlignByN(src_channels, 4);
const int dst_channels = weights.shape.o;
const int padded_dst_channels = AlignByN(dst_channels, 4);
for (int block_y = 0; 4 * block_y < padded_dst_channels; block_y++) {
for (int y_in_block = 0; y_in_block < 4; y_in_block++) {
for (int block_x = 0; 4 * block_x < padded_src_channels; block_x++) {
for (int x_in_block = 0; x_in_block < 4; x_in_block++) {
int y = 4 * block_y + y_in_block;
int x = 4 * block_x + x_in_block;
int dst_index = block_x * padded_dst_channels * 4 + block_y * 16 +
x_in_block * 4 + y_in_block;
if (x < src_channels && y < dst_channels) {
dst[dst_index] = weights.data[src_channels * y + x];
} else {
dst[dst_index] = 0.0f;
}
}
}
}
}
}
template <DataType T, typename S>
void RearrangeFCWeightsToOIO4I4(const tflite::gpu::Tensor<OHWI, T>& weights,
S* dst) {
const int src_channels = weights.shape.i;
const int src_depth = DivideRoundUp(src_channels, 4);
const int dst_channels = weights.shape.o;
const int dst_depth = DivideRoundUp(dst_channels, 4);
int counter = 0;
for (int d = 0; d < dst_depth; ++d) {
for (int s = 0; s < src_depth; ++s) {
for (int i = 0; i < 4; ++i) {
const int src_ch = s * 4 + i;
for (int j = 0; j < 4; ++j) {
const int dst_ch = d * 4 + j;
if (src_ch < src_channels && dst_ch < dst_channels) {
dst[counter++] = weights.data[dst_ch * src_channels + src_ch];
} else {
dst[counter++] = 0.0f;
}
}
}
}
}
}
class FCFCAdd : public GPUOperation {
public:
FCFCAdd() = default;
void GetPossibleKernelWorkGroups(
TuningType tuning_type, const DeviceInfo& device_info,
const KernelInfo& kernel_info,
std::vector<int3>* work_groups) const override {
work_groups->push_back(work_group_size_);
}
int3 GetGridSize() const override;
// Move only
FCFCAdd(FCFCAdd&& kernel);
FCFCAdd& operator=(FCFCAdd&& kernel);
FCFCAdd(const FCFCAdd&) = delete;
FCFCAdd& operator=(const FCFCAdd&) = delete;
private:
FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info);
friend FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info,
const OperationDef& definition,
const FullyConnectedAttributes& attr0,
const FullyConnectedAttributes& attr1);
template <DataType T>
void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
const std::string& name, bool weights_are_buffer);
std::string GetFCFCAddKernelCode(const OperationDef& op_def,
const DeviceInfo& device_info);
};
template <DataType T>
void FCFCAdd::UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
const std::string& name, bool weights_are_buffer) {
const int src_depth = DivideRoundUp(weights.shape.i, 4);
const int dst_depth = DivideRoundUp(weights.shape.o, 4);
const int elements_count = src_depth * dst_depth * 4;
const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
const int float4_size = f32_weights ? 16 : 8;
if (weights_are_buffer) {
BufferDescriptor desc;
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
desc.element_size = 16;
desc.size = float4_size * elements_count;
desc.data.resize(desc.size);
if (f32_weights) {
float* ptr = reinterpret_cast<float*>(desc.data.data());
RearrangeFCWeightsToIOO4I4(weights, ptr);
} else {
half* ptr = reinterpret_cast<half*>(desc.data.data());
RearrangeFCWeightsToIOO4I4(weights, ptr);
}
args_.AddObject(name, absl::make_unique<BufferDescriptor>(std::move(desc)));
} else {
Texture2DDescriptor desc;
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
// desc.element_type = DataType::UINT8;
// desc.normalized = true;
// desc.normalized_type = f32_weights ? DataType::FLOAT32 :
// DataType::FLOAT16;
desc.size = int2(src_depth * 4, dst_depth);
desc.data.resize(float4_size * elements_count);
if (f32_weights) {
float* ptr = reinterpret_cast<float*>(desc.data.data());
RearrangeFCWeightsToOIO4I4(weights, ptr);
} else {
half* ptr = reinterpret_cast<half*>(desc.data.data());
RearrangeFCWeightsToOIO4I4(weights, ptr);
}
args_.AddObject(name,
absl::make_unique<Texture2DDescriptor>(std::move(desc)));
}
}
FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info,
const OperationDef& definition,
const FullyConnectedAttributes& attr0,
const FullyConnectedAttributes& attr1);
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_