Added new specialized kernel that replace FullyConnected + FullyConnected + Add.
PiperOrigin-RevId: 334262389 Change-Id: I6b900ed0626e49df3761af5b8f828ba4b436aee9
This commit is contained in:
parent
e1b710dd04
commit
55956aced7
@ -23,3 +23,30 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fc_fc_add",
|
||||
srcs = ["fc_fc_add.cc"],
|
||||
hdrs = ["fc_fc_add.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/cl:arguments",
|
||||
"//tensorflow/lite/delegates/gpu/cl:buffer",
|
||||
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
|
||||
"//tensorflow/lite/delegates/gpu/cl:device_info",
|
||||
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
|
||||
"//tensorflow/lite/delegates/gpu/cl:precision",
|
||||
"//tensorflow/lite/delegates/gpu/cl:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/cl:tensor_type",
|
||||
"//tensorflow/lite/delegates/gpu/cl:texture2d",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:tuning_parameters",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:util",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
207
tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc
Normal file
207
tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc
Normal file
@ -0,0 +1,207 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
bool UseBufferForWeights(const DeviceInfo& device_info) {
|
||||
return device_info.IsAdreno() || device_info.IsAMD() || device_info.IsMali();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FCFCAdd::FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info)
|
||||
: GPUOperation(definition) {
|
||||
if (device_info.IsAdreno()) {
|
||||
if (device_info.IsAdreno3xx()) {
|
||||
work_group_size_ = int3(16, 4, 1);
|
||||
} else if (device_info.IsAdreno4xx()) {
|
||||
work_group_size_ = int3(32, 4, 1);
|
||||
} else {
|
||||
work_group_size_ = int3(32, 4, 1);
|
||||
}
|
||||
} else if (device_info.IsIntel()) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
} else if (device_info.IsNvidia()) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
} else if (device_info.IsPowerVR()) {
|
||||
work_group_size_ = int3(8, 4, 1);
|
||||
} else {
|
||||
work_group_size_ = int3(16, 4, 1);
|
||||
}
|
||||
code_ = GetFCFCAddKernelCode(definition_, device_info);
|
||||
}
|
||||
|
||||
FCFCAdd::FCFCAdd(FCFCAdd&& kernel) : GPUOperation(std::move(kernel)) {}
|
||||
|
||||
FCFCAdd& FCFCAdd::operator=(FCFCAdd&& kernel) {
|
||||
if (this != &kernel) {
|
||||
GPUOperation::operator=(std::move(kernel));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// We split vec vec dot (every thread do vec vec dot product in basic
|
||||
// vec mat mult) on 4 parts to create more threads
|
||||
// tid.y thread process every 4-th element in vec vec dot
|
||||
// Good results for ~1024 x 1024 sizes, for other can be written more
|
||||
// optimized shaders
|
||||
|
||||
std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info) {
|
||||
AddSrcTensor("src_tensor_0", op_def.src_tensors[0]);
|
||||
AddSrcTensor("src_tensor_1", op_def.src_tensors[1]);
|
||||
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
|
||||
|
||||
const bool weights_are_buffer = UseBufferForWeights(device_info);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
switch (op_def.precision) {
|
||||
case CalculationsPrecision::F32:
|
||||
c += "#define FLT16 float16\n";
|
||||
break;
|
||||
case CalculationsPrecision::F32_F16:
|
||||
case CalculationsPrecision::F16:
|
||||
c += "#define FLT16 half16\n";
|
||||
break;
|
||||
}
|
||||
|
||||
c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n";
|
||||
c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n";
|
||||
|
||||
c += R"(__kernel void main_function($0) {
|
||||
int gid = get_global_id(0);
|
||||
int2 tid = (int2)(get_local_id(0), get_local_id(1));
|
||||
ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f);
|
||||
if (gid < args.dst_tensor.Slices()) {
|
||||
for (int c = tid.y; c < args.src_tensor_0.Slices(); c += WG_Y) {
|
||||
FLT4 v = args.src_tensor_0.Read(0, 0, c);
|
||||
)";
|
||||
if (weights_are_buffer) {
|
||||
c += R"(FLT16 w = args.weights0.Read(c * args.dst_tensor.Slices() + gid);
|
||||
FLT4 partial = v.s0 * w.s0123;
|
||||
partial = mad(v.s1, w.s4567, partial);
|
||||
partial = mad(v.s2, w.s89ab, partial);
|
||||
partial = mad(v.s3, w.scdef, partial);
|
||||
s += TO_ACCUM_TYPE(partial);
|
||||
)";
|
||||
} else {
|
||||
c += R"(FLT4 w0 = args.weights0.Read(c * 4 + 0, gid);
|
||||
FLT4 w1 = args.weights0.Read(c * 4 + 1, gid);
|
||||
FLT4 w2 = args.weights0.Read(c * 4 + 2, gid);
|
||||
FLT4 w3 = args.weights0.Read(c * 4 + 3, gid);
|
||||
FLT4 partial = v.s0 * w0;
|
||||
partial = mad(v.s1, w1, partial);
|
||||
partial = mad(v.s2, w2, partial);
|
||||
partial = mad(v.s3, w3, partial);
|
||||
s += TO_ACCUM_TYPE(partial);
|
||||
)";
|
||||
}
|
||||
c += R"( }
|
||||
for (int c = tid.y; c < args.src_tensor_1.Slices(); c += WG_Y) {
|
||||
FLT4 v = args.src_tensor_1.Read(0, 0, c);
|
||||
)";
|
||||
if (weights_are_buffer) {
|
||||
c += R"(FLT16 w = args.weights1.Read(c * args.dst_tensor.Slices() + gid);
|
||||
FLT4 partial = v.s0 * w.s0123;
|
||||
partial = mad(v.s1, w.s4567, partial);
|
||||
partial = mad(v.s2, w.s89ab, partial);
|
||||
partial = mad(v.s3, w.scdef, partial);
|
||||
s += TO_ACCUM_TYPE(partial);
|
||||
)";
|
||||
} else {
|
||||
c += R"(FLT4 w0 = args.weights1.Read(c * 4 + 0, gid);
|
||||
FLT4 w1 = args.weights1.Read(c * 4 + 1, gid);
|
||||
FLT4 w2 = args.weights1.Read(c * 4 + 2, gid);
|
||||
FLT4 w3 = args.weights1.Read(c * 4 + 3, gid);
|
||||
FLT4 partial = v.s0 * w0;
|
||||
partial = mad(v.s1, w1, partial);
|
||||
partial = mad(v.s2, w2, partial);
|
||||
partial = mad(v.s3, w3, partial);
|
||||
s += TO_ACCUM_TYPE(partial);
|
||||
)";
|
||||
}
|
||||
c += R"( }
|
||||
}
|
||||
__local ACCUM_FLT4 temp[WG_X][WG_Y];
|
||||
temp[tid.x][tid.y] = s;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (gid >= args.dst_tensor.Slices()) {
|
||||
return;
|
||||
}
|
||||
if (tid.y == 0) {
|
||||
)";
|
||||
for (int i = 1; i < work_group_size_.y; ++i) {
|
||||
c += " s += temp[tid.x][" + std::to_string(i) + "];\n";
|
||||
}
|
||||
c +=
|
||||
R"( FLT4 r0 = TO_FLT4(s) + args.biases0.Read(gid) + args.biases1.Read(gid);
|
||||
args.dst_tensor.Write(r0, 0, 0, gid);
|
||||
}
|
||||
})";
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
int3 FCFCAdd::GetGridSize() const { return int3(dst_[0]->Slices(), 1, 1); }
|
||||
|
||||
FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr0,
|
||||
const FullyConnectedAttributes& attr1) {
|
||||
FCFCAdd result(definition, device_info);
|
||||
result.UploadWeights(attr0.weights, "weights0",
|
||||
UseBufferForWeights(device_info));
|
||||
result.UploadWeights(attr1.weights, "weights1",
|
||||
UseBufferForWeights(device_info));
|
||||
|
||||
TensorLinearDescriptor desc0;
|
||||
desc0.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
desc0.element_type = definition.GetDataType();
|
||||
desc0.UploadLinearData(attr0.bias);
|
||||
result.args_.AddObject(
|
||||
"biases0", absl::make_unique<TensorLinearDescriptor>(std::move(desc0)));
|
||||
|
||||
TensorLinearDescriptor desc1;
|
||||
desc1.storage_type = LinearStorageType::TEXTURE_2D;
|
||||
desc1.element_type = definition.GetDataType();
|
||||
desc1.UploadLinearData(attr1.bias);
|
||||
result.args_.AddObject(
|
||||
"biases1", absl::make_unique<TensorLinearDescriptor>(std::move(desc1)));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
189
tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h
Normal file
189
tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h
Normal file
@ -0,0 +1,189 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/device_info.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
template <DataType T, typename S>
|
||||
void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
S* dst) {
|
||||
const int src_channels = weights.shape.i;
|
||||
const int padded_src_channels = AlignByN(src_channels, 4);
|
||||
const int dst_channels = weights.shape.o;
|
||||
const int padded_dst_channels = AlignByN(dst_channels, 4);
|
||||
|
||||
for (int block_y = 0; 4 * block_y < padded_dst_channels; block_y++) {
|
||||
for (int y_in_block = 0; y_in_block < 4; y_in_block++) {
|
||||
for (int block_x = 0; 4 * block_x < padded_src_channels; block_x++) {
|
||||
for (int x_in_block = 0; x_in_block < 4; x_in_block++) {
|
||||
int y = 4 * block_y + y_in_block;
|
||||
int x = 4 * block_x + x_in_block;
|
||||
int dst_index = block_x * padded_dst_channels * 4 + block_y * 16 +
|
||||
x_in_block * 4 + y_in_block;
|
||||
if (x < src_channels && y < dst_channels) {
|
||||
dst[dst_index] = weights.data[src_channels * y + x];
|
||||
} else {
|
||||
dst[dst_index] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType T, typename S>
|
||||
void RearrangeFCWeightsToOIO4I4(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
S* dst) {
|
||||
const int src_channels = weights.shape.i;
|
||||
const int src_depth = DivideRoundUp(src_channels, 4);
|
||||
const int dst_channels = weights.shape.o;
|
||||
const int dst_depth = DivideRoundUp(dst_channels, 4);
|
||||
|
||||
int counter = 0;
|
||||
for (int d = 0; d < dst_depth; ++d) {
|
||||
for (int s = 0; s < src_depth; ++s) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const int src_ch = s * 4 + i;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int dst_ch = d * 4 + j;
|
||||
if (src_ch < src_channels && dst_ch < dst_channels) {
|
||||
dst[counter++] = weights.data[dst_ch * src_channels + src_ch];
|
||||
} else {
|
||||
dst[counter++] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class FCFCAdd : public GPUOperation {
|
||||
public:
|
||||
FCFCAdd() = default;
|
||||
void GetPossibleKernelWorkGroups(
|
||||
TuningType tuning_type, const DeviceInfo& device_info,
|
||||
const KernelInfo& kernel_info,
|
||||
std::vector<int3>* work_groups) const override {
|
||||
work_groups->push_back(work_group_size_);
|
||||
}
|
||||
int3 GetGridSize() const override;
|
||||
|
||||
// Move only
|
||||
FCFCAdd(FCFCAdd&& kernel);
|
||||
FCFCAdd& operator=(FCFCAdd&& kernel);
|
||||
FCFCAdd(const FCFCAdd&) = delete;
|
||||
FCFCAdd& operator=(const FCFCAdd&) = delete;
|
||||
|
||||
private:
|
||||
FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info);
|
||||
friend FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr0,
|
||||
const FullyConnectedAttributes& attr1);
|
||||
|
||||
template <DataType T>
|
||||
void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const std::string& name, bool weights_are_buffer);
|
||||
|
||||
std::string GetFCFCAddKernelCode(const OperationDef& op_def,
|
||||
const DeviceInfo& device_info);
|
||||
};
|
||||
|
||||
template <DataType T>
|
||||
void FCFCAdd::UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const std::string& name, bool weights_are_buffer) {
|
||||
const int src_depth = DivideRoundUp(weights.shape.i, 4);
|
||||
const int dst_depth = DivideRoundUp(weights.shape.o, 4);
|
||||
|
||||
const int elements_count = src_depth * dst_depth * 4;
|
||||
const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
|
||||
|
||||
const int float4_size = f32_weights ? 16 : 8;
|
||||
|
||||
if (weights_are_buffer) {
|
||||
BufferDescriptor desc;
|
||||
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
desc.element_size = 16;
|
||||
desc.size = float4_size * elements_count;
|
||||
desc.data.resize(desc.size);
|
||||
|
||||
if (f32_weights) {
|
||||
float* ptr = reinterpret_cast<float*>(desc.data.data());
|
||||
RearrangeFCWeightsToIOO4I4(weights, ptr);
|
||||
} else {
|
||||
half* ptr = reinterpret_cast<half*>(desc.data.data());
|
||||
RearrangeFCWeightsToIOO4I4(weights, ptr);
|
||||
}
|
||||
|
||||
args_.AddObject(name, absl::make_unique<BufferDescriptor>(std::move(desc)));
|
||||
} else {
|
||||
Texture2DDescriptor desc;
|
||||
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
// desc.element_type = DataType::UINT8;
|
||||
// desc.normalized = true;
|
||||
// desc.normalized_type = f32_weights ? DataType::FLOAT32 :
|
||||
// DataType::FLOAT16;
|
||||
desc.size = int2(src_depth * 4, dst_depth);
|
||||
desc.data.resize(float4_size * elements_count);
|
||||
|
||||
if (f32_weights) {
|
||||
float* ptr = reinterpret_cast<float*>(desc.data.data());
|
||||
RearrangeFCWeightsToOIO4I4(weights, ptr);
|
||||
} else {
|
||||
half* ptr = reinterpret_cast<half*>(desc.data.data());
|
||||
RearrangeFCWeightsToOIO4I4(weights, ptr);
|
||||
}
|
||||
|
||||
args_.AddObject(name,
|
||||
absl::make_unique<Texture2DDescriptor>(std::move(desc)));
|
||||
}
|
||||
}
|
||||
|
||||
FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr0,
|
||||
const FullyConnectedAttributes& attr1);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_
|
Loading…
x
Reference in New Issue
Block a user