Added support of dynamic weights to ConvPowerVR.
PiperOrigin-RevId: 309335643 Change-Id: I7f2b0536eab2ed123e25b30de21d937a9a38204b
This commit is contained in:
parent
afa06187b9
commit
896a2700d1
@ -222,6 +222,7 @@ cc_library(
|
||||
srcs = ["conv_powervr.cc"],
|
||||
hdrs = ["conv_powervr.h"],
|
||||
deps = [
|
||||
":conv_common",
|
||||
":gpu_operation",
|
||||
":util",
|
||||
":work_group_picking",
|
||||
|
@ -138,6 +138,18 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
attr.dilations.w, attr.dilations.h),
|
||||
conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC& weights_shape, const CLDevice& device,
|
||||
const BHWC* dst_shape)
|
||||
: GPUOperation(definition),
|
||||
stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w,
|
||||
-attr.padding.prepended.h),
|
||||
kernel_dilation_(weights_shape.w, weights_shape.h, attr.dilations.w,
|
||||
attr.dilations.h),
|
||||
conv_params_(GuessBestParams(device, definition, attr, weights_shape,
|
||||
dst_shape)) {}
|
||||
|
||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
const CLDevice& device, const BHWC* dst_shape)
|
||||
@ -192,7 +204,11 @@ absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) {
|
||||
absl::Status ConvPowerVR::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
if (definition_.src_tensors.size() == 1) {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
} else {
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||
}
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
@ -821,6 +837,22 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
x_kernel_is_1, y_kernel_is_1, false, dst_shape);
|
||||
}
|
||||
|
||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
const CLDevice& device, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||
const BHWC* dst_shape) const {
|
||||
const int dst_depth = DivideRoundUp(weights_shape.b, 4);
|
||||
const int src_depth = DivideRoundUp(weights_shape.c, 4);
|
||||
const bool x_kernel_is_1 =
|
||||
weights_shape.w == 1 && attr.strides.w == 1 && attr.dilations.w == 1 &&
|
||||
attr.padding.prepended.w == 0 && attr.padding.appended.w == 0;
|
||||
const bool y_kernel_is_1 =
|
||||
weights_shape.h == 1 && attr.strides.h == 1 && attr.dilations.h == 1 &&
|
||||
attr.padding.prepended.h == 0 && attr.padding.appended.h == 0;
|
||||
return GuessBestParams(device, definition, src_depth, dst_depth,
|
||||
x_kernel_is_1, y_kernel_is_1, false, dst_shape);
|
||||
}
|
||||
|
||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||
const CLDevice& device, const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr, const BHWC* dst_shape) const {
|
||||
@ -863,6 +895,20 @@ absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||
}
|
||||
|
||||
absl::Status CreateConvPowerVRDynamicWeights(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||
ConvPowerVR* result, const BHWC* dst_shape) {
|
||||
*result = ConvPowerVR(definition, attr, weights_shape,
|
||||
*creation_context.device, dst_shape);
|
||||
LinearStorageCreateInfo create_info;
|
||||
create_info.storage_type = LinearStorageType::BUFFER;
|
||||
create_info.data_type = result->conv_params_.weights_data_type;
|
||||
create_info.aligned_size = weights_shape.b;
|
||||
return CreateLinearStorage(create_info, attr.bias, creation_context.context,
|
||||
&result->biases_);
|
||||
}
|
||||
|
||||
absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||
@ -44,6 +45,13 @@ class ConvPowerVR : public GPUOperation {
|
||||
absl::Status Tune(const TuningParameters& params) override;
|
||||
absl::Status Compile(const CreationContext& creation_context) override;
|
||||
|
||||
ConvWeightsDescription GetConvWeightsDescription() const {
|
||||
ConvWeightsDescription desc;
|
||||
desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
|
||||
desc.output_group_size = conv_params_.block_size.z;
|
||||
return desc;
|
||||
}
|
||||
|
||||
// Move only
|
||||
ConvPowerVR(ConvPowerVR&& operation);
|
||||
ConvPowerVR& operator=(ConvPowerVR&& operation);
|
||||
@ -82,6 +90,9 @@ class ConvPowerVR : public GPUOperation {
|
||||
ConvPowerVR(const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const CLDevice& device,
|
||||
const BHWC* dst_shape = nullptr);
|
||||
ConvPowerVR(const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||
const CLDevice& device, const BHWC* dst_shape = nullptr);
|
||||
ConvPowerVR(const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr, const CLDevice& device,
|
||||
const BHWC* dst_shape = nullptr);
|
||||
@ -112,6 +123,11 @@ class ConvPowerVR : public GPUOperation {
|
||||
ConvPowerVR* result,
|
||||
const BHWC* dst_shape);
|
||||
|
||||
friend absl::Status CreateConvPowerVRDynamicWeights(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||
ConvPowerVR* result, const BHWC* dst_shape);
|
||||
|
||||
friend absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||
@ -126,6 +142,11 @@ class ConvPowerVR : public GPUOperation {
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC* dst_shape = nullptr) const;
|
||||
ConvParams GuessBestParams(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC& weights_shape,
|
||||
const BHWC* dst_shape = nullptr) const;
|
||||
ConvParams GuessBestParams(const CLDevice& device,
|
||||
const OperationDef& definition,
|
||||
const FullyConnectedAttributes& attr,
|
||||
@ -225,6 +246,11 @@ absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||
ConvPowerVR* result,
|
||||
const BHWC* dst_shape = nullptr);
|
||||
|
||||
absl::Status CreateConvPowerVRDynamicWeights(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||
ConvPowerVR* result, const BHWC* dst_shape = nullptr);
|
||||
|
||||
absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||
|
Loading…
Reference in New Issue
Block a user