Added BatchedMatMul parser for model reader.
Added OpenCL transformation to execute Batched mat mul with batch = 1. PiperOrigin-RevId: 330033938 Change-Id: Ieabbe49a063373efe4d6d1647cdd8a100db1c38c
This commit is contained in:
parent
7f2e2bb276
commit
6a00b43334
@ -111,6 +111,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:mean_stddev_normalization",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:reduce",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
|
||||
"//tensorflow/lite/delegates/gpu/cl/selectors:default_selector",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h"
|
||||
@ -165,6 +166,80 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
|
||||
return absl::UnimplementedError(absl::StrCat(
|
||||
"No support of ", node.operation.type, " with this parameters"));
|
||||
}
|
||||
case OperationType::BATCHED_MATMUL: {
|
||||
// Currently only batch = 1 is supported.
|
||||
// Matmul replaced with this sequence:
|
||||
// 1) Transpose second tensor(weights). (1x1xHxW)->(Wx1x1xH)
|
||||
// 2) Convert second tensor(weights) from 1) to Convolution weights
|
||||
// 3) Run usual convolution
|
||||
auto second_shape = inputs[1]->tensor.shape;
|
||||
auto dst_shape = outputs[0]->tensor.shape;
|
||||
if (dst_shape.b != 1) {
|
||||
return absl::UnimplementedError(
|
||||
"Currently only batch = 1 supported for BATCHED_MATMUL.");
|
||||
}
|
||||
BHWC weights_shape(second_shape.c, 1, 1, second_shape.w);
|
||||
Convolution2DAttributes attr;
|
||||
attr.strides = HW(1, 1);
|
||||
attr.dilations = HW(1, 1);
|
||||
attr.padding.appended = HW(0, 0);
|
||||
attr.padding.prepended = HW(0, 0);
|
||||
attr.bias.shape = Linear(weights_shape.b);
|
||||
attr.bias.data.resize(weights_shape.b, 0.0f);
|
||||
|
||||
TensorDescriptor transposed_desc = {op_def.src_tensors[1].data_type,
|
||||
op_def.src_tensors[1].storage_type,
|
||||
Layout::BHWC};
|
||||
transposed_desc.storage_type = SelectBestStorageType(
|
||||
device_info, weights_shape, transposed_desc.storage_type,
|
||||
transposed_desc.data_type, transposed_desc.layout);
|
||||
TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type,
|
||||
TensorStorageType::BUFFER, Layout::BHWC};
|
||||
gpu_subgraph->operations.clear();
|
||||
gpu_subgraph->operations.resize(3);
|
||||
auto& transpose_op = gpu_subgraph->operations[0];
|
||||
auto& converter_op = gpu_subgraph->operations[1];
|
||||
auto& conv_op = gpu_subgraph->operations[2];
|
||||
conv_op.input_ids = {static_cast<int>(inputs[0]->id), -1};
|
||||
conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
|
||||
OperationDef conv_def = op_def;
|
||||
conv_def.src_tensors[1] = weights_desc;
|
||||
ConvWeightsDescription conv_weights_desc;
|
||||
conv_op.operation = SelectConvolutionWithDynamicWeights(
|
||||
attr, weights_shape, dst_shape, device_info, conv_def, hints,
|
||||
&conv_weights_desc);
|
||||
|
||||
int aligned_output =
|
||||
AlignByN(weights_shape.b, conv_weights_desc.output_group_size * 4);
|
||||
int aligned_input = AlignByN(weights_shape.c, 4);
|
||||
gpu_subgraph->new_tensors = {{BHWC(1, 1, 1,
|
||||
aligned_output * aligned_input *
|
||||
weights_shape.h * weights_shape.w),
|
||||
weights_desc},
|
||||
{weights_shape, transposed_desc}};
|
||||
OperationDef converter_def;
|
||||
converter_def.precision = op_def.precision;
|
||||
converter_def.src_tensors.push_back(transposed_desc);
|
||||
converter_def.dst_tensors.push_back(weights_desc);
|
||||
|
||||
converter_op.input_ids = {-2};
|
||||
converter_op.output_ids = {-1};
|
||||
converter_op.operation =
|
||||
SelectConverterToConvWeights(conv_weights_desc, converter_def, hints);
|
||||
|
||||
OperationDef transpose_def;
|
||||
transpose_def.precision = op_def.precision;
|
||||
transpose_def.src_tensors.push_back(op_def.src_tensors[1]);
|
||||
transpose_def.dst_tensors.push_back(transposed_desc);
|
||||
|
||||
transpose_op.input_ids = {static_cast<int>(inputs[1]->id)};
|
||||
transpose_op.output_ids = {-2};
|
||||
TransposeAttributes transpose_attr;
|
||||
transpose_attr.perm = BHWC(3, 0, 1, 2);
|
||||
transpose_op.operation = absl::make_unique<GPUOperation>(
|
||||
CreateTranspose(transpose_def, transpose_attr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::CONCAT: {
|
||||
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
|
||||
std::vector<int> channels(inputs.size());
|
||||
|
@ -299,6 +299,27 @@ class AddOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
};
|
||||
|
||||
class BatchedMatMulOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
return CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/2, /*outputs=*/1);
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::BATCHED_MATMUL);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class ConcatenationOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
@ -2563,6 +2584,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
return std::make_unique<AddOperationParser>();
|
||||
case kTfLiteBuiltinAveragePool2d:
|
||||
return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
|
||||
case kTfLiteBuiltinBatchMatmul:
|
||||
return std::make_unique<BatchedMatMulOperationParser>();
|
||||
case kTfLiteBuiltinConcatenation:
|
||||
return std::make_unique<ConcatenationOperationParser>();
|
||||
case kTfLiteBuiltinConv2d:
|
||||
|
Loading…
x
Reference in New Issue
Block a user