diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index c77ba39b9f3..520145d7d5b 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -111,6 +111,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:mean_stddev_normalization", "//tensorflow/lite/delegates/gpu/cl/kernels:reduce", + "//tensorflow/lite/delegates/gpu/cl/kernels:transpose", "//tensorflow/lite/delegates/gpu/cl/selectors:default_selector", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index 0bd78103409..9b0a169e5b4 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h" #include "tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h" @@ -165,6 +166,80 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, return absl::UnimplementedError(absl::StrCat( "No support of ", node.operation.type, " with this parameters")); } + case OperationType::BATCHED_MATMUL: { + // Currently only batch = 1 is supported. + // Matmul replaced with this sequence: + // 1) Transpose second tensor(weights). (1x1xHxW)->(Wx1x1xH) + // 2) Convert second tensor(weights) from 1) to Convolution weights + // 3) Run usual convolution + auto second_shape = inputs[1]->tensor.shape; + auto dst_shape = outputs[0]->tensor.shape; + if (dst_shape.b != 1) { + return absl::UnimplementedError( + "Currently only batch = 1 supported for BATCHED_MATMUL."); + } + BHWC weights_shape(second_shape.c, 1, 1, second_shape.w); + Convolution2DAttributes attr; + attr.strides = HW(1, 1); + attr.dilations = HW(1, 1); + attr.padding.appended = HW(0, 0); + attr.padding.prepended = HW(0, 0); + attr.bias.shape = Linear(weights_shape.b); + attr.bias.data.resize(weights_shape.b, 0.0f); + + TensorDescriptor transposed_desc = {op_def.src_tensors[1].data_type, + op_def.src_tensors[1].storage_type, + Layout::BHWC}; + transposed_desc.storage_type = SelectBestStorageType( + device_info, weights_shape, transposed_desc.storage_type, + transposed_desc.data_type, transposed_desc.layout); + TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type, + TensorStorageType::BUFFER, Layout::BHWC}; + gpu_subgraph->operations.clear(); + gpu_subgraph->operations.resize(3); + auto& transpose_op = gpu_subgraph->operations[0]; + auto& converter_op = gpu_subgraph->operations[1]; + auto& conv_op = gpu_subgraph->operations[2]; + conv_op.input_ids = {static_cast(inputs[0]->id), -1}; + conv_op.output_ids = {static_cast(outputs[0]->id)}; + OperationDef conv_def = op_def; + conv_def.src_tensors[1] = weights_desc; + ConvWeightsDescription conv_weights_desc; + conv_op.operation = SelectConvolutionWithDynamicWeights( + attr, weights_shape, dst_shape, device_info, conv_def, hints, + &conv_weights_desc); + + int aligned_output = + AlignByN(weights_shape.b, conv_weights_desc.output_group_size * 4); + int aligned_input = AlignByN(weights_shape.c, 4); + gpu_subgraph->new_tensors = {{BHWC(1, 1, 1, + aligned_output * aligned_input * + weights_shape.h * weights_shape.w), + weights_desc}, + {weights_shape, transposed_desc}}; + OperationDef converter_def; + converter_def.precision = op_def.precision; + converter_def.src_tensors.push_back(transposed_desc); + converter_def.dst_tensors.push_back(weights_desc); + + converter_op.input_ids = {-2}; + converter_op.output_ids = {-1}; + converter_op.operation = + SelectConverterToConvWeights(conv_weights_desc, converter_def, hints); + + OperationDef transpose_def; + transpose_def.precision = op_def.precision; + transpose_def.src_tensors.push_back(op_def.src_tensors[1]); + transpose_def.dst_tensors.push_back(transposed_desc); + + transpose_op.input_ids = {static_cast(inputs[1]->id)}; + transpose_op.output_ids = {-2}; + TransposeAttributes transpose_attr; + transpose_attr.perm = BHWC(3, 0, 1, 2); + transpose_op.operation = absl::make_unique( + CreateTranspose(transpose_def, transpose_attr)); + return absl::OkStatus(); + } case OperationType::CONCAT: { auto attr = absl::any_cast(node.operation.attributes); std::vector channels(inputs.size()); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index c116db1f039..bf904a03e78 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -299,6 +299,27 @@ class AddOperationParser : public TFLiteOperationParser { } }; +class BatchedMatMulOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + return CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + Node* node = graph->NewNode(); + node->operation.type = ToString(OperationType::BATCHED_MATMUL); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddInput(node, 1)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + return absl::OkStatus(); + } +}; + class ConcatenationOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -2563,6 +2584,8 @@ std::unique_ptr NewOperationParser( return std::make_unique(); case kTfLiteBuiltinAveragePool2d: return std::make_unique(PoolingType::AVERAGE); + case kTfLiteBuiltinBatchMatmul: + return std::make_unique(); case kTfLiteBuiltinConcatenation: return std::make_unique(); case kTfLiteBuiltinConv2d: