diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index faaf34040ac..b633cb4311a 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -62,6 +62,7 @@ cc_library( hdrs = ["dw_convolution_selector.h"], deps = [ "//tensorflow/lite/delegates/gpu/cl:cl_device", + "//tensorflow/lite/delegates/gpu/cl:precision", "//tensorflow/lite/delegates/gpu/cl/kernels:depth_wise_conv", "//tensorflow/lite/delegates/gpu/cl/kernels:depth_wise_conv_3x3", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc index 4235cd58ba4..85afa3fff43 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/dw_convolution_selector.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/depth_wise_conv_3x3.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" namespace tflite { namespace gpu { @@ -68,8 +69,10 @@ Status SelectDWConvolutionMali(const DepthwiseConvolution2DAttributes& attr, const auto storage_type = op_def.src_tensors[0].storage_type; bool buffer_type = storage_type == TensorStorageType::BUFFER || storage_type == TensorStorageType::IMAGE_BUFFER; - if (!buffer_type && !op_def.IsBatchSupported() && - IsDepthWiseConv3x3Supported(attr)) { + MaliInfo mali_info = creation_context.device->GetInfo().mali_info; + if (IsDepthWiseConv3x3Supported(attr) && !mali_info.IsMidgard() && + !buffer_type && !op_def.IsBatchSupported() && + op_def.precision != CalculationsPrecision::F32) { DepthWiseConv3x3 dw_conv; RETURN_IF_ERROR( CreateDepthWiseConv3x3(creation_context, op_def, attr, &dw_conv));