Add BatchMatMul built-in op for TF Lite
PiperOrigin-RevId: 302489633 Change-Id: Ie4ad2abad069b1e5bc654fc51caf0bcbc99b714f
This commit is contained in:
parent
c6ec2565db
commit
828fe43cf3
|
@ -152,6 +152,7 @@ typedef enum {
|
|||
kTfLiteBuiltinSelectV2 = 123,
|
||||
kTfLiteBuiltinDensify = 124,
|
||||
kTfLiteBuiltinSegmentSum = 125,
|
||||
kTfLiteBuiltinBatchMatmul = 126,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -840,6 +840,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||
case BuiltinOperator_SCATTER_ND:
|
||||
case BuiltinOperator_DENSIFY:
|
||||
case BuiltinOperator_SEGMENT_SUM:
|
||||
case BuiltinOperator_BATCH_MATMUL:
|
||||
break;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
|
|
@ -426,6 +426,7 @@ cc_library(
|
|||
"arg_min_max.cc",
|
||||
"audio_spectrogram.cc",
|
||||
"basic_rnn.cc",
|
||||
"batch_matmul.cc",
|
||||
"batch_to_space_nd.cc",
|
||||
"bidirectional_sequence_lstm.cc",
|
||||
"bidirectional_sequence_rnn.cc",
|
||||
|
@ -849,6 +850,19 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "batch_matmul_test",
|
||||
size = "small",
|
||||
srcs = ["batch_matmul_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "cast_test",
|
||||
size = "small",
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace batch_matmul {
|
||||
|
||||
static const int kInputLHSTensor = 0;
|
||||
static const int kInputRHSTensor = 1;
|
||||
static const int kOutputTensor = 0;
|
||||
|
||||
// This file has two implementations of Transpose.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const RuntimeShape& extended_lhs_shape,
|
||||
const RuntimeShape& extended_rhs_shape,
|
||||
int output_rank, TfLiteTensor* output) {
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
|
||||
// Fill in any broadcast dimensions.
|
||||
for (int i = 0; i < output_rank - 2; ++i) {
|
||||
const int lhs_dim = extended_lhs_shape.Dims(i);
|
||||
const int rhs_dim = extended_rhs_shape.Dims(i);
|
||||
int broadcast_dim = lhs_dim;
|
||||
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
|
||||
broadcast_dim = rhs_dim;
|
||||
}
|
||||
output_shape->data[i] = broadcast_dim;
|
||||
}
|
||||
// Fill in the matmul dimensions.
|
||||
output_shape->data[output_rank - 2] =
|
||||
extended_lhs_shape.Dims(output_rank - 2);
|
||||
output_shape->data[output_rank - 1] =
|
||||
extended_rhs_shape.Dims(output_rank - 1);
|
||||
TfLiteStatus stat = context->ResizeTensor(context, output, output_shape);
|
||||
return stat;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor);
|
||||
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, lhs_data->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, rhs_data->type, kTfLiteFloat32);
|
||||
// Support dimensions between 2 and 5, inclusive.
|
||||
TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
|
||||
TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5);
|
||||
TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2);
|
||||
TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5);
|
||||
|
||||
const int lhs_rank = NumDimensions(lhs_data);
|
||||
const int rhs_rank = NumDimensions(rhs_data);
|
||||
const int output_rank = std::max(lhs_rank, rhs_rank);
|
||||
const RuntimeShape extended_lhs_shape =
|
||||
RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data));
|
||||
const RuntimeShape extended_rhs_shape =
|
||||
RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data));
|
||||
|
||||
// Ensure any batch dimensions obey broacasting rules.
|
||||
for (int i = 0; i < output_rank - 2; ++i) {
|
||||
const int lhs_dim = extended_lhs_shape.Dims(i);
|
||||
const int rhs_dim = extended_rhs_shape.Dims(i);
|
||||
if (lhs_dim != rhs_dim) {
|
||||
if (lhs_dim != 1) {
|
||||
TF_LITE_ENSURE_EQ(context, rhs_dim, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ensure other dimensions work for matrix multiplication.
|
||||
TF_LITE_ENSURE_EQ(context, extended_lhs_shape.Dims(output_rank - 1),
|
||||
extended_rhs_shape.Dims(output_rank - 2));
|
||||
return ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape,
|
||||
output_rank, output);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor);
|
||||
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
switch (lhs->type) {
|
||||
case kTfLiteFloat32:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
optimized_ops::BatchMatMul(
|
||||
GetTensorShape(lhs), GetTensorData<float>(lhs), GetTensorShape(rhs),
|
||||
GetTensorData<float>(rhs), GetTensorShape(output),
|
||||
GetTensorData<float>(output),
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
} else {
|
||||
reference_ops::BatchMatMul(
|
||||
GetTensorShape(lhs), GetTensorData<float>(lhs), GetTensorShape(rhs),
|
||||
GetTensorData<float>(rhs), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Currently BatchMatMul doesn't support type: %s",
|
||||
TfLiteTypeGetName(lhs->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace batch_matmul
|
||||
|
||||
TfLiteRegistration* Register_BATCH_MATMUL_REF() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, batch_matmul::Prepare,
|
||||
batch_matmul::Eval<batch_matmul::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() {
|
||||
static TfLiteRegistration r = {
|
||||
nullptr, nullptr, batch_matmul::Prepare,
|
||||
batch_matmul::Eval<batch_matmul::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_BATCH_MATMUL() {
|
||||
return Register_BATCH_MATMUL_GENERIC_OPTIMIZED();
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
|
@ -0,0 +1,169 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class BatchMatMulOpModel : public SingleOpModel {
|
||||
public:
|
||||
BatchMatMulOpModel(const TensorData& lhs, const TensorData& rhs) {
|
||||
lhs_id_ = AddInput(lhs);
|
||||
rhs_id_ = AddInput(rhs);
|
||||
output_id_ = AddOutput(lhs.type);
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_MATMUL, BuiltinOptions_NONE, 0);
|
||||
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
|
||||
}
|
||||
|
||||
int lhs() const { return lhs_id_; }
|
||||
int rhs() const { return rhs_id_; }
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
|
||||
std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); }
|
||||
|
||||
protected:
|
||||
int lhs_id_;
|
||||
int rhs_id_;
|
||||
int output_id_;
|
||||
};
|
||||
|
||||
TEST(BatchMatMulOpModelTest, Float32Test_Simple) {
|
||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
|
||||
{TensorType_FLOAT32, {1, 3, 4}});
|
||||
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<float>(model.rhs(),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
|
||||
104.0f, 257.0f}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) {
|
||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
||||
{TensorType_FLOAT32, {2, 3, 4}});
|
||||
model.PopulateTensor<float>(model.lhs(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
model.PopulateTensor<float>(model.rhs(),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
|
||||
104.0f, 257.0f, 482.0f, 662.0f, 554.0f, 761.0f,
|
||||
626.0f, 860.0f, 698.0f, 959.0f}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) {
|
||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
||||
{TensorType_FLOAT32, {3, 4}});
|
||||
model.PopulateTensor<float>(model.lhs(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
model.PopulateTensor<float>(model.rhs(),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
|
||||
104.0f, 257.0f, 194.0f, 266.0f, 266.0f, 365.0f,
|
||||
338.0f, 464.0f, 410.0f, 563.0f}));
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) {
|
||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
|
||||
{TensorType_FLOAT32, {3, 2, 4}});
|
||||
model.PopulateTensor<float>(model.lhs(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
model.PopulateTensor<float>(model.rhs(),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
|
||||
|
||||
model.Invoke();
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray(
|
||||
{23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f,
|
||||
127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f,
|
||||
123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f,
|
||||
71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f,
|
||||
303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f,
|
||||
181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f,
|
||||
233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f,
|
||||
449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f,
|
||||
485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f}));
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFiveD) {
|
||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 1, 3, 2}},
|
||||
{TensorType_FLOAT32, {3, 2, 4}});
|
||||
model.PopulateTensor<float>(model.lhs(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
model.PopulateTensor<float>(model.rhs(),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
|
||||
|
||||
model.Invoke();
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray(
|
||||
{23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f,
|
||||
127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f,
|
||||
123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f,
|
||||
71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f,
|
||||
303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f,
|
||||
181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f,
|
||||
233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f,
|
||||
449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f,
|
||||
485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f}));
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) {
|
||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {4, 5}},
|
||||
{TensorType_FLOAT32, {3, 1, 5, 2}});
|
||||
model.PopulateTensor<float>(
|
||||
model.lhs(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
|
||||
model.PopulateTensor<float>(
|
||||
model.rhs(),
|
||||
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36});
|
||||
|
||||
model.Invoke();
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({145.0f, 370.0f, 595.0f, 820.0f, 220.0f, 570.0f,
|
||||
920.0f, 1270.0f, 295.0f, 770.0f, 1245.0f, 1720.0f,
|
||||
370.0f, 970.0f, 1570.0f, 2170.0f, 445.0f, 1170.0f,
|
||||
1895.0f, 2620.0f, 520.0f, 1370.0f, 2220.0f, 3070.0f}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
|
@ -36,6 +36,7 @@ TfLiteRegistration* Register_ARG_MAX();
|
|||
TfLiteRegistration* Register_ARG_MIN();
|
||||
TfLiteRegistration* Register_AVERAGE_POOL_2D();
|
||||
TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
|
||||
TfLiteRegistration* Register_BATCH_MATMUL();
|
||||
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
|
||||
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
|
||||
TfLiteRegistration* Register_CAST();
|
||||
|
|
|
@ -213,6 +213,7 @@ cc_library(
|
|||
name = "optimized_base",
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"optimized/batch_matmul.h",
|
||||
"optimized/depthwiseconv_3x3_filter_common.h",
|
||||
"optimized/depthwiseconv_float.h",
|
||||
"optimized/depthwiseconv_multithread.h",
|
||||
|
@ -416,6 +417,7 @@ cc_library(
|
|||
hdrs = [
|
||||
"reference/add.h",
|
||||
"reference/arg_min_max.h",
|
||||
"reference/batch_matmul.h",
|
||||
"reference/binary_function.h",
|
||||
"reference/ceil.h",
|
||||
"reference/comparisons.h",
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimized_ops {
|
||||
|
||||
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
|
||||
const RuntimeShape& rhs_shape, const float* rhs_data,
|
||||
const RuntimeShape& output_shape, float* output_data,
|
||||
CpuBackendContext* context) {
|
||||
using ::tflite::cpu_backend_gemm::Gemm;
|
||||
using ::tflite::cpu_backend_gemm::GemmParams;
|
||||
using ::tflite::cpu_backend_gemm::MatrixParams;
|
||||
const RuntimeShape extended_lhs_shape =
|
||||
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||
const RuntimeShape extended_rhs_shape =
|
||||
RuntimeShape::ExtendedShape(5, rhs_shape);
|
||||
|
||||
// Determine which dimension is the broadcast dimension.
|
||||
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
|
||||
if (lhs_dim == rhs_dim) return lhs_dim;
|
||||
if (lhs_dim == 1) return rhs_dim;
|
||||
TFLITE_DCHECK_EQ(rhs_dim, 1);
|
||||
return lhs_dim;
|
||||
};
|
||||
|
||||
// Compute the "extent" for iterating on this dimension.
|
||||
// If we are broadcasting, then don't advance (i.e return 0).
|
||||
auto extent = [](const RuntimeShape& shape, int x) {
|
||||
if (shape.Dims(x) == 1) {
|
||||
return 0;
|
||||
}
|
||||
int prod = 1;
|
||||
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
|
||||
prod *= shape.Dims(i);
|
||||
}
|
||||
return prod;
|
||||
};
|
||||
|
||||
const int batch_dim0 =
|
||||
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
|
||||
const int batch_dim1 =
|
||||
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
|
||||
const int batch_dim2 =
|
||||
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
|
||||
|
||||
const int lhs_ext0 = extent(extended_lhs_shape, 0);
|
||||
const int lhs_ext1 = extent(extended_lhs_shape, 1);
|
||||
const int lhs_ext2 = extent(extended_lhs_shape, 2);
|
||||
const int rhs_ext0 = extent(extended_rhs_shape, 0);
|
||||
const int rhs_ext1 = extent(extended_rhs_shape, 1);
|
||||
const int rhs_ext2 = extent(extended_rhs_shape, 2);
|
||||
|
||||
// Set params for each matrix multiply.
|
||||
const int lhs_rows = extended_lhs_shape.Dims(3);
|
||||
const int rhs_cols = extended_rhs_shape.Dims(4);
|
||||
const int accum_depth = extended_lhs_shape.Dims(4);
|
||||
|
||||
MatrixParams<float> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.rows = lhs_rows;
|
||||
lhs_params.cols = accum_depth;
|
||||
|
||||
MatrixParams<float> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.rows = accum_depth;
|
||||
rhs_params.cols = rhs_cols;
|
||||
|
||||
MatrixParams<float> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.rows = lhs_rows;
|
||||
dst_params.cols = rhs_cols;
|
||||
|
||||
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||
const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||
const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||
const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||
const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||
const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||
const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||
float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
||||
b1 * batch_dim2 + b2) *
|
||||
lhs_rows * rhs_cols;
|
||||
GemmParams<float, float> gemm_params;
|
||||
cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
|
||||
dst_params, out_ptr, gemm_params, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
|
|
@ -0,0 +1,105 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
|
||||
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
|
||||
const RuntimeShape& rhs_shape, const float* rhs_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const RuntimeShape extended_lhs_shape =
|
||||
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||
const RuntimeShape extended_rhs_shape =
|
||||
RuntimeShape::ExtendedShape(5, rhs_shape);
|
||||
|
||||
// Determine which dimension is the broadcast dimension.
|
||||
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
|
||||
if (lhs_dim == rhs_dim) return lhs_dim;
|
||||
if (lhs_dim == 1) return rhs_dim;
|
||||
TFLITE_DCHECK_EQ(rhs_dim, 1);
|
||||
return lhs_dim;
|
||||
};
|
||||
|
||||
// Compute the "extent" for iterating on this dimension.
|
||||
// If we are broadcasting, then don't advance (i.e return 0).
|
||||
auto extent = [](const RuntimeShape& shape, int x) {
|
||||
if (shape.Dims(x) == 1) {
|
||||
return 0;
|
||||
}
|
||||
int prod = 1;
|
||||
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
|
||||
prod *= shape.Dims(i);
|
||||
}
|
||||
return prod;
|
||||
};
|
||||
|
||||
const int batch_dim0 =
|
||||
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
|
||||
const int batch_dim1 =
|
||||
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
|
||||
const int batch_dim2 =
|
||||
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
|
||||
|
||||
const int lhs_ext0 = extent(extended_lhs_shape, 0);
|
||||
const int lhs_ext1 = extent(extended_lhs_shape, 1);
|
||||
const int lhs_ext2 = extent(extended_lhs_shape, 2);
|
||||
const int rhs_ext0 = extent(extended_rhs_shape, 0);
|
||||
const int rhs_ext1 = extent(extended_rhs_shape, 1);
|
||||
const int rhs_ext2 = extent(extended_rhs_shape, 2);
|
||||
|
||||
// Set params for each matrix multiply.
|
||||
const int lhs_rows = extended_lhs_shape.Dims(3);
|
||||
const int rhs_cols = extended_rhs_shape.Dims(4);
|
||||
const int accum_depth = extended_lhs_shape.Dims(4);
|
||||
|
||||
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||
const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||
const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||
const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||
const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||
const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||
const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||
float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
||||
b1 * batch_dim2 + b2) *
|
||||
lhs_rows * rhs_cols;
|
||||
for (int j = 0; j < rhs_cols; ++j) {
|
||||
for (int i = 0; i < lhs_rows; ++i) {
|
||||
float total = 0.f;
|
||||
for (int k = 0; k < accum_depth; ++k) {
|
||||
total +=
|
||||
lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
|
||||
}
|
||||
int idx = lhs_rows * j + i;
|
||||
out_ptr[idx] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|
|
@ -278,6 +278,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
||||
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL());
|
||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
|
|
@ -134,6 +134,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF();
|
|||
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
|
||||
TfLiteRegistration* Register_SELECT_V2();
|
||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||
TfLiteRegistration* Register_BATCH_MATMUL_REF();
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
|
@ -346,7 +346,8 @@ enum BuiltinOperator : byte {
|
|||
SCATTER_ND = 122,
|
||||
SELECT_V2 = 123,
|
||||
DENSIFY = 124,
|
||||
SEGMENT_SUM = 125
|
||||
SEGMENT_SUM = 125,
|
||||
BATCH_MATMUL = 126
|
||||
}
|
||||
|
||||
|
||||
|
@ -451,7 +452,8 @@ union BuiltinOptions {
|
|||
ScatterNdOptions,
|
||||
SelectV2Options,
|
||||
DensifyOptions,
|
||||
SegmentSumOptions
|
||||
SegmentSumOptions,
|
||||
BatchMatMulOptions
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
|
@ -945,6 +947,9 @@ table DensifyOptions {
|
|||
table SegmentSumOptions {
|
||||
}
|
||||
|
||||
table BatchMatMulOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
|
|
@ -346,6 +346,9 @@ struct DensifyOptionsT;
|
|||
struct SegmentSumOptions;
|
||||
struct SegmentSumOptionsT;
|
||||
|
||||
struct BatchMatMulOptions;
|
||||
struct BatchMatMulOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
|
@ -771,11 +774,12 @@ enum BuiltinOperator {
|
|||
BuiltinOperator_SELECT_V2 = 123,
|
||||
BuiltinOperator_DENSIFY = 124,
|
||||
BuiltinOperator_SEGMENT_SUM = 125,
|
||||
BuiltinOperator_BATCH_MATMUL = 126,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM
|
||||
BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
|
@ -902,13 +906,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
|
|||
BuiltinOperator_SCATTER_ND,
|
||||
BuiltinOperator_SELECT_V2,
|
||||
BuiltinOperator_DENSIFY,
|
||||
BuiltinOperator_SEGMENT_SUM
|
||||
BuiltinOperator_SEGMENT_SUM,
|
||||
BuiltinOperator_BATCH_MATMUL
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOperator() {
|
||||
static const char * const names[127] = {
|
||||
static const char * const names[128] = {
|
||||
"ADD",
|
||||
"AVERAGE_POOL_2D",
|
||||
"CONCATENATION",
|
||||
|
@ -1035,13 +1040,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
|||
"SELECT_V2",
|
||||
"DENSIFY",
|
||||
"SEGMENT_SUM",
|
||||
"BATCH_MATMUL",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_SEGMENT_SUM)) return "";
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOperator()[index];
|
||||
}
|
||||
|
@ -1148,11 +1154,12 @@ enum BuiltinOptions {
|
|||
BuiltinOptions_SelectV2Options = 98,
|
||||
BuiltinOptions_DensifyOptions = 99,
|
||||
BuiltinOptions_SegmentSumOptions = 100,
|
||||
BuiltinOptions_BatchMatMulOptions = 101,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
|
@ -1254,13 +1261,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
|
|||
BuiltinOptions_ScatterNdOptions,
|
||||
BuiltinOptions_SelectV2Options,
|
||||
BuiltinOptions_DensifyOptions,
|
||||
BuiltinOptions_SegmentSumOptions
|
||||
BuiltinOptions_SegmentSumOptions,
|
||||
BuiltinOptions_BatchMatMulOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOptions() {
|
||||
static const char * const names[102] = {
|
||||
static const char * const names[103] = {
|
||||
"NONE",
|
||||
"Conv2DOptions",
|
||||
"DepthwiseConv2DOptions",
|
||||
|
@ -1362,13 +1370,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
|||
"SelectV2Options",
|
||||
"DensifyOptions",
|
||||
"SegmentSumOptions",
|
||||
"BatchMatMulOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_SegmentSumOptions)) return "";
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOptions()[index];
|
||||
}
|
||||
|
@ -1777,6 +1786,10 @@ template<> struct BuiltinOptionsTraits<tflite::SegmentSumOptions> {
|
|||
static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<tflite::BatchMatMulOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
|
@ -2609,6 +2622,14 @@ struct BuiltinOptionsUnion {
|
|||
return type == BuiltinOptions_SegmentSumOptions ?
|
||||
reinterpret_cast<const tflite::SegmentSumOptionsT *>(value) : nullptr;
|
||||
}
|
||||
tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() {
|
||||
return type == BuiltinOptions_BatchMatMulOptions ?
|
||||
reinterpret_cast<tflite::BatchMatMulOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() const {
|
||||
return type == BuiltinOptions_BatchMatMulOptions ?
|
||||
reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
|
@ -9109,6 +9130,46 @@ inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(
|
|||
|
||||
flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct BatchMatMulOptionsT : public flatbuffers::NativeTable {
|
||||
typedef BatchMatMulOptions TableType;
|
||||
BatchMatMulOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef BatchMatMulOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
BatchMatMulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<BatchMatMulOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct BatchMatMulOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit BatchMatMulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
BatchMatMulOptionsBuilder &operator=(const BatchMatMulOptionsBuilder &);
|
||||
flatbuffers::Offset<BatchMatMulOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<BatchMatMulOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
BatchMatMulOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
tflite::BuiltinOperator builtin_code;
|
||||
|
@ -9545,6 +9606,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||
const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const {
|
||||
return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast<const tflite::SegmentSumOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const {
|
||||
return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast<const tflite::BatchMatMulOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
|
@ -9981,6 +10045,10 @@ template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as<
|
|||
return builtin_options_as_SegmentSumOptions();
|
||||
}
|
||||
|
||||
template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as<tflite::BatchMatMulOptions>() const {
|
||||
return builtin_options_as_BatchMatMulOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
|
@ -13392,6 +13460,29 @@ inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffer
|
|||
_fbb);
|
||||
}
|
||||
|
||||
inline BatchMatMulOptionsT *BatchMatMulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new BatchMatMulOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void BatchMatMulOptions::UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<BatchMatMulOptions> BatchMatMulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateBatchMatMulOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchMatMulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateBatchMatMulOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
|
@ -14197,6 +14288,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
|||
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_BatchMatMulOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
|
@ -14615,6 +14710,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
|||
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_BatchMatMulOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -15021,6 +15120,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
|||
auto ptr = reinterpret_cast<const tflite::SegmentSumOptionsT *>(value);
|
||||
return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_BatchMatMulOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value);
|
||||
return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
@ -15427,6 +15530,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
|||
value = new tflite::SegmentSumOptionsT(*reinterpret_cast<tflite::SegmentSumOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_BatchMatMulOptions: {
|
||||
value = new tflite::BatchMatMulOptionsT(*reinterpret_cast<tflite::BatchMatMulOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -15934,6 +16041,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
|||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_BatchMatMulOptions: {
|
||||
auto ptr = reinterpret_cast<tflite::BatchMatMulOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
|
|
@ -57,6 +57,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||
{{OperatorType::kDiv, 1}, "1.6.0"},
|
||||
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},
|
||||
{{OperatorType::kBatchToSpaceND, 2}, "1.14.0"},
|
||||
{{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kCast, 1}, "1.5.0"},
|
||||
{{OperatorType::kConcatenation, 1}, "1.5.0"},
|
||||
{{OperatorType::kConcatenation, 2}, "1.14.0"},
|
||||
|
|
Loading…
Reference in New Issue