Added Create method that supports runtime weights for ConvolutionTransposed.
PiperOrigin-RevId: 342354367 Change-Id: I5325c8c8c5258a74b16ebb7a696aa9e0090b4fd7
This commit is contained in:
parent
134db25ce3
commit
aa56651979
@ -275,6 +275,7 @@ cc_library(
|
|||||||
srcs = ["convolution_transposed.cc"],
|
srcs = ["convolution_transposed.cc"],
|
||||||
hdrs = ["convolution_transposed.h"],
|
hdrs = ["convolution_transposed.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":conv_common",
|
||||||
":gpu_operation",
|
":gpu_operation",
|
||||||
":util",
|
":util",
|
||||||
":work_group_picking",
|
":work_group_picking",
|
||||||
|
@ -31,11 +31,10 @@ namespace cl {
|
|||||||
|
|
||||||
ConvolutionTransposed::ConvolutionTransposed(
|
ConvolutionTransposed::ConvolutionTransposed(
|
||||||
const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
|
const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
|
||||||
const GpuInfo& gpu_info)
|
const GpuInfo& gpu_info, bool weights_are_buffer)
|
||||||
: GPUOperation(definition),
|
: GPUOperation(definition),
|
||||||
stride_(attr.stride.w, attr.stride.h, 1, 1),
|
stride_(attr.stride.w, attr.stride.h, 1, 1),
|
||||||
block_size_(2, 2, 1, 2) {
|
block_size_(2, 2, 1, 2) {
|
||||||
const bool weights_are_buffer = gpu_info.IsMali();
|
|
||||||
const bool is_f16 = definition.precision == CalculationsPrecision::F16;
|
const bool is_f16 = definition.precision == CalculationsPrecision::F16;
|
||||||
if (gpu_info.IsMali()) {
|
if (gpu_info.IsMali()) {
|
||||||
if (gpu_info.mali_info.IsMidgard()) {
|
if (gpu_info.mali_info.IsMidgard()) {
|
||||||
@ -60,16 +59,15 @@ ConvolutionTransposed::ConvolutionTransposed(
|
|||||||
args_.AddInt("kernel_size_y", attr.weights.shape.h);
|
args_.AddInt("kernel_size_y", attr.weights.shape.h);
|
||||||
code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
|
code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
|
||||||
weights_are_buffer, block_size_);
|
weights_are_buffer, block_size_);
|
||||||
UploadWeights(attr.weights, weights_are_buffer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ConvolutionTransposed::ConvolutionTransposed(
|
ConvolutionTransposed::ConvolutionTransposed(
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr, const GpuInfo& gpu_info)
|
const ConvolutionTransposed3DAttributes& attr, const GpuInfo& gpu_info,
|
||||||
|
bool weights_are_buffer)
|
||||||
: GPUOperation(definition),
|
: GPUOperation(definition),
|
||||||
stride_(attr.stride.w, attr.stride.h, attr.stride.d, 1),
|
stride_(attr.stride.w, attr.stride.h, attr.stride.d, 1),
|
||||||
block_size_(2, 2, 1, 2) {
|
block_size_(2, 2, 1, 2) {
|
||||||
const bool weights_are_buffer = gpu_info.IsMali();
|
|
||||||
const bool is_f16 = definition.precision == CalculationsPrecision::F16;
|
const bool is_f16 = definition.precision == CalculationsPrecision::F16;
|
||||||
if (gpu_info.IsMali()) {
|
if (gpu_info.IsMali()) {
|
||||||
if (gpu_info.mali_info.IsMidgard()) {
|
if (gpu_info.mali_info.IsMidgard()) {
|
||||||
@ -98,7 +96,6 @@ ConvolutionTransposed::ConvolutionTransposed(
|
|||||||
args_.AddInt("grid_size_y");
|
args_.AddInt("grid_size_y");
|
||||||
code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
|
code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
|
||||||
weights_are_buffer, block_size_);
|
weights_are_buffer, block_size_);
|
||||||
UploadWeights(attr.weights, weights_are_buffer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ConvolutionTransposed::ConvolutionTransposed(ConvolutionTransposed&& operation)
|
ConvolutionTransposed::ConvolutionTransposed(ConvolutionTransposed&& operation)
|
||||||
@ -124,6 +121,15 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode(
|
|||||||
AddSrcTensor("src_tensor", src_desc);
|
AddSrcTensor("src_tensor", src_desc);
|
||||||
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
|
AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
|
||||||
|
|
||||||
|
if (op_def.src_tensors.size() == 2) {
|
||||||
|
// dynamic weights
|
||||||
|
BufferDescriptor desc;
|
||||||
|
desc.element_type = op_def.src_tensors[1].data_type;
|
||||||
|
desc.element_size = 16;
|
||||||
|
desc.memory_type = MemoryType::GLOBAL;
|
||||||
|
AddSrcBuffer("weights", desc);
|
||||||
|
}
|
||||||
|
|
||||||
const auto& src_def = op_def.src_tensors[0];
|
const auto& src_def = op_def.src_tensors[0];
|
||||||
|
|
||||||
std::string c = GetCommonDefines(op_def.precision);
|
std::string c = GetCommonDefines(op_def.precision);
|
||||||
@ -544,7 +550,9 @@ void ConvolutionTransposed::GetPossibleKernelWorkGroups(
|
|||||||
ConvolutionTransposed CreateConvolutionTransposed(
|
ConvolutionTransposed CreateConvolutionTransposed(
|
||||||
const GpuInfo& gpu_info, const OperationDef& definition,
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr) {
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
ConvolutionTransposed result(definition, attr, gpu_info);
|
const bool weights_are_buffer = gpu_info.IsMali();
|
||||||
|
ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
|
||||||
|
result.UploadWeights(attr.weights, weights_are_buffer);
|
||||||
|
|
||||||
TensorLinearDescriptor desc;
|
TensorLinearDescriptor desc;
|
||||||
desc.storage_type =
|
desc.storage_type =
|
||||||
@ -559,7 +567,25 @@ ConvolutionTransposed CreateConvolutionTransposed(
|
|||||||
ConvolutionTransposed CreateConvolutionTransposed3D(
|
ConvolutionTransposed CreateConvolutionTransposed3D(
|
||||||
const GpuInfo& gpu_info, const OperationDef& definition,
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr) {
|
const ConvolutionTransposed3DAttributes& attr) {
|
||||||
ConvolutionTransposed result(definition, attr, gpu_info);
|
const bool weights_are_buffer = gpu_info.IsMali();
|
||||||
|
ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
|
||||||
|
result.UploadWeights(attr.weights, weights_are_buffer);
|
||||||
|
|
||||||
|
TensorLinearDescriptor desc;
|
||||||
|
desc.storage_type =
|
||||||
|
DeduceLinearStorageType(definition.GetPrimaryStorageType());
|
||||||
|
desc.element_type = definition.GetDataType();
|
||||||
|
desc.UploadLinearData(attr.bias);
|
||||||
|
result.args_.AddObject(
|
||||||
|
"biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConvolutionTransposed CreateConvolutionTransposedDynamicWeights(
|
||||||
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
|
const ConvolutionTransposedAttributes& attr) {
|
||||||
|
const bool weights_are_buffer = true;
|
||||||
|
ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
|
||||||
|
|
||||||
TensorLinearDescriptor desc;
|
TensorLinearDescriptor desc;
|
||||||
desc.storage_type =
|
desc.storage_type =
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_common.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||||
@ -53,6 +54,13 @@ class ConvolutionTransposed : public GPUOperation {
|
|||||||
ConvolutionTransposed(const ConvolutionTransposed&) = delete;
|
ConvolutionTransposed(const ConvolutionTransposed&) = delete;
|
||||||
ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete;
|
ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete;
|
||||||
|
|
||||||
|
ConvWeightsDescription GetConvWeightsDescription() const {
|
||||||
|
ConvWeightsDescription desc;
|
||||||
|
desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
|
||||||
|
desc.output_group_size = block_size_.w;
|
||||||
|
return desc;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend ConvolutionTransposed CreateConvolutionTransposed(
|
friend ConvolutionTransposed CreateConvolutionTransposed(
|
||||||
const GpuInfo& gpu_info, const OperationDef& definition,
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
@ -60,12 +68,16 @@ class ConvolutionTransposed : public GPUOperation {
|
|||||||
friend ConvolutionTransposed CreateConvolutionTransposed3D(
|
friend ConvolutionTransposed CreateConvolutionTransposed3D(
|
||||||
const GpuInfo& gpu_info, const OperationDef& definition,
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr);
|
const ConvolutionTransposed3DAttributes& attr);
|
||||||
|
friend ConvolutionTransposed CreateConvolutionTransposedDynamicWeights(
|
||||||
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
|
const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
ConvolutionTransposed(const OperationDef& definition,
|
ConvolutionTransposed(const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
const GpuInfo& gpu_info);
|
const GpuInfo& gpu_info, bool weights_are_buffer);
|
||||||
ConvolutionTransposed(const OperationDef& definition,
|
ConvolutionTransposed(const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr,
|
const ConvolutionTransposed3DAttributes& attr,
|
||||||
const GpuInfo& gpu_info);
|
const GpuInfo& gpu_info, bool weights_are_buffer);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
void UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
@ -213,6 +225,10 @@ ConvolutionTransposed CreateConvolutionTransposed3D(
|
|||||||
const GpuInfo& gpu_info, const OperationDef& definition,
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr);
|
const ConvolutionTransposed3DAttributes& attr);
|
||||||
|
|
||||||
|
ConvolutionTransposed CreateConvolutionTransposedDynamicWeights(
|
||||||
|
const GpuInfo& gpu_info, const OperationDef& definition,
|
||||||
|
const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
x
Reference in New Issue
Block a user