Removed useless Status and CreationContext from depth wise convolution kernels.
PiperOrigin-RevId: 327371483 Change-Id: I2b3ae18022881f2975fb2a1d31df6b1a78c5d936
This commit is contained in:
parent
aeec5a20a9
commit
c82c43f658
@ -306,42 +306,38 @@ int3 DepthwiseConvolution::GetGridSize() const {
|
||||
return int3(grid_x, grid_y, grid_z);
|
||||
}
|
||||
|
||||
absl::Status CreateDepthwiseConvolution(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
DepthwiseConvolution* result) {
|
||||
bool weights_are_buffer = creation_context.device->IsMali();
|
||||
*result = DepthwiseConvolution(definition, attr, weights_are_buffer);
|
||||
RETURN_IF_ERROR(
|
||||
result->UploadWeights(attr.weights, creation_context.context));
|
||||
DepthwiseConvolution CreateDepthwiseConvolution(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr) {
|
||||
bool weights_are_buffer = device_info.IsMali();
|
||||
DepthwiseConvolution result(definition, attr, weights_are_buffer);
|
||||
result.UploadWeights(attr.weights);
|
||||
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type = weights_are_buffer ? LinearStorageType::BUFFER
|
||||
: LinearStorageType::TEXTURE_2D;
|
||||
desc.element_type = definition.GetDataType();
|
||||
desc.UploadLinearData(attr.bias);
|
||||
result->args_.AddObject(
|
||||
result.args_.AddObject(
|
||||
"biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status CreateDepthwiseConvolution(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution3DAttributes& attr,
|
||||
DepthwiseConvolution* result) {
|
||||
bool weights_are_buffer = creation_context.device->IsMali();
|
||||
*result = DepthwiseConvolution(definition, attr, weights_are_buffer);
|
||||
RETURN_IF_ERROR(
|
||||
result->UploadWeights(attr.weights, creation_context.context));
|
||||
DepthwiseConvolution CreateDepthwiseConvolution(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution3DAttributes& attr) {
|
||||
bool weights_are_buffer = device_info.IsMali();
|
||||
DepthwiseConvolution result(definition, attr, weights_are_buffer);
|
||||
result.UploadWeights(attr.weights);
|
||||
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type = weights_are_buffer ? LinearStorageType::BUFFER
|
||||
: LinearStorageType::TEXTURE_2D;
|
||||
desc.element_type = definition.GetDataType();
|
||||
desc.UploadLinearData(attr.bias);
|
||||
result->args_.AddObject(
|
||||
result.args_.AddObject(
|
||||
"biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
|
@ -48,14 +48,12 @@ class DepthwiseConvolution : public GPUOperation {
|
||||
DepthwiseConvolution& operator=(const DepthwiseConvolution&) = delete;
|
||||
|
||||
private:
|
||||
friend absl::Status CreateDepthwiseConvolution(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
DepthwiseConvolution* result);
|
||||
friend absl::Status CreateDepthwiseConvolution(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution3DAttributes& attr,
|
||||
DepthwiseConvolution* result);
|
||||
friend DepthwiseConvolution CreateDepthwiseConvolution(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr);
|
||||
friend DepthwiseConvolution CreateDepthwiseConvolution(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution3DAttributes& attr);
|
||||
DepthwiseConvolution(const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
bool weights_are_buffer);
|
||||
@ -64,16 +62,14 @@ class DepthwiseConvolution : public GPUOperation {
|
||||
bool weights_are_buffer);
|
||||
|
||||
template <DataType T>
|
||||
absl::Status UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
CLContext* context);
|
||||
void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights);
|
||||
|
||||
template <DataType S, typename T>
|
||||
void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
|
||||
absl::Span<T> dst);
|
||||
|
||||
template <DataType T>
|
||||
absl::Status UploadWeights(const tflite::gpu::Tensor<OHWDI, T>& weights,
|
||||
CLContext* context);
|
||||
void UploadWeights(const tflite::gpu::Tensor<OHWDI, T>& weights);
|
||||
|
||||
template <DataType S, typename T>
|
||||
void RearrangeWeightsData(const tflite::gpu::Tensor<OHWDI, S>& weights,
|
||||
@ -94,8 +90,8 @@ class DepthwiseConvolution : public GPUOperation {
|
||||
};
|
||||
|
||||
template <DataType T>
|
||||
absl::Status DepthwiseConvolution::UploadWeights(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||
void DepthwiseConvolution::UploadWeights(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights) {
|
||||
const int dst_channels = weights.shape.i * weights.shape.o;
|
||||
const int dst_slices = DivideRoundUp(dst_channels, 4);
|
||||
const int kernel_x = weights.shape.w;
|
||||
@ -130,8 +126,6 @@ absl::Status DepthwiseConvolution::UploadWeights(
|
||||
desc.data = std::move(data);
|
||||
args_.AddObject("weights", absl::make_unique<Texture2DDescriptor>(desc));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <DataType S, typename T>
|
||||
@ -164,8 +158,8 @@ void DepthwiseConvolution::RearrangeWeightsData(
|
||||
}
|
||||
|
||||
template <DataType T>
|
||||
absl::Status DepthwiseConvolution::UploadWeights(
|
||||
const tflite::gpu::Tensor<OHWDI, T>& weights, CLContext* context) {
|
||||
void DepthwiseConvolution::UploadWeights(
|
||||
const tflite::gpu::Tensor<OHWDI, T>& weights) {
|
||||
const int dst_channels = weights.shape.i * weights.shape.o;
|
||||
const int dst_slices = DivideRoundUp(dst_channels, 4);
|
||||
const int kernel_x = weights.shape.w;
|
||||
@ -203,8 +197,6 @@ absl::Status DepthwiseConvolution::UploadWeights(
|
||||
args_.AddObject("weights",
|
||||
absl::make_unique<Texture2DDescriptor>(std::move(desc)));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <DataType S, typename T>
|
||||
@ -239,9 +231,13 @@ void DepthwiseConvolution::RearrangeWeightsData(
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status CreateDepthwiseConvolution(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr, DepthwiseConvolution* result);
|
||||
DepthwiseConvolution CreateDepthwiseConvolution(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr);
|
||||
|
||||
DepthwiseConvolution CreateDepthwiseConvolution(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution3DAttributes& attr);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -313,21 +313,15 @@ bool IsDepthwiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) {
|
||||
attr.padding.appended.h == 1;
|
||||
}
|
||||
|
||||
absl::Status CreateDepthwiseConv3x3(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr, DepthwiseConv3x3* result) {
|
||||
if (!IsDepthwiseConv3x3Supported(attr)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"DepthwiseConv3x3 doesn't support this attributes");
|
||||
}
|
||||
bool weights_are_buffer =
|
||||
creation_context.device->IsPowerVR() || creation_context.device->IsMali();
|
||||
bool local_mem_uploads =
|
||||
weights_are_buffer && creation_context.device->IsPowerVR();
|
||||
*result = DepthwiseConv3x3(definition, weights_are_buffer, local_mem_uploads,
|
||||
creation_context.device->info_);
|
||||
return result->UploadWeightsAndBiases(
|
||||
attr.weights, attr.bias, weights_are_buffer, creation_context.context);
|
||||
DepthwiseConv3x3 CreateDepthwiseConv3x3(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr) {
|
||||
bool weights_are_buffer = device_info.IsPowerVR() || device_info.IsMali();
|
||||
bool local_mem_uploads = weights_are_buffer && device_info.IsPowerVR();
|
||||
DepthwiseConv3x3 result(definition, weights_are_buffer, local_mem_uploads,
|
||||
device_info);
|
||||
result.UploadWeightsAndBiases(attr.weights, attr.bias, weights_are_buffer);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
|
@ -55,14 +55,13 @@ class DepthwiseConv3x3 : public GPUOperation {
|
||||
bool weights_are_buffer, bool local_mem_uploads,
|
||||
const DeviceInfo& device_info);
|
||||
template <DataType T>
|
||||
absl::Status UploadWeightsAndBiases(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases, bool weights_are_buffer,
|
||||
CLContext* context);
|
||||
void UploadWeightsAndBiases(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases,
|
||||
bool weights_are_buffer);
|
||||
|
||||
friend absl::Status CreateDepthwiseConv3x3(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr, DepthwiseConv3x3* result);
|
||||
friend DepthwiseConv3x3 CreateDepthwiseConv3x3(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr);
|
||||
|
||||
template <DataType S, typename T>
|
||||
void RearrangeWeightsAndBiasesData(
|
||||
@ -77,10 +76,9 @@ class DepthwiseConv3x3 : public GPUOperation {
|
||||
};
|
||||
|
||||
template <DataType T>
|
||||
absl::Status DepthwiseConv3x3::UploadWeightsAndBiases(
|
||||
void DepthwiseConv3x3::UploadWeightsAndBiases(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases, bool weights_are_buffer,
|
||||
CLContext* context) {
|
||||
const tflite::gpu::Tensor<Linear, T>& biases, bool weights_are_buffer) {
|
||||
const int src_depth = DivideRoundUp(weights.shape.i, 4);
|
||||
int texture_width = 10; // 3x3 kernel + 1 bias
|
||||
int texture_height = src_depth;
|
||||
@ -115,8 +113,6 @@ absl::Status DepthwiseConv3x3::UploadWeightsAndBiases(
|
||||
args_.AddObject("weights",
|
||||
absl::make_unique<Texture2DDescriptor>(std::move(desc)));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <DataType S, typename T>
|
||||
@ -154,9 +150,9 @@ void DepthwiseConv3x3::RearrangeWeightsAndBiasesData(
|
||||
|
||||
bool IsDepthwiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr);
|
||||
|
||||
absl::Status CreateDepthwiseConv3x3(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr, DepthwiseConv3x3* result);
|
||||
DepthwiseConv3x3 CreateDepthwiseConv3x3(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const DepthwiseConvolution2DAttributes& attr);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -56,9 +56,8 @@ TEST_F(OpenCLOperationTest, DepthwiseConv3x3SimpleWeights) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
DepthwiseConv3x3 operation;
|
||||
ASSERT_OK(
|
||||
CreateDepthwiseConv3x3(creation_context_, op_def, attr, &operation));
|
||||
DepthwiseConv3x3 operation = CreateDepthwiseConv3x3(
|
||||
creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -93,9 +92,8 @@ TEST_F(OpenCLOperationTest, DepthwiseConv3x3) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
DepthwiseConv3x3 operation;
|
||||
ASSERT_OK(
|
||||
CreateDepthwiseConv3x3(creation_context_, op_def, attr, &operation));
|
||||
DepthwiseConv3x3 operation = CreateDepthwiseConv3x3(
|
||||
creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
|
@ -55,9 +55,8 @@ TEST_F(OpenCLOperationTest, DepthwiseConvSimpleWeights) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
DepthwiseConvolution operation;
|
||||
ASSERT_OK(CreateDepthwiseConvolution(creation_context_, op_def, attr,
|
||||
&operation));
|
||||
DepthwiseConvolution operation = CreateDepthwiseConvolution(
|
||||
creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -91,9 +90,8 @@ TEST_F(OpenCLOperationTest, DepthwiseConvNoMultiplier) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
DepthwiseConvolution operation;
|
||||
ASSERT_OK(CreateDepthwiseConvolution(creation_context_, op_def, attr,
|
||||
&operation));
|
||||
DepthwiseConvolution operation = CreateDepthwiseConvolution(
|
||||
creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -128,9 +126,8 @@ TEST_F(OpenCLOperationTest, DepthwiseConvMultiplier2) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
DepthwiseConvolution operation;
|
||||
ASSERT_OK(CreateDepthwiseConvolution(creation_context_, op_def, attr,
|
||||
&operation));
|
||||
DepthwiseConvolution operation = CreateDepthwiseConvolution(
|
||||
creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 4), &dst_tensor));
|
||||
EXPECT_THAT(
|
||||
|
@ -26,79 +26,59 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
absl::Status SelectDWConvolutionAdreno(
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
const CreationContext& creation_context, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
std::unique_ptr<GPUOperation> SelectDWConvolutionAdreno(
|
||||
const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const OperationDef& op_def) {
|
||||
if (IsDepthwiseConv3x3Supported(attr)) {
|
||||
DepthwiseConv3x3 dw_conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateDepthwiseConv3x3(creation_context, op_def, attr, &dw_conv));
|
||||
*ptr = absl::make_unique<DepthwiseConv3x3>(std::move(dw_conv));
|
||||
return absl::make_unique<DepthwiseConv3x3>(
|
||||
CreateDepthwiseConv3x3(device_info, op_def, attr));
|
||||
} else {
|
||||
DepthwiseConvolution dw_conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateDepthwiseConvolution(creation_context, op_def, attr, &dw_conv));
|
||||
*ptr = absl::make_unique<DepthwiseConvolution>(std::move(dw_conv));
|
||||
return absl::make_unique<DepthwiseConvolution>(
|
||||
CreateDepthwiseConvolution(device_info, op_def, attr));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SelectDWConvolutionPowerVR(
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
const CreationContext& creation_context, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
std::unique_ptr<GPUOperation> SelectDWConvolutionPowerVR(
|
||||
const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const OperationDef& op_def) {
|
||||
if (IsDepthwiseConv3x3Supported(attr)) {
|
||||
DepthwiseConv3x3 dw_conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateDepthwiseConv3x3(creation_context, op_def, attr, &dw_conv));
|
||||
*ptr = absl::make_unique<DepthwiseConv3x3>(std::move(dw_conv));
|
||||
return absl::make_unique<DepthwiseConv3x3>(
|
||||
CreateDepthwiseConv3x3(device_info, op_def, attr));
|
||||
} else {
|
||||
DepthwiseConvolution dw_conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateDepthwiseConvolution(creation_context, op_def, attr, &dw_conv));
|
||||
*ptr = absl::make_unique<DepthwiseConvolution>(std::move(dw_conv));
|
||||
return absl::make_unique<DepthwiseConvolution>(
|
||||
CreateDepthwiseConvolution(device_info, op_def, attr));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SelectDWConvolutionMali(
|
||||
const DepthwiseConvolution2DAttributes& attr,
|
||||
const CreationContext& creation_context, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
std::unique_ptr<GPUOperation> SelectDWConvolutionMali(
|
||||
const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const OperationDef& op_def) {
|
||||
const auto storage_type = op_def.src_tensors[0].storage_type;
|
||||
bool buffer_type = storage_type == TensorStorageType::BUFFER ||
|
||||
storage_type == TensorStorageType::IMAGE_BUFFER;
|
||||
MaliInfo mali_info = creation_context.device->info_.mali_info;
|
||||
const MaliInfo mali_info = device_info.mali_info;
|
||||
if (IsDepthwiseConv3x3Supported(attr) && !mali_info.IsMidgard() &&
|
||||
!buffer_type && op_def.precision != CalculationsPrecision::F32) {
|
||||
DepthwiseConv3x3 dw_conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateDepthwiseConv3x3(creation_context, op_def, attr, &dw_conv));
|
||||
*ptr = absl::make_unique<DepthwiseConv3x3>(std::move(dw_conv));
|
||||
return absl::make_unique<DepthwiseConv3x3>(
|
||||
CreateDepthwiseConv3x3(device_info, op_def, attr));
|
||||
} else {
|
||||
DepthwiseConvolution dw_conv;
|
||||
RETURN_IF_ERROR(
|
||||
CreateDepthwiseConvolution(creation_context, op_def, attr, &dw_conv));
|
||||
*ptr = absl::make_unique<DepthwiseConvolution>(std::move(dw_conv));
|
||||
return absl::make_unique<DepthwiseConvolution>(
|
||||
CreateDepthwiseConvolution(device_info, op_def, attr));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
const auto& device_info = creation_context.device->info_;
|
||||
std::unique_ptr<GPUOperation> SelectDWConvolution(
|
||||
const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const OperationDef& op_def) {
|
||||
if (device_info.IsAdreno()) {
|
||||
return SelectDWConvolutionAdreno(attr, creation_context, op_def, ptr);
|
||||
return SelectDWConvolutionAdreno(attr, device_info, op_def);
|
||||
} else if (device_info.IsPowerVR()) {
|
||||
return SelectDWConvolutionPowerVR(attr, creation_context, op_def, ptr);
|
||||
return SelectDWConvolutionPowerVR(attr, device_info, op_def);
|
||||
} else if (device_info.IsMali()) {
|
||||
return SelectDWConvolutionMali(attr, creation_context, op_def, ptr);
|
||||
return SelectDWConvolutionMali(attr, device_info, op_def);
|
||||
} else {
|
||||
return SelectDWConvolutionAdreno(attr, creation_context, op_def, ptr);
|
||||
return SelectDWConvolutionAdreno(attr, device_info, op_def);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,10 +26,9 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
absl::Status SelectDWConvolution(const DepthwiseConvolution2DAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
std::unique_ptr<GPUOperation> SelectDWConvolution(
|
||||
const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const OperationDef& op_def);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -237,7 +237,9 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
case OperationType::DEPTHWISE_CONVOLUTION: {
|
||||
auto attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||
node.operation.attributes);
|
||||
return SelectDWConvolution(attr, creation_context, op_def, gpu_op);
|
||||
*gpu_op =
|
||||
SelectDWConvolution(attr, creation_context.GetDeviceInfo(), op_def);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::FULLY_CONNECTED: {
|
||||
auto attr =
|
||||
|
Loading…
x
Reference in New Issue
Block a user