Added BatchedMatMul parser for model reader.

Added OpenCL transformation to execute Batched mat mul with batch = 1.

PiperOrigin-RevId: 330033938
Change-Id: Ieabbe49a063373efe4d6d1647cdd8a100db1c38c
This commit is contained in:
Raman Sarokin 2020-09-03 18:42:52 -07:00 committed by TensorFlower Gardener
parent 7f2e2bb276
commit 6a00b43334
3 changed files with 99 additions and 0 deletions

View File

@ -111,6 +111,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
"//tensorflow/lite/delegates/gpu/cl/kernels:mean_stddev_normalization",
"//tensorflow/lite/delegates/gpu/cl/kernels:reduce",
"//tensorflow/lite/delegates/gpu/cl/kernels:transpose",
"//tensorflow/lite/delegates/gpu/cl/selectors:default_selector",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:model",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/elementwise.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/mean_stddev_normalization.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/reduce.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/transpose.h"
#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.h"
#include "tensorflow/lite/delegates/gpu/cl/selectors/convolution_transposed_selector.h"
#include "tensorflow/lite/delegates/gpu/cl/selectors/default_selector.h"
@ -165,6 +166,80 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info,
return absl::UnimplementedError(absl::StrCat(
"No support of ", node.operation.type, " with this parameters"));
}
case OperationType::BATCHED_MATMUL: {
// Currently only batch = 1 is supported.
// Matmul replaced with this sequence:
// 1) Transpose second tensor(weights). (1x1xHxW)->(Wx1x1xH)
// 2) Convert second tensor(weights) from 1) to Convolution weights
// 3) Run usual convolution
auto second_shape = inputs[1]->tensor.shape;
auto dst_shape = outputs[0]->tensor.shape;
if (dst_shape.b != 1) {
return absl::UnimplementedError(
"Currently only batch = 1 supported for BATCHED_MATMUL.");
}
BHWC weights_shape(second_shape.c, 1, 1, second_shape.w);
Convolution2DAttributes attr;
attr.strides = HW(1, 1);
attr.dilations = HW(1, 1);
attr.padding.appended = HW(0, 0);
attr.padding.prepended = HW(0, 0);
attr.bias.shape = Linear(weights_shape.b);
attr.bias.data.resize(weights_shape.b, 0.0f);
TensorDescriptor transposed_desc = {op_def.src_tensors[1].data_type,
op_def.src_tensors[1].storage_type,
Layout::BHWC};
transposed_desc.storage_type = SelectBestStorageType(
device_info, weights_shape, transposed_desc.storage_type,
transposed_desc.data_type, transposed_desc.layout);
TensorDescriptor weights_desc = {op_def.src_tensors[1].data_type,
TensorStorageType::BUFFER, Layout::BHWC};
gpu_subgraph->operations.clear();
gpu_subgraph->operations.resize(3);
auto& transpose_op = gpu_subgraph->operations[0];
auto& converter_op = gpu_subgraph->operations[1];
auto& conv_op = gpu_subgraph->operations[2];
conv_op.input_ids = {static_cast<int>(inputs[0]->id), -1};
conv_op.output_ids = {static_cast<int>(outputs[0]->id)};
OperationDef conv_def = op_def;
conv_def.src_tensors[1] = weights_desc;
ConvWeightsDescription conv_weights_desc;
conv_op.operation = SelectConvolutionWithDynamicWeights(
attr, weights_shape, dst_shape, device_info, conv_def, hints,
&conv_weights_desc);
int aligned_output =
AlignByN(weights_shape.b, conv_weights_desc.output_group_size * 4);
int aligned_input = AlignByN(weights_shape.c, 4);
gpu_subgraph->new_tensors = {{BHWC(1, 1, 1,
aligned_output * aligned_input *
weights_shape.h * weights_shape.w),
weights_desc},
{weights_shape, transposed_desc}};
OperationDef converter_def;
converter_def.precision = op_def.precision;
converter_def.src_tensors.push_back(transposed_desc);
converter_def.dst_tensors.push_back(weights_desc);
converter_op.input_ids = {-2};
converter_op.output_ids = {-1};
converter_op.operation =
SelectConverterToConvWeights(conv_weights_desc, converter_def, hints);
OperationDef transpose_def;
transpose_def.precision = op_def.precision;
transpose_def.src_tensors.push_back(op_def.src_tensors[1]);
transpose_def.dst_tensors.push_back(transposed_desc);
transpose_op.input_ids = {static_cast<int>(inputs[1]->id)};
transpose_op.output_ids = {-2};
TransposeAttributes transpose_attr;
transpose_attr.perm = BHWC(3, 0, 1, 2);
transpose_op.operation = absl::make_unique<GPUOperation>(
CreateTranspose(transpose_def, transpose_attr));
return absl::OkStatus();
}
case OperationType::CONCAT: {
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
std::vector<int> channels(inputs.size());

View File

@ -299,6 +299,27 @@ class AddOperationParser : public TFLiteOperationParser {
}
};
class BatchedMatMulOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1);
}
absl::Status Parse(const TfLiteNode* tflite_node,
const TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader* reader) final {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::BATCHED_MATMUL);
RETURN_IF_ERROR(reader->AddInput(node, 0));
RETURN_IF_ERROR(reader->AddInput(node, 1));
RETURN_IF_ERROR(reader->AddOutputs(node));
return absl::OkStatus();
}
};
class ConcatenationOperationParser : public TFLiteOperationParser {
public:
absl::Status IsSupported(const TfLiteContext* context,
@ -2563,6 +2584,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
return std::make_unique<AddOperationParser>();
case kTfLiteBuiltinAveragePool2d:
return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
case kTfLiteBuiltinBatchMatmul:
return std::make_unique<BatchedMatMulOperationParser>();
case kTfLiteBuiltinConcatenation:
return std::make_unique<ConcatenationOperationParser>();
case kTfLiteBuiltinConv2d: