Added support of dynamic weights to ConvPowerVR.
PiperOrigin-RevId: 309335643 Change-Id: I7f2b0536eab2ed123e25b30de21d937a9a38204b
This commit is contained in:
parent
afa06187b9
commit
896a2700d1
@ -222,6 +222,7 @@ cc_library(
|
|||||||
srcs = ["conv_powervr.cc"],
|
srcs = ["conv_powervr.cc"],
|
||||||
hdrs = ["conv_powervr.h"],
|
hdrs = ["conv_powervr.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":conv_common",
|
||||||
":gpu_operation",
|
":gpu_operation",
|
||||||
":util",
|
":util",
|
||||||
":work_group_picking",
|
":work_group_picking",
|
||||||
|
@ -138,6 +138,18 @@ ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
|||||||
attr.dilations.w, attr.dilations.h),
|
attr.dilations.w, attr.dilations.h),
|
||||||
conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
|
conv_params_(GuessBestParams(device, definition, attr, dst_shape)) {}
|
||||||
|
|
||||||
|
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr,
|
||||||
|
const BHWC& weights_shape, const CLDevice& device,
|
||||||
|
const BHWC* dst_shape)
|
||||||
|
: GPUOperation(definition),
|
||||||
|
stride_padding_(attr.strides.w, attr.strides.h, -attr.padding.prepended.w,
|
||||||
|
-attr.padding.prepended.h),
|
||||||
|
kernel_dilation_(weights_shape.w, weights_shape.h, attr.dilations.w,
|
||||||
|
attr.dilations.h),
|
||||||
|
conv_params_(GuessBestParams(device, definition, attr, weights_shape,
|
||||||
|
dst_shape)) {}
|
||||||
|
|
||||||
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
ConvPowerVR::ConvPowerVR(const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
const CLDevice& device, const BHWC* dst_shape)
|
const CLDevice& device, const BHWC* dst_shape)
|
||||||
@ -192,7 +204,11 @@ absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) {
|
|||||||
absl::Status ConvPowerVR::BindArguments() {
|
absl::Status ConvPowerVR::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
if (definition_.src_tensors.size() == 1) {
|
||||||
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
|
} else {
|
||||||
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||||
|
}
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
@ -821,6 +837,22 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
|||||||
x_kernel_is_1, y_kernel_is_1, false, dst_shape);
|
x_kernel_is_1, y_kernel_is_1, false, dst_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||||
|
const CLDevice& device, const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||||
|
const BHWC* dst_shape) const {
|
||||||
|
const int dst_depth = DivideRoundUp(weights_shape.b, 4);
|
||||||
|
const int src_depth = DivideRoundUp(weights_shape.c, 4);
|
||||||
|
const bool x_kernel_is_1 =
|
||||||
|
weights_shape.w == 1 && attr.strides.w == 1 && attr.dilations.w == 1 &&
|
||||||
|
attr.padding.prepended.w == 0 && attr.padding.appended.w == 0;
|
||||||
|
const bool y_kernel_is_1 =
|
||||||
|
weights_shape.h == 1 && attr.strides.h == 1 && attr.dilations.h == 1 &&
|
||||||
|
attr.padding.prepended.h == 0 && attr.padding.appended.h == 0;
|
||||||
|
return GuessBestParams(device, definition, src_depth, dst_depth,
|
||||||
|
x_kernel_is_1, y_kernel_is_1, false, dst_shape);
|
||||||
|
}
|
||||||
|
|
||||||
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
|
||||||
const CLDevice& device, const OperationDef& definition,
|
const CLDevice& device, const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr, const BHWC* dst_shape) const {
|
const FullyConnectedAttributes& attr, const BHWC* dst_shape) const {
|
||||||
@ -863,6 +895,20 @@ absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
|||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status CreateConvPowerVRDynamicWeights(
|
||||||
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||||
|
ConvPowerVR* result, const BHWC* dst_shape) {
|
||||||
|
*result = ConvPowerVR(definition, attr, weights_shape,
|
||||||
|
*creation_context.device, dst_shape);
|
||||||
|
LinearStorageCreateInfo create_info;
|
||||||
|
create_info.storage_type = LinearStorageType::BUFFER;
|
||||||
|
create_info.data_type = result->conv_params_.weights_data_type;
|
||||||
|
create_info.aligned_size = weights_shape.b;
|
||||||
|
return CreateLinearStorage(create_info, attr.bias, creation_context.context,
|
||||||
|
&result->biases_);
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status CreateConvPowerVRWino4x4To6x6(
|
absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||||
@ -44,6 +45,13 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
absl::Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
absl::Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
|
ConvWeightsDescription GetConvWeightsDescription() const {
|
||||||
|
ConvWeightsDescription desc;
|
||||||
|
desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
|
||||||
|
desc.output_group_size = conv_params_.block_size.z;
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvPowerVR(ConvPowerVR&& operation);
|
ConvPowerVR(ConvPowerVR&& operation);
|
||||||
ConvPowerVR& operator=(ConvPowerVR&& operation);
|
ConvPowerVR& operator=(ConvPowerVR&& operation);
|
||||||
@ -82,6 +90,9 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
ConvPowerVR(const OperationDef& definition,
|
ConvPowerVR(const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, const CLDevice& device,
|
const Convolution2DAttributes& attr, const CLDevice& device,
|
||||||
const BHWC* dst_shape = nullptr);
|
const BHWC* dst_shape = nullptr);
|
||||||
|
ConvPowerVR(const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||||
|
const CLDevice& device, const BHWC* dst_shape = nullptr);
|
||||||
ConvPowerVR(const OperationDef& definition,
|
ConvPowerVR(const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr, const CLDevice& device,
|
const FullyConnectedAttributes& attr, const CLDevice& device,
|
||||||
const BHWC* dst_shape = nullptr);
|
const BHWC* dst_shape = nullptr);
|
||||||
@ -112,6 +123,11 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
ConvPowerVR* result,
|
ConvPowerVR* result,
|
||||||
const BHWC* dst_shape);
|
const BHWC* dst_shape);
|
||||||
|
|
||||||
|
friend absl::Status CreateConvPowerVRDynamicWeights(
|
||||||
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||||
|
ConvPowerVR* result, const BHWC* dst_shape);
|
||||||
|
|
||||||
friend absl::Status CreateConvPowerVRWino4x4To6x6(
|
friend absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||||
@ -126,6 +142,11 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
const BHWC* dst_shape = nullptr) const;
|
const BHWC* dst_shape = nullptr) const;
|
||||||
|
ConvParams GuessBestParams(const CLDevice& device,
|
||||||
|
const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr,
|
||||||
|
const BHWC& weights_shape,
|
||||||
|
const BHWC* dst_shape = nullptr) const;
|
||||||
ConvParams GuessBestParams(const CLDevice& device,
|
ConvParams GuessBestParams(const CLDevice& device,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
@ -225,6 +246,11 @@ absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
|||||||
ConvPowerVR* result,
|
ConvPowerVR* result,
|
||||||
const BHWC* dst_shape = nullptr);
|
const BHWC* dst_shape = nullptr);
|
||||||
|
|
||||||
|
absl::Status CreateConvPowerVRDynamicWeights(
|
||||||
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
|
const Convolution2DAttributes& attr, const BHWC& weights_shape,
|
||||||
|
ConvPowerVR* result, const BHWC* dst_shape = nullptr);
|
||||||
|
|
||||||
absl::Status CreateConvPowerVRWino4x4To6x6(
|
absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||||
|
Loading…
Reference in New Issue
Block a user