Added more restriction for DepthWise3x3 selection on Mali.

PiperOrigin-RevId: 300614404
Change-Id: I35a3f07fcc0092f6af3ace7015006991d9d2b273
This commit is contained in:
Raman Sarokin 2020-03-12 13:36:08 -07:00 committed by TensorFlower Gardener
parent 0d3b9d540f
commit 24e4a95157
2 changed files with 6 additions and 2 deletions

View File

@ -62,6 +62,7 @@ cc_library(
hdrs = ["dw_convolution_selector.h"],
deps = [
"//tensorflow/lite/delegates/gpu/cl:cl_device",
"//tensorflow/lite/delegates/gpu/cl:precision",
"//tensorflow/lite/delegates/gpu/cl/kernels:depth_wise_conv",
"//tensorflow/lite/delegates/gpu/cl/kernels:depth_wise_conv_3x3",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
namespace tflite {
namespace gpu {
@ -68,8 +69,10 @@ Status SelectDWConvolutionMali(const DepthwiseConvolution2DAttributes& attr,
const auto storage_type = op_def.src_tensors[0].storage_type;
bool buffer_type = storage_type == TensorStorageType::BUFFER ||
storage_type == TensorStorageType::IMAGE_BUFFER;
if (!buffer_type && !op_def.IsBatchSupported() &&
IsDepthWiseConv3x3Supported(attr)) {
MaliInfo mali_info = creation_context.device->GetInfo().mali_info;
if (IsDepthWiseConv3x3Supported(attr) && !mali_info.IsMidgard() &&
!buffer_type && !op_def.IsBatchSupported() &&
op_def.precision != CalculationsPrecision::F32) {
DepthWiseConv3x3 dw_conv;
RETURN_IF_ERROR(
CreateDepthWiseConv3x3(creation_context, op_def, attr, &dw_conv));