Preparing ConvPowerVR to support 3d convolutions.
PiperOrigin-RevId: 331182880 Change-Id: I811603ff3677349752c6200d6e1cafa62ab3a2ee
This commit is contained in:
parent
3a366ec295
commit
721a91c7e9
@ -71,32 +71,33 @@ std::string GenerateAsyncUpload(const std::string& local_ptr_name,
|
||||
return c;
|
||||
}
|
||||
|
||||
std::string GenerateBlockCoords(const int3& block_size,
|
||||
std::string GenerateBlockCoords(const int4& block_size,
|
||||
const int3& work_group_launch_order,
|
||||
bool linear_hw) {
|
||||
bool linear_spatial) {
|
||||
std::string c;
|
||||
int3 launch_remap;
|
||||
launch_remap[work_group_launch_order.x] = 0;
|
||||
launch_remap[work_group_launch_order.y] = 1;
|
||||
launch_remap[work_group_launch_order.z] = 2;
|
||||
if (linear_hw) {
|
||||
if (linear_spatial) {
|
||||
if (work_group_launch_order[0] == 0) {
|
||||
c += " int linear_hw = get_global_id(0);\n";
|
||||
c += " int linear_spatial = get_global_id(0);\n";
|
||||
} else {
|
||||
c += " int linear_hw = get_group_id(" + std::to_string(launch_remap[0]) +
|
||||
c += " int linear_spatial = get_group_id(" +
|
||||
std::to_string(launch_remap[0]) +
|
||||
") * get_local_size(0) + get_local_id(0);\n";
|
||||
}
|
||||
c += " int DST_Y = (linear_hw / args.task_size_x) * " +
|
||||
c += " int DST_Y = (linear_spatial / args.task_size_x) * " +
|
||||
std::to_string(block_size.y) + ";\n";
|
||||
c += " int DST_X = (linear_hw % args.task_size_x) * " +
|
||||
c += " int DST_X = (linear_spatial % args.task_size_x) * " +
|
||||
std::to_string(block_size.x) + ";\n";
|
||||
if (work_group_launch_order[1] == 1) {
|
||||
c += " int DST_S = get_global_id(1) * " + std::to_string(block_size.z) +
|
||||
c += " int DST_S = get_global_id(1) * " + std::to_string(block_size.w) +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int DST_S = (get_group_id(" + std::to_string(launch_remap[1]) +
|
||||
") * get_local_size(1) + get_local_id(1)) * " +
|
||||
std::to_string(block_size.z) + ";\n";
|
||||
std::to_string(block_size.w) + ";\n";
|
||||
}
|
||||
} else {
|
||||
if (work_group_launch_order[0] == 0) {
|
||||
@ -116,12 +117,12 @@ std::string GenerateBlockCoords(const int3& block_size,
|
||||
std::to_string(block_size.y) + ";\n";
|
||||
}
|
||||
if (work_group_launch_order[2] == 2) {
|
||||
c += " int DST_S = get_global_id(2) * " + std::to_string(block_size.z) +
|
||||
c += " int DST_S = get_global_id(2) * " + std::to_string(block_size.w) +
|
||||
";\n";
|
||||
} else {
|
||||
c += " int DST_S = (get_group_id(" + std::to_string(launch_remap[2]) +
|
||||
") * get_local_size(2) + get_local_id(2)) * " +
|
||||
std::to_string(block_size.z) + ";\n";
|
||||
std::to_string(block_size.w) + ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,10 +134,10 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
const DeviceInfo& device_info, const BHWC* dst_shape)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w,
|
||||
-attr.padding.prepended.h),
|
||||
kernel_dilation_(attr.weights.shape.w, attr.weights.shape.h,
|
||||
attr.dilations.w, attr.dilations.h),
|
||||
stride_(attr.strides.w, attr.strides.h, 1, 1),
|
||||
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
|
||||
kernel_size_(attr.weights.shape.w, attr.weights.shape.h, 1, 1),
|
||||
dilation_(attr.dilations.w, attr.dilations.h, 1, 1),
|
||||
conv_params_(GuessBestParams(device_info, definition, attr, dst_shape)) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
@ -144,10 +145,10 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const BHWC& weights_shape,
|
||||
const DeviceInfo& device_info, const BHWC* dst_shape)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w,
|
||||
-attr.padding.prepended.h),
|
||||
kernel_dilation_(weights_shape.w, weights_shape.h, attr.dilations.w,
|
||||
attr.dilations.h),
|
||||
stride_(attr.strides.w, attr.strides.h, 1, 1),
|
||||
padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
|
||||
kernel_size_(weights_shape.w, weights_shape.h, 1, 1),
|
||||
dilation_(attr.dilations.w, attr.dilations.h, 1, 1),
|
||||
conv_params_(GuessBestParams(device_info, definition, attr, weights_shape,
|
||||
dst_shape)) {}
|
||||
|
||||
@ -155,25 +156,33 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
const DeviceInfo& device_info, const BHWC* dst_shape)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(1, 1, 0, 0),
|
||||
kernel_dilation_(1, 1, 1, 1),
|
||||
stride_(1, 1, 1, 1),
|
||||
padding_(0, 0, 0, 0),
|
||||
kernel_size_(1, 1, 1, 1),
|
||||
dilation_(1, 1, 1, 1),
|
||||
conv_params_(GuessBestParams(device_info, definition, attr, dst_shape)) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(1, 1, 0, 0),
|
||||
kernel_dilation_(1, 1, 1, 1) {}
|
||||
stride_(1, 1, 1, 1),
|
||||
padding_(0, 0, 0, 0),
|
||||
kernel_size_(1, 1, 1, 1),
|
||||
dilation_(1, 1, 1, 1) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(ConvPowerVR&& operation)
|
||||
: GPUOperation(std::move(operation)),
|
||||
stride_padding_(operation.stride_padding_),
|
||||
kernel_dilation_(operation.kernel_dilation_),
|
||||
stride_(operation.stride_),
|
||||
padding_(operation.padding_),
|
||||
kernel_size_(operation.kernel_size_),
|
||||
dilation_(operation.dilation_),
|
||||
conv_params_(operation.conv_params_) {}
|
||||
|
||||
ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) {
|
||||
if (this != &operation) {
|
||||
std::swap(stride_padding_, operation.stride_padding_);
|
||||
std::swap(kernel_dilation_, operation.kernel_dilation_);
|
||||
std::swap(stride_, operation.stride_);
|
||||
std::swap(padding_, operation.padding_);
|
||||
std::swap(kernel_size_, operation.kernel_size_);
|
||||
std::swap(dilation_, operation.dilation_);
|
||||
std::swap(conv_params_, operation.conv_params_);
|
||||
GPUOperation::operator=(std::move(operation));
|
||||
}
|
||||
@ -182,7 +191,7 @@ ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) {
|
||||
|
||||
void ConvPowerVR::GenerateCode(const DeviceInfo& device_info) {
|
||||
const bool stride_correction =
|
||||
definition_.IsBatchSupported() && stride_padding_.x != 1;
|
||||
definition_.IsBatchSupported() && stride_.x != 1;
|
||||
code_ =
|
||||
GenerateConv(device_info, definition_, stride_correction, conv_params_);
|
||||
if (definition_.precision == CalculationsPrecision::F16 &&
|
||||
@ -196,18 +205,16 @@ void ConvPowerVR::GenerateCode(const DeviceInfo& device_info) {
|
||||
|
||||
absl::Status ConvPowerVR::BindArguments() {
|
||||
if (!conv_params_.x_kernel_is_1 || !conv_params_.y_kernel_is_1) {
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_x", stride_padding_.x));
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_y", stride_padding_.y));
|
||||
RETURN_IF_ERROR(
|
||||
args_.SetInt("padding_x", stride_padding_.z * src_[0]->Batch()));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_y", stride_padding_.w));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_dilation_.x));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_dilation_.y));
|
||||
RETURN_IF_ERROR(
|
||||
args_.SetInt("dilation_x", kernel_dilation_.z * src_[0]->Batch()));
|
||||
RETURN_IF_ERROR(args_.SetInt("dilation_y", kernel_dilation_.w));
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x));
|
||||
RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch()));
|
||||
RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x));
|
||||
RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y));
|
||||
RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
|
||||
RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y));
|
||||
}
|
||||
if (conv_params_.linear_hw) {
|
||||
if (conv_params_.linear_spatial) {
|
||||
const int grid_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(),
|
||||
conv_params_.block_size.x);
|
||||
RETURN_IF_ERROR(args_.SetInt("task_size_x", grid_x));
|
||||
@ -221,10 +228,10 @@ int3 ConvPowerVR::GetGridSize() const {
|
||||
const int grid_y =
|
||||
DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
|
||||
const int grid_z =
|
||||
DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.z);
|
||||
DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w);
|
||||
int3 wg;
|
||||
|
||||
if (conv_params_.linear_hw) {
|
||||
if (conv_params_.linear_spatial) {
|
||||
wg.x = DivideRoundUp(grid_x * grid_y, work_group_size_.x);
|
||||
wg.y = DivideRoundUp(grid_z, work_group_size_.y);
|
||||
return int3(
|
||||
@ -285,6 +292,28 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
AddSrcBuffer("weights", desc);
|
||||
}
|
||||
|
||||
const auto& src_def = op_def.src_tensors[0];
|
||||
|
||||
auto generate_id = [&](const std::string& x, const std::string& y,
|
||||
const std::string& z) {
|
||||
std::string id;
|
||||
if (src_def.HasAxis(Axis::WIDTH)) {
|
||||
id += "_w" + x;
|
||||
}
|
||||
if (src_def.HasAxis(Axis::HEIGHT)) {
|
||||
id += "_h" + y;
|
||||
}
|
||||
if (src_def.HasAxis(Axis::DEPTH)) {
|
||||
id += "_d" + z;
|
||||
}
|
||||
return id;
|
||||
};
|
||||
|
||||
auto generate_id_full = [&](const std::string& x, const std::string& y,
|
||||
const std::string& z, const std::string& s) {
|
||||
return generate_id(x, y, z) + "_s" + s;
|
||||
};
|
||||
|
||||
auto dst_desc = op_def.dst_tensors[0];
|
||||
if (op_def.IsBatchSupported()) {
|
||||
dst_desc.SetStateVar("BatchedWidth", "true");
|
||||
@ -302,7 +331,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
args_.AddInt("dilation_x");
|
||||
args_.AddInt("dilation_y");
|
||||
}
|
||||
if (conv_params_.linear_hw) {
|
||||
if (conv_params_.linear_spatial) {
|
||||
args_.AddInt("task_size_x");
|
||||
}
|
||||
|
||||
@ -318,7 +347,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
||||
|
||||
const int local_mem_size =
|
||||
conv_params.block_size.z * 4 * conv_params.src_depth_loop_size;
|
||||
conv_params.block_size.w * 4 * conv_params.src_depth_loop_size;
|
||||
|
||||
const bool use_simd_broadcast = conv_params.IsPrivateMemBroadcast();
|
||||
const int simd_size = conv_params.simd_size;
|
||||
@ -343,7 +372,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
|
||||
}
|
||||
}
|
||||
const int3 block_size = conv_params.block_size;
|
||||
const int4 block_size = conv_params.block_size;
|
||||
if (conv_params.fixed_work_group_size) {
|
||||
c += "__attribute__((reqd_work_group_size(" +
|
||||
std::to_string(work_group_size_.x) + ", " +
|
||||
@ -358,7 +387,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
c += "$0) {\n";
|
||||
c += GenerateBlockCoords(conv_params.block_size,
|
||||
conv_params.work_group_launch_order,
|
||||
conv_params.linear_hw);
|
||||
conv_params.linear_spatial);
|
||||
std::vector<std::string> dst_x(conv_params.block_size.x);
|
||||
for (int x = 0; x < conv_params.block_size.x; ++x) {
|
||||
dst_x[x] = "(DST_X + " + std::to_string(x) + ")";
|
||||
@ -376,7 +405,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
}
|
||||
if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
|
||||
if (conv_params.linear_hw) {
|
||||
if (conv_params.linear_spatial) {
|
||||
c += " int lid = get_local_id(0);\n";
|
||||
} else {
|
||||
c += " int lid = get_local_id(1) * " +
|
||||
@ -386,11 +415,17 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
if (use_simd_broadcast) {
|
||||
c += " int simd_id = get_sub_group_local_id();\n";
|
||||
}
|
||||
for (int z = 0; z < block_size.z; ++z) {
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
c += " ACCUM_FLT4 r" + std::to_string(z) + std::to_string(y) +
|
||||
std::to_string(x) + " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
|
||||
for (int s = 0; s < block_size.w; ++s) {
|
||||
const std::string sind = std::to_string(s);
|
||||
for (int z = 0; z < block_size.z; ++z) {
|
||||
const std::string zind = std::to_string(z);
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yind = std::to_string(y);
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
const std::string xind = std::to_string(x);
|
||||
c += " ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) +
|
||||
" = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -427,7 +462,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
c += " " + weights_global_ptr +
|
||||
" filters_loc = args.weights.GetPtr() + (DST_S * "
|
||||
"args.src_tensor.Height() + DST_Y * " +
|
||||
std::to_string(block_size.z) +
|
||||
std::to_string(block_size.w) +
|
||||
") * 4 * args.src_tensor.Slices();\n";
|
||||
} else {
|
||||
c += " " + weights_global_ptr +
|
||||
@ -472,24 +507,28 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
}
|
||||
if (buffer_type) {
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yck = "yck" + std::to_string(y);
|
||||
const std::string yind = std::to_string(y);
|
||||
const std::string yck = "yck" + yind;
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
const std::string xck = "xck" + std::to_string(x);
|
||||
const std::string xind = std::to_string(x);
|
||||
const std::string xck = "xck" + xind;
|
||||
std::string xc =
|
||||
is1x1 ? "min(" + dst_x[x] + ", args.src_tensor.Width() - 1)" : xck;
|
||||
std::string yc =
|
||||
is1x1 ? "min(" + dst_y[y] + ", args.src_tensor.Height() - 1)" : yck;
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
c += " int src_a_" + id + " = " + yc +
|
||||
" * args.src_tensor.Width() + " + xc + ";\n";
|
||||
std::string id = generate_id(xind, yind, "");
|
||||
c += " int src_a" + id + " = " + yc + " * args.src_tensor.Width() + " +
|
||||
xc + ";\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto declare_src = [&]() {
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yind = std::to_string(y);
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
const std::string id = std::to_string(y) + std::to_string(x);
|
||||
const std::string xind = std::to_string(x);
|
||||
const std::string id = generate_id(xind, yind, "");
|
||||
c += " " + weights_data_type + " src" + id + ";\n";
|
||||
}
|
||||
}
|
||||
@ -498,27 +537,28 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
auto read_src = [&]() {
|
||||
const std::string cl_type = ToCLDataType(conv_params.weights_data_type);
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yind = std::to_string(y);
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
const std::string xind = std::to_string(x);
|
||||
std::string id = generate_id(xind, yind, "");
|
||||
if (buffer_type) {
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
if (is1x1) {
|
||||
c += " src" + id + " = args.src_tensor.Read<" + cl_type +
|
||||
">(src_a_" + id + ");\n";
|
||||
">(src_a" + id + ");\n";
|
||||
} else {
|
||||
std::string condition =
|
||||
"mx" + std::to_string(x) + " && my" + std::to_string(y);
|
||||
if (conditional_read) {
|
||||
c += " src" + id + " = " + condition +
|
||||
" ? args.src_tensor.Read<" + cl_type + ">(src_a_" + id +
|
||||
" ? args.src_tensor.Read<" + cl_type + ">(src_a" + id +
|
||||
") : (FLT4)(0.0f);\n";
|
||||
} else {
|
||||
c += " src" + id + " = args.src_tensor.Read<" + cl_type +
|
||||
">(src_a_" + id + ") * (FLT)(" + condition + ");\n";
|
||||
">(src_a" + id + ") * (FLT)(" + condition + ");\n";
|
||||
}
|
||||
}
|
||||
c += " src_a_" + id + " += src_layer_offset;\n";
|
||||
c += " src_a" + id + " += src_layer_offset;\n";
|
||||
} else {
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
const std::string xc = is1x1 ? dst_x[x] : "xck" + std::to_string(x);
|
||||
const std::string yc = is1x1 ? dst_y[y] : "yck" + std::to_string(y);
|
||||
c += " src" + id + " = args.src_tensor.Read<" + cl_type + ">(" +
|
||||
@ -532,15 +572,19 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
conv_params.weights_data_type == DataType::FLOAT16);
|
||||
auto conv_core = [&](int shared_offset) {
|
||||
const std::string channels[] = {"x", "y", "z", "w"};
|
||||
for (int z = 0; z < block_size.z; ++z) {
|
||||
for (int s = 0; s < block_size.w; ++s) {
|
||||
const std::string sind = std::to_string(s);
|
||||
if (weights_type_as_accum_type) {
|
||||
for (int ch = 0; ch < 4; ++ch) {
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yind = std::to_string(y);
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
const std::string xind = std::to_string(x);
|
||||
std::string R = "r" + generate_id_full(xind, yind, "", sind);
|
||||
std::string S = "src" + generate_id(xind, yind, "");
|
||||
if (use_simd_broadcast) {
|
||||
int simd_id = (z * 4 + ch + shared_offset) / simd_size;
|
||||
int thread_id = (z * 4 + ch + shared_offset) % simd_size;
|
||||
int simd_id = (s * 4 + ch + shared_offset) / simd_size;
|
||||
int thread_id = (s * 4 + ch + shared_offset) % simd_size;
|
||||
std::string w_val_x = "sub_group_broadcast(simd_w" +
|
||||
std::to_string(simd_id) + ".x, " +
|
||||
std::to_string(thread_id) + "u)";
|
||||
@ -553,38 +597,39 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
std::string w_val_w = "sub_group_broadcast(simd_w" +
|
||||
std::to_string(simd_id) + ".w, " +
|
||||
std::to_string(thread_id) + "u)";
|
||||
c += " r" + std::to_string(z) + id + ".x += " + w_val_x +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " r" + std::to_string(z) + id + ".y += " + w_val_y +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " r" + std::to_string(z) + id + ".z += " + w_val_z +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " r" + std::to_string(z) + id + ".w += " + w_val_w +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " " + R + ".x += " + w_val_x + " * " + S + "." +
|
||||
channels[ch] + ";\n";
|
||||
c += " " + R + ".y += " + w_val_y + " * " + S + "." +
|
||||
channels[ch] + ";\n";
|
||||
c += " " + R + ".z += " + w_val_z + " * " + S + "." +
|
||||
channels[ch] + ";\n";
|
||||
c += " " + R + ".w += " + w_val_w + " * " + S + "." +
|
||||
channels[ch] + ";\n";
|
||||
} else {
|
||||
const std::string weight_id =
|
||||
std::to_string(z * 4 + ch + shared_offset);
|
||||
std::to_string(s * 4 + ch + shared_offset);
|
||||
std::string w_val;
|
||||
if (conv_params.AreWeightsBuffer()) {
|
||||
w_val = "weights_cache[" + weight_id + "]";
|
||||
} else {
|
||||
w_val = "f" + weight_id;
|
||||
}
|
||||
c += " r" + std::to_string(z) + id + " += " + w_val +
|
||||
" * src" + id + "." + channels[ch] + ";\n";
|
||||
c += " " + R + " += " + w_val + " * " + S + "." +
|
||||
channels[ch] + ";\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // F32_F16 precision and weights type is float16
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yind = std::to_string(y);
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
std::string id = std::to_string(y) + std::to_string(x);
|
||||
std::string R = "r" + std::to_string(z) + id;
|
||||
std::string S = "src" + id;
|
||||
const std::string xind = std::to_string(x);
|
||||
std::string R = "r" + generate_id_full(xind, yind, "", sind);
|
||||
std::string S = "src" + generate_id(xind, yind, "");
|
||||
std::vector<std::string> F(4);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
std::string weight_id = std::to_string(z * 4 + i + shared_offset);
|
||||
std::string weight_id = std::to_string(s * 4 + i + shared_offset);
|
||||
if (conv_params.AreWeightsBuffer()) {
|
||||
F[i] = "weights_cache[" + weight_id + "]";
|
||||
} else {
|
||||
@ -633,7 +678,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
} else if (conv_params.AreWeightsBuffer()) { // GLOBAL_MEM/CONSTANT_MEM
|
||||
c += " weights_cache = filters_loc;\n";
|
||||
} else { // TEXTURES_MEM
|
||||
for (int dst_s = 0; dst_s < block_size.z; ++dst_s) {
|
||||
for (int dst_s = 0; dst_s < block_size.w; ++dst_s) {
|
||||
std::string f_y = is1x1 ? "s" : "filter_offset";
|
||||
if (conv_params.different_weights_for_height) {
|
||||
f_y = "DST_Y * args.src_tensor.Slices() + s";
|
||||
@ -660,7 +705,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
conv_core(0);
|
||||
for (int i = 1; i < conv_params.src_depth_loop_size; ++i) {
|
||||
read_src();
|
||||
conv_core(i * block_size.z * 4);
|
||||
conv_core(i * block_size.w * 4);
|
||||
c += " s += 1;\n";
|
||||
}
|
||||
if (conv_params.AreWeightsBuffer()) {
|
||||
@ -675,13 +720,13 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) {
|
||||
c += GenerateAsyncUpload("weights_cache", "args.biases.GetPtr()", "DST_S",
|
||||
block_size.z);
|
||||
block_size.w);
|
||||
} else if (conv_params.weights_upload_type ==
|
||||
ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
c += GenerateUploadByThreads("weights_cache", "args.biases.GetPtr()",
|
||||
"DST_S", "lid", total_work_items,
|
||||
block_size.z);
|
||||
block_size.w);
|
||||
c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
|
||||
} else {
|
||||
c += " weights_cache = args.biases.GetPtr() + DST_S;\n";
|
||||
@ -694,21 +739,23 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
c += " return;\n";
|
||||
c += " }\n";
|
||||
}
|
||||
for (int z = 0; z < block_size.z; ++z) {
|
||||
const std::string sz = std::to_string(z);
|
||||
c += " if (DST_S + " + sz + " >= args.dst_tensor.Slices()) return;\n";
|
||||
for (int s = 0; s < block_size.w; ++s) {
|
||||
const std::string sind = std::to_string(s);
|
||||
c += " if (DST_S + " + sind + " >= args.dst_tensor.Slices()) return;\n";
|
||||
c += " {\n";
|
||||
if (conv_params.AreWeightsBuffer()) {
|
||||
c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sz + "]);\n";
|
||||
c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sind + "]);\n";
|
||||
} else {
|
||||
c += " FLT4 bias_val = args.biases.Read(DST_S + " + sz + ");\n";
|
||||
c += " FLT4 bias_val = args.biases.Read(DST_S + " + sind + ");\n";
|
||||
}
|
||||
for (int y = 0; y < block_size.y; ++y) {
|
||||
const std::string yind = std::to_string(y);
|
||||
for (int x = 0; x < block_size.x; ++x) {
|
||||
const std::string xind = std::to_string(x);
|
||||
const std::string xs = dst_x[x];
|
||||
const std::string ys = dst_y[y];
|
||||
const std::string zs = "DST_S + " + sz;
|
||||
const std::string r_id = sz + std::to_string(y) + std::to_string(x);
|
||||
const std::string zs = "DST_S + " + sind;
|
||||
const std::string id = generate_id_full(xind, yind, "", sind);
|
||||
bool need_x_check = x != 0;
|
||||
bool need_y_check = y != 0;
|
||||
if (need_x_check && need_y_check) {
|
||||
@ -721,7 +768,7 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info,
|
||||
} else {
|
||||
c += " {\n";
|
||||
}
|
||||
c += " FLT4 res = TO_FLT4(r" + r_id + ") + bias_val;\n";
|
||||
c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
|
||||
c += " args.dst_tensor.Write(res, " + xs + ", " + ys + ", " + zs +
|
||||
");\n";
|
||||
c += " }\n";
|
||||
@ -738,7 +785,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
int src_depth, int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
|
||||
bool different_weights_for_height, const BHWC* dst_shape) {
|
||||
ConvParams conv_params;
|
||||
conv_params.linear_hw = false;
|
||||
conv_params.linear_spatial = false;
|
||||
conv_params.weights_data_type =
|
||||
DeduceDataTypeFromPrecision(definition.precision);
|
||||
conv_params.x_kernel_is_1 = x_kernel_is_1;
|
||||
@ -750,43 +797,43 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.work_group_launch_order = int3(2, 0, 1);
|
||||
conv_params.fixed_work_group_size = true;
|
||||
} else {
|
||||
conv_params.linear_hw = true;
|
||||
conv_params.linear_spatial = true;
|
||||
work_group_size_ = int3(32, 1, 1);
|
||||
conv_params.work_group_launch_order = int3(1, 0, 2);
|
||||
conv_params.fixed_work_group_size = true;
|
||||
}
|
||||
conv_params.block_size = int3(2, 1, 4);
|
||||
conv_params.block_size = int4(2, 1, 1, 4);
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
|
||||
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||
conv_params.block_size.z = 4;
|
||||
conv_params.block_size.w = 4;
|
||||
} else if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||
conv_params.block_size.z = 2;
|
||||
conv_params.block_size.w = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = dst_depth;
|
||||
conv_params.block_size.w = dst_depth;
|
||||
}
|
||||
if (dst_shape) {
|
||||
int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
|
||||
float task_size_per_cu =
|
||||
static_cast<float>(task_size) / device_info.compute_units_count;
|
||||
int block_size = conv_params.block_size.x * conv_params.block_size.y *
|
||||
conv_params.block_size.z;
|
||||
conv_params.block_size.w;
|
||||
float threads_per_cu = task_size_per_cu / block_size;
|
||||
float warps_per_cu = threads_per_cu / 32 /*warp_size*/;
|
||||
if (warps_per_cu < 8.0f) {
|
||||
conv_params.block_size.x = 1;
|
||||
}
|
||||
if (warps_per_cu < 4.0f && conv_params.block_size.z >= 4) {
|
||||
conv_params.block_size.z /= 2;
|
||||
if (warps_per_cu < 4.0f && conv_params.block_size.w >= 4) {
|
||||
conv_params.block_size.w /= 2;
|
||||
}
|
||||
if (warps_per_cu < 2.0f && conv_params.block_size.z >= 2) {
|
||||
conv_params.block_size.z /= 2;
|
||||
if (warps_per_cu < 2.0f && conv_params.block_size.w >= 2) {
|
||||
conv_params.block_size.w /= 2;
|
||||
}
|
||||
}
|
||||
if (src_depth % 2 == 0) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) {
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
|
||||
conv_params.src_depth_loop_size = 4;
|
||||
}
|
||||
} else if (device_info.IsPowerVR()) {
|
||||
@ -795,7 +842,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.work_group_launch_order = int3(2, 0, 1);
|
||||
conv_params.fixed_work_group_size = true;
|
||||
} else {
|
||||
conv_params.linear_hw = true;
|
||||
conv_params.linear_spatial = true;
|
||||
work_group_size_ = int3(32, 1, 1);
|
||||
conv_params.work_group_launch_order = int3(1, 0, 2);
|
||||
conv_params.fixed_work_group_size = true;
|
||||
@ -803,28 +850,28 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.weights_data_type =
|
||||
definition.precision == CalculationsPrecision::F16 ? DataType::FLOAT16
|
||||
: DataType::FLOAT32;
|
||||
conv_params.block_size = int3(1, 1, 4);
|
||||
conv_params.block_size = int4(1, 1, 1, 4);
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
conv_params.weights_upload_type =
|
||||
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
|
||||
if (dst_depth % 8 == 0 || dst_depth >= 32) {
|
||||
conv_params.block_size.z = 8;
|
||||
conv_params.block_size.w = 8;
|
||||
} else if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||
conv_params.block_size.z = 4;
|
||||
conv_params.block_size.w = 4;
|
||||
} else if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||
conv_params.block_size.z = 2;
|
||||
conv_params.block_size.w = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = dst_depth;
|
||||
conv_params.block_size.w = dst_depth;
|
||||
}
|
||||
if (definition.precision == CalculationsPrecision::F16) {
|
||||
conv_params.block_size.z = std::min(4, conv_params.block_size.z);
|
||||
conv_params.block_size.w = std::min(4, conv_params.block_size.w);
|
||||
if (src_depth % 2 == 0) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) {
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
|
||||
conv_params.src_depth_loop_size = 4;
|
||||
}
|
||||
if (conv_params.block_size.z == 1) {
|
||||
if (conv_params.block_size.w == 1) {
|
||||
if (src_depth % 2 == 0) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
@ -848,20 +895,20 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.fixed_work_group_size = true;
|
||||
}
|
||||
|
||||
conv_params.block_size = int3(2, 1, 1);
|
||||
conv_params.block_size = int4(2, 1, 1, 1);
|
||||
if (x_kernel_is_1 && y_kernel_is_1) {
|
||||
conv_params.block_size.y = 2;
|
||||
}
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM;
|
||||
if (dst_depth % 8 == 0 || dst_depth >= 32) {
|
||||
conv_params.block_size.z = 8;
|
||||
conv_params.block_size.w = 8;
|
||||
} else if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||
conv_params.block_size.z = 4;
|
||||
conv_params.block_size.w = 4;
|
||||
} else if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||
conv_params.block_size.z = 2;
|
||||
conv_params.block_size.w = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = 1;
|
||||
conv_params.block_size.w = 1;
|
||||
}
|
||||
if (src_depth % 2 == 0 && src_depth >= 16) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
@ -878,20 +925,20 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
}
|
||||
if (block_size == 8) {
|
||||
if (dst_depth == 1 || dst_depth == 3) {
|
||||
conv_params.block_size = int3(2, 2, 1);
|
||||
conv_params.block_size = int4(2, 2, 1, 1);
|
||||
} else {
|
||||
conv_params.block_size = int3(2, 2, 2);
|
||||
conv_params.block_size = int4(2, 2, 1, 2);
|
||||
}
|
||||
} else if (block_size == 4) {
|
||||
if (dst_depth == 1 || dst_depth == 3) {
|
||||
conv_params.block_size = int3(2, 2, 1);
|
||||
conv_params.block_size = int4(2, 2, 1, 1);
|
||||
} else {
|
||||
conv_params.block_size = int3(2, 1, 2);
|
||||
conv_params.block_size = int4(2, 1, 1, 2);
|
||||
}
|
||||
} else if (block_size == 2) {
|
||||
conv_params.block_size = int3(2, 1, 1);
|
||||
conv_params.block_size = int4(2, 1, 1, 1);
|
||||
} else {
|
||||
conv_params.block_size = int3(1, 1, 1);
|
||||
conv_params.block_size = int4(1, 1, 1, 1);
|
||||
}
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
MaliInfo mali_info = device_info.mali_info;
|
||||
@ -907,7 +954,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.fixed_work_group_size = false;
|
||||
conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
|
||||
} else if (device_info.IsAdreno()) {
|
||||
conv_params.block_size = int3(2, 2, 1);
|
||||
conv_params.block_size = int4(2, 2, 1, 1);
|
||||
work_group_size_ = int3(8, 2, 1);
|
||||
conv_params.work_group_launch_order = int3(0, 1, 2);
|
||||
conv_params.fixed_work_group_size = false;
|
||||
@ -924,12 +971,12 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.work_group_launch_order = int3(0, 1, 2);
|
||||
conv_params.fixed_work_group_size = true;
|
||||
} else {
|
||||
conv_params.linear_hw = true;
|
||||
conv_params.linear_spatial = true;
|
||||
work_group_size_ = int3(16, 1, 1);
|
||||
conv_params.work_group_launch_order = int3(0, 1, 2);
|
||||
conv_params.fixed_work_group_size = true;
|
||||
}
|
||||
conv_params.block_size = int3(1, 1, 4);
|
||||
conv_params.block_size = int4(1, 1, 1, 4);
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
int sub_group_size = 16;
|
||||
if (definition.precision != CalculationsPrecision::F32_F16 &&
|
||||
@ -944,36 +991,36 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
|
||||
}
|
||||
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||
conv_params.block_size.z = 4;
|
||||
conv_params.block_size.w = 4;
|
||||
} else if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||
conv_params.block_size.z = 2;
|
||||
conv_params.block_size.w = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = dst_depth;
|
||||
conv_params.block_size.w = dst_depth;
|
||||
}
|
||||
if (src_depth % 2 == 0) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) {
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
|
||||
conv_params.src_depth_loop_size = 4;
|
||||
}
|
||||
} else {
|
||||
conv_params.block_size = int3(1, 1, 4);
|
||||
conv_params.block_size = int4(1, 1, 1, 4);
|
||||
work_group_size_ = int3(8, 2, 1);
|
||||
conv_params.work_group_launch_order = int3(0, 1, 2);
|
||||
conv_params.fixed_work_group_size = false;
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
|
||||
if (dst_depth % 4 == 0 || dst_depth >= 8) {
|
||||
conv_params.block_size.z = 4;
|
||||
conv_params.block_size.w = 4;
|
||||
} else if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||
conv_params.block_size.z = 2;
|
||||
conv_params.block_size.w = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = dst_depth;
|
||||
conv_params.block_size.w = dst_depth;
|
||||
}
|
||||
if (src_depth % 2 == 0) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.z <= 2) {
|
||||
if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
|
||||
conv_params.src_depth_loop_size = 4;
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ class ConvPowerVR : public GPUOperation {
|
||||
ConvWeightsDescription GetConvWeightsDescription() const {
|
||||
ConvWeightsDescription desc;
|
||||
desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
|
||||
desc.output_group_size = conv_params_.block_size.z;
|
||||
desc.output_group_size = conv_params_.block_size.w;
|
||||
return desc;
|
||||
}
|
||||
|
||||
@ -82,10 +82,10 @@ class ConvPowerVR : public GPUOperation {
|
||||
// weights, so for PowerVR in this kernel we have F32 weights for
|
||||
// F32_F16 precision mode
|
||||
DataType weights_data_type; // used for weights and biases
|
||||
int3 block_size;
|
||||
int4 block_size; // WHDS
|
||||
int3 work_group_launch_order;
|
||||
bool fixed_work_group_size;
|
||||
bool linear_hw;
|
||||
bool linear_spatial; // spatial dimensions are Width/Height/Depth
|
||||
bool different_weights_for_height;
|
||||
int src_depth_loop_size;
|
||||
WeightsUploadType weights_upload_type;
|
||||
@ -178,8 +178,10 @@ class ConvPowerVR : public GPUOperation {
|
||||
const OperationDef& op_def, bool stride_correction,
|
||||
const ConvParams& conv_params);
|
||||
|
||||
int4 stride_padding_;
|
||||
int4 kernel_dilation_;
|
||||
int4 stride_;
|
||||
int4 padding_;
|
||||
int4 kernel_size_;
|
||||
int4 dilation_;
|
||||
ConvParams conv_params_;
|
||||
};
|
||||
|
||||
@ -214,7 +216,7 @@ void ConvPowerVR::UploadBias(const tflite::gpu::Tensor<Linear, T>& bias) {
|
||||
const int float_size = conv_params_.weights_data_type == DataType::FLOAT32
|
||||
? sizeof(float)
|
||||
: sizeof(half);
|
||||
int aligned_channels = AlignByN(bias.shape.v, 4 * conv_params_.block_size.z);
|
||||
int aligned_channels = AlignByN(bias.shape.v, 4 * conv_params_.block_size.w);
|
||||
desc.size = float_size * aligned_channels;
|
||||
desc.data.resize(desc.size);
|
||||
if (conv_params_.weights_data_type == DataType::FLOAT32) {
|
||||
@ -235,7 +237,7 @@ void ConvPowerVR::UploadBias(const tflite::gpu::Tensor<Linear, T>& bias) {
|
||||
template <DataType T>
|
||||
void ConvPowerVR::UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights) {
|
||||
const int dst_slices =
|
||||
AlignByN(DivideRoundUp(weights.shape.o, 4), conv_params_.block_size.z);
|
||||
AlignByN(DivideRoundUp(weights.shape.o, 4), conv_params_.block_size.w);
|
||||
const int src_slices = DivideRoundUp(weights.shape.i, 4);
|
||||
|
||||
const bool f32_weights = conv_params_.weights_data_type == DataType::FLOAT32;
|
||||
@ -249,19 +251,19 @@ void ConvPowerVR::UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights) {
|
||||
if (f32_weights) {
|
||||
float4* ptr = reinterpret_cast<float4*>(data.data());
|
||||
if (conv_params_.AreWeightsBuffer()) {
|
||||
RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.z,
|
||||
RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w,
|
||||
absl::MakeSpan(ptr, elements_count));
|
||||
} else {
|
||||
RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.z,
|
||||
RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w,
|
||||
absl::MakeSpan(ptr, elements_count));
|
||||
}
|
||||
} else {
|
||||
half4* ptr = reinterpret_cast<half4*>(data.data());
|
||||
if (conv_params_.AreWeightsBuffer()) {
|
||||
RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.z,
|
||||
RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w,
|
||||
absl::MakeSpan(ptr, elements_count));
|
||||
} else {
|
||||
RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.z,
|
||||
RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w,
|
||||
absl::MakeSpan(ptr, elements_count));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user