Improved selection of block size for Mali.

PiperOrigin-RevId: 300605347
Change-Id: If538c67ffff73e5b45c2baa6941d90863b8399db
This commit is contained in:
Raman Sarokin 2020-03-12 12:49:59 -07:00 committed by TensorFlower Gardener
parent 6625cee185
commit 13367effdd
7 changed files with 159 additions and 110 deletions

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h"
#include <array> #include <array>
#include <cfloat>
#include <string> #include <string>
#include <utility> #include <utility>
@ -202,74 +201,6 @@ std::string GenerateConvBuffer1x1(
return c; return c;
} }
// task_size as amount of FLT4 processed elements.
int GetRecommendedBlockSizeForConv(const CLDevice& device,
const OperationDef& definition,
int task_size) {
const float task_size_per_cu =
task_size / static_cast<float>(device.GetInfo().compute_units_count);
int block_size = 1;
float threshold_1 = FLT_MAX;
float threshold_2 = FLT_MAX;
float threshold_4 = FLT_MAX;
if (!device.IsMali()) {
return 1;
}
MaliInfo mali_info = device.GetInfo().mali_info;
switch (definition.precision) {
case CalculationsPrecision::F16:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 4.0f;
threshold_4 = 256.0f * 8.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 256.0f * 2.0f;
threshold_2 = 256.0f * 8.0f;
threshold_4 = 256.0f * 16.0f;
} else if (mali_info.IsBifrostGen3()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 6.0f;
threshold_4 = 256.0f * 16.0f;
}
break;
case CalculationsPrecision::F32_F16:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 3.0f;
threshold_4 = 256.0f * 32.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 256.0f * 2.0f;
threshold_2 = 256.0f * 8.0f;
} else if (mali_info.IsBifrostGen3()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 8.0f;
}
break;
case CalculationsPrecision::F32:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 4.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 128.0f;
threshold_2 = 256.0f * 4.0f;
} else if (mali_info.IsBifrostGen3()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 12.0f;
}
break;
}
if (task_size_per_cu <= threshold_1) {
block_size = 1;
} else if (task_size_per_cu <= threshold_2) {
block_size = 2;
} else if (task_size_per_cu <= threshold_4) {
block_size = 4;
} else {
block_size = 8;
}
return block_size;
}
ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device, ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
const OperationDef& definition, const OperationDef& definition,
const BHWC& shape, int src_depth, const BHWC& shape, int src_depth,
@ -295,7 +226,7 @@ ConvBuffer1x1::ConvParams GetBestParams(const CLDevice& device,
int task_size = shape.w * shape.b * shape.h * dst_depth; int task_size = shape.w * shape.b * shape.h * dst_depth;
int block_size = int block_size =
GetRecommendedBlockSizeForConv(device, definition, task_size); GetRecommendedBlockSizeForConv(device, definition.precision, task_size);
if (!can_use_flt8 && block_size > 4) { if (!can_use_flt8 && block_size > 4) {
block_size = 4; block_size = 4;

View File

@ -130,21 +130,21 @@ std::string GenerateBlockCoords(const int3& block_size,
ConvPowerVR::ConvPowerVR(const OperationDef& definition, ConvPowerVR::ConvPowerVR(const OperationDef& definition,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
const CLDevice& device) const CLDevice& device, const BHWC* dst_shape)
: GPUOperation(definition), : GPUOperation(definition),
stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w, stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w,
-attr.padding.prepended.h), -attr.padding.prepended.h),
kernel_dilation_(attr.weights.shape.w, attr.weights.shape.h, kernel_dilation_(attr.weights.shape.w, attr.weights.shape.h,
attr.dilations.w, attr.dilations.h), attr.dilations.w, attr.dilations.h),
conv_params_(GuessBestParams(device, definition, attr)) {} conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
ConvPowerVR::ConvPowerVR(const OperationDef& definition, ConvPowerVR::ConvPowerVR(const OperationDef& definition,
const FullyConnectedAttributes& attr, const FullyConnectedAttributes& attr,
const CLDevice& device) const CLDevice& device, const BHWC* dst_shape)
: GPUOperation(definition), : GPUOperation(definition),
stride_padding_(1, 1, 0, 0), stride_padding_(1, 1, 0, 0),
kernel_dilation_(1, 1, 1, 1), kernel_dilation_(1, 1, 1, 1),
conv_params_(GuessBestParams(device, definition, attr)) {} conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
ConvPowerVR::ConvPowerVR(const OperationDef& definition) ConvPowerVR::ConvPowerVR(const OperationDef& definition)
: GPUOperation(definition), : GPUOperation(definition),
@ -628,7 +628,7 @@ std::string GenerateConv(
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
const CLDevice& device, const OperationDef& definition, int src_depth, const CLDevice& device, const OperationDef& definition, int src_depth,
int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1, int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
bool different_weights_for_height) const { bool different_weights_for_height, const BHWC* dst_shape) const {
ConvParams conv_params; ConvParams conv_params;
conv_params.linear_hw = false; conv_params.linear_hw = false;
conv_params.weights_data_type = conv_params.weights_data_type =
@ -741,17 +741,45 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
conv_params.src_depth_loop_size = 2; conv_params.src_depth_loop_size = 2;
} }
} else if (device.IsMali()) { } else if (device.IsMali()) {
conv_params.block_size = int3(2, 1, 1); int block_size = 2;
if (dst_shape) {
int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
block_size = GetRecommendedBlockSizeForConv(device, definition.precision,
task_size);
}
if (!x_kernel_is_1 || !y_kernel_is_1) {
block_size = std::min(block_size, 4);
}
if (block_size == 8) {
if (dst_depth == 1 || dst_depth == 3) {
conv_params.block_size = int3(2, 2, 1);
} else {
conv_params.block_size = int3(2, 2, 2);
}
} else if (block_size == 4) {
if (dst_depth == 1 || dst_depth == 3) {
conv_params.block_size = int3(2, 2, 1);
} else {
conv_params.block_size = int3(2, 1, 2);
}
} else if (block_size == 2) {
conv_params.block_size = int3(2, 1, 1);
} else {
conv_params.block_size = int3(1, 1, 1);
}
conv_params.src_depth_loop_size = 1;
MaliInfo mali_info = device.GetInfo().mali_info;
if (src_depth % 2 == 0 && block_size <= 2 && !mali_info.IsMidgard()) {
conv_params.src_depth_loop_size = 2;
}
if (src_depth % 4 == 0 && block_size == 1 && !mali_info.IsMidgard() &&
definition.precision == CalculationsPrecision::F16) {
conv_params.src_depth_loop_size = 4;
}
conv_params.work_group_size = int3(4, 4, 1); conv_params.work_group_size = int3(4, 4, 1);
conv_params.work_group_launch_order = int3(0, 1, 2); conv_params.work_group_launch_order = int3(0, 1, 2);
conv_params.fixed_work_group_size = false; conv_params.fixed_work_group_size = false;
conv_params.src_depth_loop_size = 1;
conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
if (dst_depth % 2 == 0 || dst_depth >= 4) {
conv_params.block_size.z = 2;
} else {
conv_params.block_size.z = 1;
}
} else { } else {
conv_params.block_size = int3(1, 1, 4); conv_params.block_size = int3(1, 1, 4);
conv_params.work_group_size = int3(8, 2, 1); conv_params.work_group_size = int3(8, 2, 1);
@ -779,7 +807,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
const CLDevice& device, const OperationDef& definition, const CLDevice& device, const OperationDef& definition,
const Convolution2DAttributes& attr) const { const Convolution2DAttributes& attr, const BHWC* dst_shape) const {
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 && const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 &&
@ -791,16 +819,16 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
attr.padding.prepended.h == 0 && attr.padding.prepended.h == 0 &&
attr.padding.appended.h == 0; attr.padding.appended.h == 0;
return GuessBestParams(device, definition, src_depth, dst_depth, return GuessBestParams(device, definition, src_depth, dst_depth,
x_kernel_is_1, y_kernel_is_1, false); x_kernel_is_1, y_kernel_is_1, false, dst_shape);
} }
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
const CLDevice& device, const OperationDef& definition, const CLDevice& device, const OperationDef& definition,
const FullyConnectedAttributes& attr) const { const FullyConnectedAttributes& attr, const BHWC* dst_shape) const {
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
ConvPowerVR::ConvParams params = GuessBestParams( ConvPowerVR::ConvParams params = GuessBestParams(
device, definition, src_depth, dst_depth, true, true, false); device, definition, src_depth, dst_depth, true, true, false, dst_shape);
params.work_group_size.x *= params.work_group_size.y; params.work_group_size.x *= params.work_group_size.y;
params.work_group_size.y = 1; params.work_group_size.y = 1;
params.block_size.x *= params.block_size.y; params.block_size.x *= params.block_size.y;
@ -810,11 +838,11 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd( ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(
const CLDevice& device, const OperationDef& definition, const CLDevice& device, const OperationDef& definition,
const Convolution2DAttributes& attr) const { const Convolution2DAttributes& attr, const BHWC* dst_shape) const {
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4); const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4); const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
ConvPowerVR::ConvParams params = GuessBestParams( ConvPowerVR::ConvParams params = GuessBestParams(
device, definition, src_depth, dst_depth, true, true, true); device, definition, src_depth, dst_depth, true, true, true, dst_shape);
params.block_size.x *= params.block_size.y; params.block_size.x *= params.block_size.y;
params.block_size.y = 1; params.block_size.y = 1;
return params; return params;
@ -823,26 +851,27 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(
Status CreateConvPowerVR(const CreationContext& creation_context, Status CreateConvPowerVR(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
ConvPowerVR* result) { ConvPowerVR* result, const BHWC* dst_shape) {
*result = ConvPowerVR(definition, attr, *creation_context.device); *result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
return result->UploadData(attr.weights, attr.bias, creation_context.context); return result->UploadData(attr.weights, attr.bias, creation_context.context);
} }
Status CreateConvPowerVR(const CreationContext& creation_context, Status CreateConvPowerVR(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const FullyConnectedAttributes& attr, const FullyConnectedAttributes& attr,
ConvPowerVR* result) { ConvPowerVR* result, const BHWC* dst_shape) {
*result = ConvPowerVR(definition, attr, *creation_context.device); *result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
return result->UploadData(attr.weights, attr.bias, creation_context.context); return result->UploadData(attr.weights, attr.bias, creation_context.context);
} }
Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context, Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
ConvPowerVR* result) { ConvPowerVR* result,
const BHWC* dst_shape) {
*result = ConvPowerVR(definition); *result = ConvPowerVR(definition);
result->conv_params_ = result->GuessBestParamsWinograd( result->conv_params_ = result->GuessBestParamsWinograd(
*creation_context.device, definition, attr); *creation_context.device, definition, attr, dst_shape);
return result->UploadDataForWinograd4x4To6x6( return result->UploadDataForWinograd4x4To6x6(
attr.weights, *creation_context.device, creation_context.context); attr.weights, *creation_context.device, creation_context.context);
} }

View File

@ -79,9 +79,11 @@ class ConvPowerVR : public GPUOperation {
}; };
ConvPowerVR(const OperationDef& definition, ConvPowerVR(const OperationDef& definition,
const Convolution2DAttributes& attr, const CLDevice& device); const Convolution2DAttributes& attr, const CLDevice& device,
const BHWC* dst_shape = nullptr);
ConvPowerVR(const OperationDef& definition, ConvPowerVR(const OperationDef& definition,
const FullyConnectedAttributes& attr, const CLDevice& device); const FullyConnectedAttributes& attr, const CLDevice& device,
const BHWC* dst_shape = nullptr);
explicit ConvPowerVR(const OperationDef& definition); explicit ConvPowerVR(const OperationDef& definition);
template <DataType T> template <DataType T>
@ -100,16 +102,17 @@ class ConvPowerVR : public GPUOperation {
friend Status CreateConvPowerVR(const CreationContext& creation_context, friend Status CreateConvPowerVR(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
ConvPowerVR* result); ConvPowerVR* result, const BHWC* dst_shape);
friend Status CreateConvPowerVR(const CreationContext& creation_context, friend Status CreateConvPowerVR(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const FullyConnectedAttributes& attr, const FullyConnectedAttributes& attr,
ConvPowerVR* result); ConvPowerVR* result, const BHWC* dst_shape);
friend Status CreateConvPowerVRWino4x4To6x6( friend Status CreateConvPowerVRWino4x4To6x6(
const CreationContext& creation_context, const OperationDef& definition, const CreationContext& creation_context, const OperationDef& definition,
const Convolution2DAttributes& attr, ConvPowerVR* result); const Convolution2DAttributes& attr, ConvPowerVR* result,
const BHWC* dst_shape);
friend std::string GenerateConv( friend std::string GenerateConv(
const CLDevice& device, const OperationDef& op_def, const CLDevice& device, const OperationDef& op_def,
@ -118,18 +121,22 @@ class ConvPowerVR : public GPUOperation {
ConvParams GuessBestParams(const CLDevice& device, ConvParams GuessBestParams(const CLDevice& device,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr) const; const Convolution2DAttributes& attr,
const BHWC* dst_shape = nullptr) const;
ConvParams GuessBestParams(const CLDevice& device, ConvParams GuessBestParams(const CLDevice& device,
const OperationDef& definition, const OperationDef& definition,
const FullyConnectedAttributes& attr) const; const FullyConnectedAttributes& attr,
const BHWC* dst_shape = nullptr) const;
ConvParams GuessBestParamsWinograd(const CLDevice& device, ConvParams GuessBestParamsWinograd(const CLDevice& device,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr) const; const Convolution2DAttributes& attr,
const BHWC* dst_shape = nullptr) const;
ConvParams GuessBestParams(const CLDevice& device, ConvParams GuessBestParams(const CLDevice& device,
const OperationDef& definition, int src_depth, const OperationDef& definition, int src_depth,
int dst_depth, bool x_kernel_is_1, int dst_depth, bool x_kernel_is_1,
bool y_kernel_is_1, bool y_kernel_is_1,
bool different_weights_for_height) const; bool different_weights_for_height,
const BHWC* dst_shape = nullptr) const;
Status BindArguments(); Status BindArguments();
int3 GetGridSize() const; int3 GetGridSize() const;
@ -206,17 +213,18 @@ Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
Status CreateConvPowerVR(const CreationContext& creation_context, Status CreateConvPowerVR(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
ConvPowerVR* result); ConvPowerVR* result, const BHWC* dst_shape = nullptr);
Status CreateConvPowerVR(const CreationContext& creation_context, Status CreateConvPowerVR(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const FullyConnectedAttributes& attr, const FullyConnectedAttributes& attr,
ConvPowerVR* result); ConvPowerVR* result, const BHWC* dst_shape = nullptr);
Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context, Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const Convolution2DAttributes& attr, const Convolution2DAttributes& attr,
ConvPowerVR* result); ConvPowerVR* result,
const BHWC* dst_shape = nullptr);
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include <cfloat>
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
@ -721,6 +722,80 @@ int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size) {
return {1, 1, 1}; return {1, 1, 1};
} }
int GetRecommendedBlockSizeForConv(const CLDevice& device,
CalculationsPrecision precision,
int task_size) {
const float task_size_per_cu =
task_size / static_cast<float>(device.GetInfo().compute_units_count);
int block_size = 1;
float threshold_1 = FLT_MAX;
float threshold_2 = FLT_MAX;
float threshold_4 = FLT_MAX;
if (!device.IsMali()) {
return 1;
}
MaliInfo mali_info = device.GetInfo().mali_info;
switch (precision) {
case CalculationsPrecision::F16:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 4.0f;
threshold_4 = 256.0f * 8.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 256.0f * 2.0f;
threshold_2 = 256.0f * 8.0f;
threshold_4 = 256.0f * 16.0f;
} else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 6.0f;
threshold_4 = 256.0f * 16.0f;
} else if (mali_info.IsMidgard()) {
threshold_1 = 256.0f * 4.0f;
threshold_2 = 256.0f * 16.0f;
}
break;
case CalculationsPrecision::F32_F16:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 3.0f;
threshold_4 = 256.0f * 32.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 256.0f * 2.0f;
threshold_2 = 256.0f * 8.0f;
} else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 8.0f;
} else if (mali_info.IsMidgard()) {
threshold_1 = 256.0f * 4.0f;
}
break;
case CalculationsPrecision::F32:
if (mali_info.IsBifrostGen1()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 4.0f;
} else if (mali_info.IsBifrostGen2()) {
threshold_1 = 128.0f;
threshold_2 = 256.0f * 4.0f;
} else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
threshold_1 = 256.0f;
threshold_2 = 256.0f * 12.0f;
} else if (mali_info.IsMidgard()) {
threshold_1 = 256.0f * 16.0f;
}
break;
}
if (task_size_per_cu <= threshold_1) {
block_size = 1;
} else if (task_size_per_cu <= threshold_2) {
block_size = 2;
} else if (task_size_per_cu <= threshold_4) {
block_size = 4;
} else {
block_size = 8;
}
return block_size;
}
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -305,6 +305,11 @@ float4 GetMaskForLastPlane(int channels);
// returns first work group from wgs that has size not bigger than max_wg_size // returns first work group from wgs that has size not bigger than max_wg_size
// if no suitable groups among wgs, returns {1, 1, 1} // if no suitable groups among wgs, returns {1, 1, 1}
int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size); int3 GetFirstSuitableWorkGroup(const std::vector<int3>& wgs, int max_wg_size);
// task_size as amount of FLT4 processed elements.
int GetRecommendedBlockSizeForConv(const CLDevice& device,
CalculationsPrecision precision,
int task_size);
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -100,7 +100,8 @@ Status SelectConvolutionMali(const Convolution2DAttributes& attr,
*ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv)); *ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv));
} else { } else {
ConvPowerVR conv; ConvPowerVR conv;
RETURN_IF_ERROR(CreateConvPowerVR(creation_context, op_def, attr, &conv)); RETURN_IF_ERROR(
CreateConvPowerVR(creation_context, op_def, attr, &conv, &dst_shape));
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv)); *ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
} }
return OkStatus(); return OkStatus();
@ -118,8 +119,8 @@ Status SelectConvolutionWinogradMali(const Convolution2DAttributes& attr,
*ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv)); *ptr = absl::make_unique<ConvBuffer1x1>(std::move(conv));
} else { } else {
ConvPowerVR conv; ConvPowerVR conv;
RETURN_IF_ERROR( RETURN_IF_ERROR(CreateConvPowerVRWino4x4To6x6(creation_context, op_def,
CreateConvPowerVRWino4x4To6x6(creation_context, op_def, attr, &conv)); attr, &conv, &dst_shape));
*ptr = absl::make_unique<ConvPowerVR>(std::move(conv)); *ptr = absl::make_unique<ConvPowerVR>(std::move(conv));
} }

View File

@ -205,7 +205,7 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
return OkStatus(); return OkStatus();
} else { } else {
gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); gpu_op = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
return SelectConvolution(attr, input_shape, creation_context, op_def, return SelectConvolution(attr, output_shape, creation_context, op_def,
hints, gpu_op); hints, gpu_op);
} }
} }