Improved selection of block size for Mali.
PiperOrigin-RevId: 300605347 Change-Id: If538c67ffff73e5b45c2baa6941d90863b8399db
This commit is contained in:
parent
6625cee185
commit
13367effdd
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
|
||||
|
||||
#include <array>
|
||||
#include <cfloat>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
@ -202,74 +201,6 @@ std::string GenerateConvBuffer1x1(
|
||||
return c;
|
||||
}
|
||||
|
||||
// task_size as amount of FLT4 processed elements.
|
||||
int GetRecommendedBlockSizeForConv(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
int task_size) {
|
||||
const float task_size_per_cu =
|
||||
task_size / static_cast<float>(device.GetInfo().compute_units_count);
|
||||
int block_size = 1;
|
||||
float threshold_1 = FLT_MAX;
|
||||
float threshold_2 = FLT_MAX;
|
||||
float threshold_4 = FLT_MAX;
|
||||
if (!device.IsMali()) {
|
||||
return 1;
|
||||
}
|
||||
MaliInfo mali_info = device.GetInfo().mali_info;
|
||||
switch (definition.precision) {
|
||||
case CalculationsPrecision::F16:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
threshold_4 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 256.0f * 2.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
threshold_4 = 256.0f * 16.0f;
|
||||
} else if (mali_info.IsBifrostGen3()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 6.0f;
|
||||
threshold_4 = 256.0f * 16.0f;
|
||||
}
|
||||
break;
|
||||
case CalculationsPrecision::F32_F16:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 3.0f;
|
||||
threshold_4 = 256.0f * 32.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 256.0f * 2.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsBifrostGen3()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
}
|
||||
break;
|
||||
case CalculationsPrecision::F32:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 128.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
} else if (mali_info.IsBifrostGen3()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 12.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (task_size_per_cu <= threshold_1) {
|
||||
block_size = 1;
|
||||
} else if (task_size_per_cu <= threshold_2) {
|
||||
block_size = 2;
|
||||
} else if (task_size_per_cu <= threshold_4) {
|
||||
block_size = 4;
|
||||
} else {
|
||||
block_size = 8;
|
||||
}
|
||||
return block_size;
|
||||
}
|
||||
|
||||
ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const BHWC& shape, int src_depth,
|
||||
@ -295,7 +226,7 @@ ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
|
||||
|
||||
int task_size = shape.w * shape.b * shape.h * dst_depth;
|
||||
int block_size =
|
||||
GetRecommendedBlockSizeForConv(device, definition, task_size);
|
||||
GetRecommendedBlockSizeForConv(device, definition.precision, task_size);
|
||||
|
||||
if (!can_use_flt8 && block_size > 4) {
|
||||
block_size = 4;
|
||||
|
@ -130,21 +130,21 @@ std::string GenerateBlockCoords(const int3& block_size,
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
const CLDevice& device)
|
||||
const CLDevice& device, const BHWC* dst_shape)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w,
|
||||
-attr.padding.prepended.h),
|
||||
kernel_dilation_(attr.weights.shape.w, attr.weights.shape.h,
|
||||
attr.dilations.w, attr.dilations.h),
|
||||
conv_params_(GuessBestParams(device, definition, attr)) {}
|
||||
conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
const CLDevice& device)
|
||||
const CLDevice& device, const BHWC* dst_shape)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(1, 1, 0, 0),
|
||||
kernel_dilation_(1, 1, 1, 1),
|
||||
conv_params_(GuessBestParams(device, definition, attr)) {}
|
||||
conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition)
|
||||
: GPUOperation(definition),
|
||||
@ -628,7 +628,7 @@ std::string GenerateConv(
|
||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
const CLDevice& device, const OperationDef& definition, int src_depth,
|
||||
int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
|
||||
bool different_weights_for_height) const {
|
||||
bool different_weights_for_height, const BHWC* dst_shape) const {
|
||||
ConvParams conv_params;
|
||||
conv_params.linear_hw = false;
|
||||
conv_params.weights_data_type =
|
||||
@ -741,17 +741,45 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
} else if (device.IsMali()) {
|
||||
conv_params.block_size = int3(2, 1, 1);
|
||||
int block_size = 2;
|
||||
if (dst_shape) {
|
||||
int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
|
||||
block_size = GetRecommendedBlockSizeForConv(device, definition.precision,
|
||||
task_size);
|
||||
}
|
||||
if (!x_kernel_is_1 || !y_kernel_is_1) {
|
||||
block_size = std::min(block_size, 4);
|
||||
}
|
||||
if (block_size == 8) {
|
||||
if (dst_depth == 1 || dst_depth == 3) {
|
||||
conv_params.block_size = int3(2, 2, 1);
|
||||
} else {
|
||||
conv_params.block_size = int3(2, 2, 2);
|
||||
}
|
||||
} else if (block_size == 4) {
|
||||
if (dst_depth == 1 || dst_depth == 3) {
|
||||
conv_params.block_size = int3(2, 2, 1);
|
||||
} else {
|
||||
conv_params.block_size = int3(2, 1, 2);
|
||||
}
|
||||
} else if (block_size == 2) {
|
||||
conv_params.block_size = int3(2, 1, 1);
|
||||
} else {
|
||||
conv_params.block_size = int3(1, 1, 1);
|
||||
}
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
MaliInfo mali_info = device.GetInfo().mali_info;
|
||||
if (src_depth % 2 == 0 && block_size <= 2 && !mali_info.IsMidgard()) {
|
||||
conv_params.src_depth_loop_size = 2;
|
||||
}
|
||||
if (src_depth % 4 == 0 && block_size == 1 && !mali_info.IsMidgard() &&
|
||||
definition.precision == CalculationsPrecision::F16) {
|
||||
conv_params.src_depth_loop_size = 4;
|
||||
}
|
||||
conv_params.work_group_size = int3(4, 4, 1);
|
||||
conv_params.work_group_launch_order = int3(0, 1, 2);
|
||||
conv_params.fixed_work_group_size = false;
|
||||
conv_params.src_depth_loop_size = 1;
|
||||
conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
|
||||
if (dst_depth % 2 == 0 || dst_depth >= 4) {
|
||||
conv_params.block_size.z = 2;
|
||||
} else {
|
||||
conv_params.block_size.z = 1;
|
||||
}
|
||||
} else {
|
||||
conv_params.block_size = int3(1, 1, 4);
|
||||
conv_params.work_group_size = int3(8, 2, 1);
|
||||
@ -779,7 +807,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
|
||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
const CLDevice& device, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr) const {
|
||||
const Convolution2DAttributes& attr, const BHWC* dst_shape) const {
|
||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||
const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 &&
|
||||
@ -791,16 +819,16 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
attr.padding.prepended.h == 0 &&
|
||||
attr.padding.appended.h == 0;
|
||||
return GuessBestParams(device, definition, src_depth, dst_depth,
|
||||
x_kernel_is_1, y_kernel_is_1, false);
|
||||
x_kernel_is_1, y_kernel_is_1, false, dst_shape);
|
||||
}
|
||||
|
||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
const CLDevice& device, const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr) const {
|
||||
const FullyConnectedAttributes& attr, const BHWC* dst_shape) const {
|
||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||
ConvPowerVR::ConvParams params = GuessBestParams(
|
||||
device, definition, src_depth, dst_depth, true, true, false);
|
||||
device, definition, src_depth, dst_depth, true, true, false, dst_shape);
|
||||
params.work_group_size.x *= params.work_group_size.y;
|
||||
params.work_group_size.y = 1;
|
||||
params.block_size.x *= params.block_size.y;
|
||||
@ -810,11 +838,11 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
|
||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(
|
||||
const CLDevice& device, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr) const {
|
||||
const Convolution2DAttributes& attr, const BHWC* dst_shape) const {
|
||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||
ConvPowerVR::ConvParams params = GuessBestParams(
|
||||
device, definition, src_depth, dst_depth, true, true, true);
|
||||
device, definition, src_depth, dst_depth, true, true, true, dst_shape);
|
||||
params.block_size.x *= params.block_size.y;
|
||||
params.block_size.y = 1;
|
||||
return params;
|
||||
@ -823,26 +851,27 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(
|
||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
ConvPowerVR* result) {
|
||||
*result = ConvPowerVR(definition, attr, *creation_context.device);
|
||||
ConvPowerVR* result, const BHWC* dst_shape) {
|
||||
*result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
|
||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||
}
|
||||
|
||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
ConvPowerVR* result) {
|
||||
*result = ConvPowerVR(definition, attr, *creation_context.device);
|
||||
ConvPowerVR* result, const BHWC* dst_shape) {
|
||||
*result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
|
||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||
}
|
||||
|
||||
Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
ConvPowerVR* result) {
|
||||
ConvPowerVR* result,
|
||||
const BHWC* dst_shape) {
|
||||
*result = ConvPowerVR(definition);
|
||||
result->conv_params_ = result->GuessBestParamsWinograd(
|
||||
*creation_context.device, definition, attr);
|
||||
*creation_context.device, definition, attr, dst_shape);
|
||||
return result->UploadDataForWinograd4x4To6x6(
|
||||
attr.weights, *creation_context.device, creation_context.context);
|
||||
}
|
||||
|
@ -79,9 +79,11 @@ class ConvPowerVR : public GPUOperation {
|
||||
};
|
||||
|
||||
ConvPowerVR(const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const CLDevice& device);
|
||||
const Convolution2DAttributes& attr, const CLDevice& device,
|
||||
const BHWC* dst_shape = nullptr);
|
||||
ConvPowerVR(const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr, const CLDevice& device);
|
||||
const FullyConnectedAttributes& attr, const CLDevice& device,
|
||||
const BHWC* dst_shape = nullptr);
|
||||
explicit ConvPowerVR(const OperationDef& definition);
|
||||
|
||||
template <DataType T>
|
||||
@ -100,16 +102,17 @@ class ConvPowerVR : public GPUOperation {
|
||||
friend Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
ConvPowerVR* result);
|
||||
ConvPowerVR* result, const BHWC* dst_shape);
|
||||
|
||||
friend Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
ConvPowerVR* result);
|
||||
ConvPowerVR* result, const BHWC* dst_shape);
|
||||
|
||||
friend Status CreateConvPowerVRWino4x4To6x6(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, ConvPowerVR* result);
|
||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||
const BHWC* dst_shape);
|
||||
|
||||
friend std::string GenerateConv(
|
||||
const CLDevice& device, const OperationDef& op_def,
|
||||
@ -118,18 +121,22 @@ class ConvPowerVR : public GPUOperation {
|
||||
|
||||
ConvParams GuessBestParams(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr) const;
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC* dst_shape = nullptr) const;
|
||||
ConvParams GuessBestParams(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr) const;
|
||||
const FullyConnectedAttributes& attr,
|
||||
const BHWC* dst_shape = nullptr) const;
|
||||
ConvParams GuessBestParamsWinograd(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr) const;
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC* dst_shape = nullptr) const;
|
||||
ConvParams GuessBestParams(const CLDevice& device,
|
||||
const OperationDef& definition, int src_depth,
|
||||
int dst_depth, bool x_kernel_is_1,
|
||||
bool y_kernel_is_1,
|
||||
bool different_weights_for_height) const;
|
||||
bool different_weights_for_height,
|
||||
const BHWC* dst_shape = nullptr) const;
|
||||
|
||||
Status BindArguments();
|
||||
int3 GetGridSize() const;
|
||||
@ -206,17 +213,18 @@ Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
ConvPowerVR* result);
|
||||
ConvPowerVR* result, const BHWC* dst_shape = nullptr);
|
||||
|
||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
ConvPowerVR* result);
|
||||
ConvPowerVR* result, const BHWC* dst_shape = nullptr);
|
||||
|
||||
Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
ConvPowerVR* result);
|
||||
ConvPowerVR* result,
|
||||
const BHWC* dst_shape = nullptr);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -721,6 +722,80 @@ int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size) {
|
||||
return {1, 1, 1};
|
||||
}
|
||||
|
||||
int GetRecommendedBlockSizeForConv(const CLDevice& device,
|
||||
CalculationsPrecision precision,
|
||||
int task_size) {
|
||||
const float task_size_per_cu =
|
||||
task_size / static_cast<float>(device.GetInfo().compute_units_count);
|
||||
int block_size = 1;
|
||||
float threshold_1 = FLT_MAX;
|
||||
float threshold_2 = FLT_MAX;
|
||||
float threshold_4 = FLT_MAX;
|
||||
if (!device.IsMali()) {
|
||||
return 1;
|
||||
}
|
||||
MaliInfo mali_info = device.GetInfo().mali_info;
|
||||
switch (precision) {
|
||||
case CalculationsPrecision::F16:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
threshold_4 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 256.0f * 2.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
threshold_4 = 256.0f * 16.0f;
|
||||
} else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 6.0f;
|
||||
threshold_4 = 256.0f * 16.0f;
|
||||
} else if (mali_info.IsMidgard()) {
|
||||
threshold_1 = 256.0f * 4.0f;
|
||||
threshold_2 = 256.0f * 16.0f;
|
||||
}
|
||||
break;
|
||||
case CalculationsPrecision::F32_F16:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 3.0f;
|
||||
threshold_4 = 256.0f * 32.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 256.0f * 2.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 8.0f;
|
||||
} else if (mali_info.IsMidgard()) {
|
||||
threshold_1 = 256.0f * 4.0f;
|
||||
}
|
||||
break;
|
||||
case CalculationsPrecision::F32:
|
||||
if (mali_info.IsBifrostGen1()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
} else if (mali_info.IsBifrostGen2()) {
|
||||
threshold_1 = 128.0f;
|
||||
threshold_2 = 256.0f * 4.0f;
|
||||
} else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
|
||||
threshold_1 = 256.0f;
|
||||
threshold_2 = 256.0f * 12.0f;
|
||||
} else if (mali_info.IsMidgard()) {
|
||||
threshold_1 = 256.0f * 16.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (task_size_per_cu <= threshold_1) {
|
||||
block_size = 1;
|
||||
} else if (task_size_per_cu <= threshold_2) {
|
||||
block_size = 2;
|
||||
} else if (task_size_per_cu <= threshold_4) {
|
||||
block_size = 4;
|
||||
} else {
|
||||
block_size = 8;
|
||||
}
|
||||
return block_size;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -305,6 +305,11 @@ float4 GetMaskForLastPlane(int channels);
|
||||
// returns first work group from wgs that has size not bigger than max_wg_size
|
||||
// if no suitable groups among wgs, returns {1, 1, 1}
|
||||
int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size);
|
||||
|
||||
// task_size as amount of FLT4 processed elements.
|
||||
int GetRecommendedBlockSizeForConv(const CLDevice& device,
|
||||
CalculationsPrecision precision,
|
||||
int task_size);
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -100,7 +100,8 @@ Status SelectConvolutionMali(const Convolution2DAttributes& attr,
|
||||
*ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv));
|
||||
} else {
|
||||
ConvPowerVR conv;
|
||||
RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv));
|
||||
RETURN_IF_ERROR(
|
||||
CreateConvPowerVR(creation_context, op_def, attr, &conv, &dst_shape));
|
||||
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
|
||||
}
|
||||
return OkStatus();
|
||||
@ -118,8 +119,8 @@ Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr,
|
||||
*ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv));
|
||||
} else {
|
||||
ConvPowerVR conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv));
|
||||
RETURN_IF_ERROR(CreateConvPowerVRWino4x4To6x6(creation_context, op_def,
|
||||
attr, &conv, &dst_shape));
|
||||
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
return OkStatus();
|
||||
} else {
|
||||
gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
|
||||
return SelectConvolution(attr, input_shape, creation_context, op_def,
|
||||
return SelectConvolution(attr, output_shape, creation_context, op_def,
|
||||
hints, gpu_op);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user