Add BatchMatMul built-in op for TF Lite
PiperOrigin-RevId: 302489633 Change-Id: Ie4ad2abad069b1e5bc654fc51caf0bcbc99b714f
This commit is contained in:
parent
c6ec2565db
commit
828fe43cf3
|
@ -152,6 +152,7 @@ typedef enum {
|
||||||
kTfLiteBuiltinSelectV2 = 123,
|
kTfLiteBuiltinSelectV2 = 123,
|
||||||
kTfLiteBuiltinDensify = 124,
|
kTfLiteBuiltinDensify = 124,
|
||||||
kTfLiteBuiltinSegmentSum = 125,
|
kTfLiteBuiltinSegmentSum = 125,
|
||||||
|
kTfLiteBuiltinBatchMatmul = 126,
|
||||||
} TfLiteBuiltinOperator;
|
} TfLiteBuiltinOperator;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
@ -840,6 +840,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||||
case BuiltinOperator_SCATTER_ND:
|
case BuiltinOperator_SCATTER_ND:
|
||||||
case BuiltinOperator_DENSIFY:
|
case BuiltinOperator_DENSIFY:
|
||||||
case BuiltinOperator_SEGMENT_SUM:
|
case BuiltinOperator_SEGMENT_SUM:
|
||||||
|
case BuiltinOperator_BATCH_MATMUL:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
|
|
@ -426,6 +426,7 @@ cc_library(
|
||||||
"arg_min_max.cc",
|
"arg_min_max.cc",
|
||||||
"audio_spectrogram.cc",
|
"audio_spectrogram.cc",
|
||||||
"basic_rnn.cc",
|
"basic_rnn.cc",
|
||||||
|
"batch_matmul.cc",
|
||||||
"batch_to_space_nd.cc",
|
"batch_to_space_nd.cc",
|
||||||
"bidirectional_sequence_lstm.cc",
|
"bidirectional_sequence_lstm.cc",
|
||||||
"bidirectional_sequence_rnn.cc",
|
"bidirectional_sequence_rnn.cc",
|
||||||
|
@ -849,6 +850,19 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "batch_matmul_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["batch_matmul_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":builtin_ops",
|
||||||
|
":test_main",
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "cast_test",
|
name = "cast_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
|
@ -0,0 +1,156 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
namespace batch_matmul {
|
||||||
|
|
||||||
|
static const int kInputLHSTensor = 0;
|
||||||
|
static const int kInputRHSTensor = 1;
|
||||||
|
static const int kOutputTensor = 0;
|
||||||
|
|
||||||
|
// This file has two implementations of Transpose.
|
||||||
|
enum KernelType {
|
||||||
|
kReference,
|
||||||
|
kGenericOptimized,
|
||||||
|
};
|
||||||
|
|
||||||
|
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||||
|
const RuntimeShape& extended_lhs_shape,
|
||||||
|
const RuntimeShape& extended_rhs_shape,
|
||||||
|
int output_rank, TfLiteTensor* output) {
|
||||||
|
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
|
||||||
|
// Fill in any broadcast dimensions.
|
||||||
|
for (int i = 0; i < output_rank - 2; ++i) {
|
||||||
|
const int lhs_dim = extended_lhs_shape.Dims(i);
|
||||||
|
const int rhs_dim = extended_rhs_shape.Dims(i);
|
||||||
|
int broadcast_dim = lhs_dim;
|
||||||
|
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
|
||||||
|
broadcast_dim = rhs_dim;
|
||||||
|
}
|
||||||
|
output_shape->data[i] = broadcast_dim;
|
||||||
|
}
|
||||||
|
// Fill in the matmul dimensions.
|
||||||
|
output_shape->data[output_rank - 2] =
|
||||||
|
extended_lhs_shape.Dims(output_rank - 2);
|
||||||
|
output_shape->data[output_rank - 1] =
|
||||||
|
extended_rhs_shape.Dims(output_rank - 1);
|
||||||
|
TfLiteStatus stat = context->ResizeTensor(context, output, output_shape);
|
||||||
|
return stat;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor);
|
||||||
|
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, lhs_data->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_EQ(context, rhs_data->type, kTfLiteFloat32);
|
||||||
|
// Support dimensions between 2 and 5, inclusive.
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2);
|
||||||
|
TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5);
|
||||||
|
|
||||||
|
const int lhs_rank = NumDimensions(lhs_data);
|
||||||
|
const int rhs_rank = NumDimensions(rhs_data);
|
||||||
|
const int output_rank = std::max(lhs_rank, rhs_rank);
|
||||||
|
const RuntimeShape extended_lhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data));
|
||||||
|
const RuntimeShape extended_rhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data));
|
||||||
|
|
||||||
|
// Ensure any batch dimensions obey broacasting rules.
|
||||||
|
for (int i = 0; i < output_rank - 2; ++i) {
|
||||||
|
const int lhs_dim = extended_lhs_shape.Dims(i);
|
||||||
|
const int rhs_dim = extended_rhs_shape.Dims(i);
|
||||||
|
if (lhs_dim != rhs_dim) {
|
||||||
|
if (lhs_dim != 1) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, rhs_dim, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Ensure other dimensions work for matrix multiplication.
|
||||||
|
TF_LITE_ENSURE_EQ(context, extended_lhs_shape.Dims(output_rank - 1),
|
||||||
|
extended_rhs_shape.Dims(output_rank - 2));
|
||||||
|
return ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape,
|
||||||
|
output_rank, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <KernelType kernel_type>
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor);
|
||||||
|
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
switch (lhs->type) {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
optimized_ops::BatchMatMul(
|
||||||
|
GetTensorShape(lhs), GetTensorData<float>(lhs), GetTensorShape(rhs),
|
||||||
|
GetTensorData<float>(rhs), GetTensorShape(output),
|
||||||
|
GetTensorData<float>(output),
|
||||||
|
CpuBackendContext::GetFromContext(context));
|
||||||
|
} else {
|
||||||
|
reference_ops::BatchMatMul(
|
||||||
|
GetTensorShape(lhs), GetTensorData<float>(lhs), GetTensorShape(rhs),
|
||||||
|
GetTensorData<float>(rhs), GetTensorShape(output),
|
||||||
|
GetTensorData<float>(output));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context,
|
||||||
|
"Currently BatchMatMul doesn't support type: %s",
|
||||||
|
TfLiteTypeGetName(lhs->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace batch_matmul
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL_REF() {
|
||||||
|
static TfLiteRegistration r = {nullptr, nullptr, batch_matmul::Prepare,
|
||||||
|
batch_matmul::Eval<batch_matmul::kReference>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
nullptr, nullptr, batch_matmul::Prepare,
|
||||||
|
batch_matmul::Eval<batch_matmul::kGenericOptimized>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL() {
|
||||||
|
return Register_BATCH_MATMUL_GENERIC_OPTIMIZED();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
|
@ -0,0 +1,169 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BatchMatMulOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
BatchMatMulOpModel(const TensorData& lhs, const TensorData& rhs) {
|
||||||
|
lhs_id_ = AddInput(lhs);
|
||||||
|
rhs_id_ = AddInput(rhs);
|
||||||
|
output_id_ = AddOutput(lhs.type);
|
||||||
|
SetBuiltinOp(BuiltinOperator_BATCH_MATMUL, BuiltinOptions_NONE, 0);
|
||||||
|
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
int lhs() const { return lhs_id_; }
|
||||||
|
int rhs() const { return rhs_id_; }
|
||||||
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
|
||||||
|
std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int lhs_id_;
|
||||||
|
int rhs_id_;
|
||||||
|
int output_id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(BatchMatMulOpModelTest, Float32Test_Simple) {
|
||||||
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
|
||||||
|
{TensorType_FLOAT32, {1, 3, 4}});
|
||||||
|
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
|
||||||
|
model.PopulateTensor<float>(model.rhs(),
|
||||||
|
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
|
||||||
|
104.0f, 257.0f}));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) {
|
||||||
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
||||||
|
{TensorType_FLOAT32, {2, 3, 4}});
|
||||||
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
model.PopulateTensor<float>(model.rhs(),
|
||||||
|
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
|
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
|
||||||
|
104.0f, 257.0f, 482.0f, 662.0f, 554.0f, 761.0f,
|
||||||
|
626.0f, 860.0f, 698.0f, 959.0f}));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) {
|
||||||
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
||||||
|
{TensorType_FLOAT32, {3, 4}});
|
||||||
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
model.PopulateTensor<float>(model.rhs(),
|
||||||
|
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||||
|
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
|
||||||
|
104.0f, 257.0f, 194.0f, 266.0f, 266.0f, 365.0f,
|
||||||
|
338.0f, 464.0f, 410.0f, 563.0f}));
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) {
|
||||||
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
|
||||||
|
{TensorType_FLOAT32, {3, 2, 4}});
|
||||||
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
model.PopulateTensor<float>(model.rhs(),
|
||||||
|
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
|
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
|
||||||
|
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetOutput(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f,
|
||||||
|
127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f,
|
||||||
|
123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f,
|
||||||
|
71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f,
|
||||||
|
303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f,
|
||||||
|
181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f,
|
||||||
|
233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f,
|
||||||
|
449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f,
|
||||||
|
485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f}));
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFiveD) {
|
||||||
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 1, 3, 2}},
|
||||||
|
{TensorType_FLOAT32, {3, 2, 4}});
|
||||||
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
model.PopulateTensor<float>(model.rhs(),
|
||||||
|
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||||
|
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
|
||||||
|
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetOutput(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f,
|
||||||
|
127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f,
|
||||||
|
123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f,
|
||||||
|
71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f,
|
||||||
|
303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f,
|
||||||
|
181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f,
|
||||||
|
233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f,
|
||||||
|
449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f,
|
||||||
|
485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f}));
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 3, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) {
|
||||||
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {4, 5}},
|
||||||
|
{TensorType_FLOAT32, {3, 1, 5, 2}});
|
||||||
|
model.PopulateTensor<float>(
|
||||||
|
model.lhs(),
|
||||||
|
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
|
||||||
|
model.PopulateTensor<float>(
|
||||||
|
model.rhs(),
|
||||||
|
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
|
||||||
|
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36});
|
||||||
|
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetOutput(),
|
||||||
|
ElementsAreArray({145.0f, 370.0f, 595.0f, 820.0f, 220.0f, 570.0f,
|
||||||
|
920.0f, 1270.0f, 295.0f, 770.0f, 1245.0f, 1720.0f,
|
||||||
|
370.0f, 970.0f, 1570.0f, 2170.0f, 445.0f, 1170.0f,
|
||||||
|
1895.0f, 2620.0f, 520.0f, 1370.0f, 2220.0f, 3070.0f}));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
|
@ -36,6 +36,7 @@ TfLiteRegistration* Register_ARG_MAX();
|
||||||
TfLiteRegistration* Register_ARG_MIN();
|
TfLiteRegistration* Register_ARG_MIN();
|
||||||
TfLiteRegistration* Register_AVERAGE_POOL_2D();
|
TfLiteRegistration* Register_AVERAGE_POOL_2D();
|
||||||
TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
|
TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL();
|
||||||
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
|
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
|
||||||
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
|
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
|
||||||
TfLiteRegistration* Register_CAST();
|
TfLiteRegistration* Register_CAST();
|
||||||
|
|
|
@ -213,6 +213,7 @@ cc_library(
|
||||||
name = "optimized_base",
|
name = "optimized_base",
|
||||||
srcs = [],
|
srcs = [],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"optimized/batch_matmul.h",
|
||||||
"optimized/depthwiseconv_3x3_filter_common.h",
|
"optimized/depthwiseconv_3x3_filter_common.h",
|
||||||
"optimized/depthwiseconv_float.h",
|
"optimized/depthwiseconv_float.h",
|
||||||
"optimized/depthwiseconv_multithread.h",
|
"optimized/depthwiseconv_multithread.h",
|
||||||
|
@ -416,6 +417,7 @@ cc_library(
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"reference/add.h",
|
"reference/add.h",
|
||||||
"reference/arg_min_max.h",
|
"reference/arg_min_max.h",
|
||||||
|
"reference/batch_matmul.h",
|
||||||
"reference/binary_function.h",
|
"reference/binary_function.h",
|
||||||
"reference/ceil.h",
|
"reference/ceil.h",
|
||||||
"reference/comparisons.h",
|
"reference/comparisons.h",
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/round.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace optimized_ops {
|
||||||
|
|
||||||
|
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
|
||||||
|
const RuntimeShape& rhs_shape, const float* rhs_data,
|
||||||
|
const RuntimeShape& output_shape, float* output_data,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
using ::tflite::cpu_backend_gemm::Gemm;
|
||||||
|
using ::tflite::cpu_backend_gemm::GemmParams;
|
||||||
|
using ::tflite::cpu_backend_gemm::MatrixParams;
|
||||||
|
const RuntimeShape extended_lhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||||
|
const RuntimeShape extended_rhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, rhs_shape);
|
||||||
|
|
||||||
|
// Determine which dimension is the broadcast dimension.
|
||||||
|
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
|
||||||
|
if (lhs_dim == rhs_dim) return lhs_dim;
|
||||||
|
if (lhs_dim == 1) return rhs_dim;
|
||||||
|
TFLITE_DCHECK_EQ(rhs_dim, 1);
|
||||||
|
return lhs_dim;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute the "extent" for iterating on this dimension.
|
||||||
|
// If we are broadcasting, then don't advance (i.e return 0).
|
||||||
|
auto extent = [](const RuntimeShape& shape, int x) {
|
||||||
|
if (shape.Dims(x) == 1) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int prod = 1;
|
||||||
|
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
|
||||||
|
prod *= shape.Dims(i);
|
||||||
|
}
|
||||||
|
return prod;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int batch_dim0 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
|
||||||
|
const int batch_dim1 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
|
||||||
|
const int batch_dim2 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
|
||||||
|
|
||||||
|
const int lhs_ext0 = extent(extended_lhs_shape, 0);
|
||||||
|
const int lhs_ext1 = extent(extended_lhs_shape, 1);
|
||||||
|
const int lhs_ext2 = extent(extended_lhs_shape, 2);
|
||||||
|
const int rhs_ext0 = extent(extended_rhs_shape, 0);
|
||||||
|
const int rhs_ext1 = extent(extended_rhs_shape, 1);
|
||||||
|
const int rhs_ext2 = extent(extended_rhs_shape, 2);
|
||||||
|
|
||||||
|
// Set params for each matrix multiply.
|
||||||
|
const int lhs_rows = extended_lhs_shape.Dims(3);
|
||||||
|
const int rhs_cols = extended_rhs_shape.Dims(4);
|
||||||
|
const int accum_depth = extended_lhs_shape.Dims(4);
|
||||||
|
|
||||||
|
MatrixParams<float> lhs_params;
|
||||||
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
|
lhs_params.rows = lhs_rows;
|
||||||
|
lhs_params.cols = accum_depth;
|
||||||
|
|
||||||
|
MatrixParams<float> rhs_params;
|
||||||
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
rhs_params.rows = accum_depth;
|
||||||
|
rhs_params.cols = rhs_cols;
|
||||||
|
|
||||||
|
MatrixParams<float> dst_params;
|
||||||
|
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
dst_params.rows = lhs_rows;
|
||||||
|
dst_params.cols = rhs_cols;
|
||||||
|
|
||||||
|
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||||
|
const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||||
|
const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||||
|
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||||
|
const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||||
|
const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||||
|
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||||
|
const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||||
|
const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||||
|
float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
||||||
|
b1 * batch_dim2 + b2) *
|
||||||
|
lhs_rows * rhs_cols;
|
||||||
|
GemmParams<float, float> gemm_params;
|
||||||
|
cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
|
||||||
|
dst_params, out_ptr, gemm_params, context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace optimized_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
|
|
@ -0,0 +1,105 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/round.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
|
||||||
|
const RuntimeShape& rhs_shape, const float* rhs_data,
|
||||||
|
const RuntimeShape& output_shape, float* output_data) {
|
||||||
|
const RuntimeShape extended_lhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||||
|
const RuntimeShape extended_rhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, rhs_shape);
|
||||||
|
|
||||||
|
// Determine which dimension is the broadcast dimension.
|
||||||
|
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
|
||||||
|
if (lhs_dim == rhs_dim) return lhs_dim;
|
||||||
|
if (lhs_dim == 1) return rhs_dim;
|
||||||
|
TFLITE_DCHECK_EQ(rhs_dim, 1);
|
||||||
|
return lhs_dim;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute the "extent" for iterating on this dimension.
|
||||||
|
// If we are broadcasting, then don't advance (i.e return 0).
|
||||||
|
auto extent = [](const RuntimeShape& shape, int x) {
|
||||||
|
if (shape.Dims(x) == 1) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int prod = 1;
|
||||||
|
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
|
||||||
|
prod *= shape.Dims(i);
|
||||||
|
}
|
||||||
|
return prod;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int batch_dim0 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
|
||||||
|
const int batch_dim1 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
|
||||||
|
const int batch_dim2 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
|
||||||
|
|
||||||
|
const int lhs_ext0 = extent(extended_lhs_shape, 0);
|
||||||
|
const int lhs_ext1 = extent(extended_lhs_shape, 1);
|
||||||
|
const int lhs_ext2 = extent(extended_lhs_shape, 2);
|
||||||
|
const int rhs_ext0 = extent(extended_rhs_shape, 0);
|
||||||
|
const int rhs_ext1 = extent(extended_rhs_shape, 1);
|
||||||
|
const int rhs_ext2 = extent(extended_rhs_shape, 2);
|
||||||
|
|
||||||
|
// Set params for each matrix multiply.
|
||||||
|
const int lhs_rows = extended_lhs_shape.Dims(3);
|
||||||
|
const int rhs_cols = extended_rhs_shape.Dims(4);
|
||||||
|
const int accum_depth = extended_lhs_shape.Dims(4);
|
||||||
|
|
||||||
|
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||||
|
const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||||
|
const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||||
|
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||||
|
const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||||
|
const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||||
|
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||||
|
const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||||
|
const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||||
|
float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
||||||
|
b1 * batch_dim2 + b2) *
|
||||||
|
lhs_rows * rhs_cols;
|
||||||
|
for (int j = 0; j < rhs_cols; ++j) {
|
||||||
|
for (int i = 0; i < lhs_rows; ++i) {
|
||||||
|
float total = 0.f;
|
||||||
|
for (int k = 0; k < accum_depth; ++k) {
|
||||||
|
total +=
|
||||||
|
lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
|
||||||
|
}
|
||||||
|
int idx = lhs_rows * j + i;
|
||||||
|
out_ptr[idx] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
|
|
@ -278,6 +278,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||||
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
||||||
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
||||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||||
|
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL());
|
||||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
|
|
@ -134,6 +134,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF();
|
||||||
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
|
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
|
||||||
TfLiteRegistration* Register_SELECT_V2();
|
TfLiteRegistration* Register_SELECT_V2();
|
||||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL_REF();
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
|
@ -346,7 +346,8 @@ enum BuiltinOperator : byte {
|
||||||
SCATTER_ND = 122,
|
SCATTER_ND = 122,
|
||||||
SELECT_V2 = 123,
|
SELECT_V2 = 123,
|
||||||
DENSIFY = 124,
|
DENSIFY = 124,
|
||||||
SEGMENT_SUM = 125
|
SEGMENT_SUM = 125,
|
||||||
|
BATCH_MATMUL = 126
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -451,7 +452,8 @@ union BuiltinOptions {
|
||||||
ScatterNdOptions,
|
ScatterNdOptions,
|
||||||
SelectV2Options,
|
SelectV2Options,
|
||||||
DensifyOptions,
|
DensifyOptions,
|
||||||
SegmentSumOptions
|
SegmentSumOptions,
|
||||||
|
BatchMatMulOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Padding : byte { SAME, VALID }
|
enum Padding : byte { SAME, VALID }
|
||||||
|
@ -945,6 +947,9 @@ table DensifyOptions {
|
||||||
table SegmentSumOptions {
|
table SegmentSumOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table BatchMatMulOptions {
|
||||||
|
}
|
||||||
|
|
||||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||||
// builtin, or a string if the operator is custom.
|
// builtin, or a string if the operator is custom.
|
||||||
table OperatorCode {
|
table OperatorCode {
|
||||||
|
|
|
@ -346,6 +346,9 @@ struct DensifyOptionsT;
|
||||||
struct SegmentSumOptions;
|
struct SegmentSumOptions;
|
||||||
struct SegmentSumOptionsT;
|
struct SegmentSumOptionsT;
|
||||||
|
|
||||||
|
struct BatchMatMulOptions;
|
||||||
|
struct BatchMatMulOptionsT;
|
||||||
|
|
||||||
struct OperatorCode;
|
struct OperatorCode;
|
||||||
struct OperatorCodeT;
|
struct OperatorCodeT;
|
||||||
|
|
||||||
|
@ -771,11 +774,12 @@ enum BuiltinOperator {
|
||||||
BuiltinOperator_SELECT_V2 = 123,
|
BuiltinOperator_SELECT_V2 = 123,
|
||||||
BuiltinOperator_DENSIFY = 124,
|
BuiltinOperator_DENSIFY = 124,
|
||||||
BuiltinOperator_SEGMENT_SUM = 125,
|
BuiltinOperator_SEGMENT_SUM = 125,
|
||||||
|
BuiltinOperator_BATCH_MATMUL = 126,
|
||||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||||
BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM
|
BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
|
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] {
|
||||||
static const BuiltinOperator values[] = {
|
static const BuiltinOperator values[] = {
|
||||||
BuiltinOperator_ADD,
|
BuiltinOperator_ADD,
|
||||||
BuiltinOperator_AVERAGE_POOL_2D,
|
BuiltinOperator_AVERAGE_POOL_2D,
|
||||||
|
@ -902,13 +906,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
|
||||||
BuiltinOperator_SCATTER_ND,
|
BuiltinOperator_SCATTER_ND,
|
||||||
BuiltinOperator_SELECT_V2,
|
BuiltinOperator_SELECT_V2,
|
||||||
BuiltinOperator_DENSIFY,
|
BuiltinOperator_DENSIFY,
|
||||||
BuiltinOperator_SEGMENT_SUM
|
BuiltinOperator_SEGMENT_SUM,
|
||||||
|
BuiltinOperator_BATCH_MATMUL
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char * const *EnumNamesBuiltinOperator() {
|
inline const char * const *EnumNamesBuiltinOperator() {
|
||||||
static const char * const names[127] = {
|
static const char * const names[128] = {
|
||||||
"ADD",
|
"ADD",
|
||||||
"AVERAGE_POOL_2D",
|
"AVERAGE_POOL_2D",
|
||||||
"CONCATENATION",
|
"CONCATENATION",
|
||||||
|
@ -1035,13 +1040,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||||
"SELECT_V2",
|
"SELECT_V2",
|
||||||
"DENSIFY",
|
"DENSIFY",
|
||||||
"SEGMENT_SUM",
|
"SEGMENT_SUM",
|
||||||
|
"BATCH_MATMUL",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_SEGMENT_SUM)) return "";
|
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return "";
|
||||||
const size_t index = static_cast<size_t>(e);
|
const size_t index = static_cast<size_t>(e);
|
||||||
return EnumNamesBuiltinOperator()[index];
|
return EnumNamesBuiltinOperator()[index];
|
||||||
}
|
}
|
||||||
|
@ -1148,11 +1154,12 @@ enum BuiltinOptions {
|
||||||
BuiltinOptions_SelectV2Options = 98,
|
BuiltinOptions_SelectV2Options = 98,
|
||||||
BuiltinOptions_DensifyOptions = 99,
|
BuiltinOptions_DensifyOptions = 99,
|
||||||
BuiltinOptions_SegmentSumOptions = 100,
|
BuiltinOptions_SegmentSumOptions = 100,
|
||||||
|
BuiltinOptions_BatchMatMulOptions = 101,
|
||||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||||
BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions
|
BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
|
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] {
|
||||||
static const BuiltinOptions values[] = {
|
static const BuiltinOptions values[] = {
|
||||||
BuiltinOptions_NONE,
|
BuiltinOptions_NONE,
|
||||||
BuiltinOptions_Conv2DOptions,
|
BuiltinOptions_Conv2DOptions,
|
||||||
|
@ -1254,13 +1261,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
|
||||||
BuiltinOptions_ScatterNdOptions,
|
BuiltinOptions_ScatterNdOptions,
|
||||||
BuiltinOptions_SelectV2Options,
|
BuiltinOptions_SelectV2Options,
|
||||||
BuiltinOptions_DensifyOptions,
|
BuiltinOptions_DensifyOptions,
|
||||||
BuiltinOptions_SegmentSumOptions
|
BuiltinOptions_SegmentSumOptions,
|
||||||
|
BuiltinOptions_BatchMatMulOptions
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char * const *EnumNamesBuiltinOptions() {
|
inline const char * const *EnumNamesBuiltinOptions() {
|
||||||
static const char * const names[102] = {
|
static const char * const names[103] = {
|
||||||
"NONE",
|
"NONE",
|
||||||
"Conv2DOptions",
|
"Conv2DOptions",
|
||||||
"DepthwiseConv2DOptions",
|
"DepthwiseConv2DOptions",
|
||||||
|
@ -1362,13 +1370,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||||
"SelectV2Options",
|
"SelectV2Options",
|
||||||
"DensifyOptions",
|
"DensifyOptions",
|
||||||
"SegmentSumOptions",
|
"SegmentSumOptions",
|
||||||
|
"BatchMatMulOptions",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||||
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_SegmentSumOptions)) return "";
|
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return "";
|
||||||
const size_t index = static_cast<size_t>(e);
|
const size_t index = static_cast<size_t>(e);
|
||||||
return EnumNamesBuiltinOptions()[index];
|
return EnumNamesBuiltinOptions()[index];
|
||||||
}
|
}
|
||||||
|
@ -1777,6 +1786,10 @@ template<> struct BuiltinOptionsTraits<tflite::SegmentSumOptions> {
|
||||||
static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
|
static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<> struct BuiltinOptionsTraits<tflite::BatchMatMulOptions> {
|
||||||
|
static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions;
|
||||||
|
};
|
||||||
|
|
||||||
struct BuiltinOptionsUnion {
|
struct BuiltinOptionsUnion {
|
||||||
BuiltinOptions type;
|
BuiltinOptions type;
|
||||||
void *value;
|
void *value;
|
||||||
|
@ -2609,6 +2622,14 @@ struct BuiltinOptionsUnion {
|
||||||
return type == BuiltinOptions_SegmentSumOptions ?
|
return type == BuiltinOptions_SegmentSumOptions ?
|
||||||
reinterpret_cast<const tflite::SegmentSumOptionsT *>(value) : nullptr;
|
reinterpret_cast<const tflite::SegmentSumOptionsT *>(value) : nullptr;
|
||||||
}
|
}
|
||||||
|
tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() {
|
||||||
|
return type == BuiltinOptions_BatchMatMulOptions ?
|
||||||
|
reinterpret_cast<tflite::BatchMatMulOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
|
const tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() const {
|
||||||
|
return type == BuiltinOptions_BatchMatMulOptions ?
|
||||||
|
reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||||
|
@ -9109,6 +9130,46 @@ inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(
|
||||||
|
|
||||||
flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
|
struct BatchMatMulOptionsT : public flatbuffers::NativeTable {
|
||||||
|
typedef BatchMatMulOptions TableType;
|
||||||
|
BatchMatMulOptionsT() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
typedef BatchMatMulOptionsT NativeTableType;
|
||||||
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
|
return VerifyTableStart(verifier) &&
|
||||||
|
verifier.EndTable();
|
||||||
|
}
|
||||||
|
BatchMatMulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
void UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
static flatbuffers::Offset<BatchMatMulOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BatchMatMulOptionsBuilder {
|
||||||
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
|
flatbuffers::uoffset_t start_;
|
||||||
|
explicit BatchMatMulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
|
: fbb_(_fbb) {
|
||||||
|
start_ = fbb_.StartTable();
|
||||||
|
}
|
||||||
|
BatchMatMulOptionsBuilder &operator=(const BatchMatMulOptionsBuilder &);
|
||||||
|
flatbuffers::Offset<BatchMatMulOptions> Finish() {
|
||||||
|
const auto end = fbb_.EndTable(start_);
|
||||||
|
auto o = flatbuffers::Offset<BatchMatMulOptions>(end);
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(
|
||||||
|
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||||
|
BatchMatMulOptionsBuilder builder_(_fbb);
|
||||||
|
return builder_.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||||
typedef OperatorCode TableType;
|
typedef OperatorCode TableType;
|
||||||
tflite::BuiltinOperator builtin_code;
|
tflite::BuiltinOperator builtin_code;
|
||||||
|
@ -9545,6 +9606,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const {
|
const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const {
|
||||||
return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast<const tflite::SegmentSumOptions *>(builtin_options()) : nullptr;
|
return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast<const tflite::SegmentSumOptions *>(builtin_options()) : nullptr;
|
||||||
}
|
}
|
||||||
|
const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const {
|
||||||
|
return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast<const tflite::BatchMatMulOptions *>(builtin_options()) : nullptr;
|
||||||
|
}
|
||||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||||
}
|
}
|
||||||
|
@ -9981,6 +10045,10 @@ template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as<
|
||||||
return builtin_options_as_SegmentSumOptions();
|
return builtin_options_as_SegmentSumOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as<tflite::BatchMatMulOptions>() const {
|
||||||
|
return builtin_options_as_BatchMatMulOptions();
|
||||||
|
}
|
||||||
|
|
||||||
struct OperatorBuilder {
|
struct OperatorBuilder {
|
||||||
flatbuffers::FlatBufferBuilder &fbb_;
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
flatbuffers::uoffset_t start_;
|
flatbuffers::uoffset_t start_;
|
||||||
|
@ -13392,6 +13460,29 @@ inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffer
|
||||||
_fbb);
|
_fbb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline BatchMatMulOptionsT *BatchMatMulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
auto _o = new BatchMatMulOptionsT();
|
||||||
|
UnPackTo(_o, _resolver);
|
||||||
|
return _o;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void BatchMatMulOptions::UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
(void)_o;
|
||||||
|
(void)_resolver;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<BatchMatMulOptions> BatchMatMulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
return CreateBatchMatMulOptions(_fbb, _o, _rehasher);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
(void)_rehasher;
|
||||||
|
(void)_o;
|
||||||
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchMatMulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
|
return tflite::CreateBatchMatMulOptions(
|
||||||
|
_fbb);
|
||||||
|
}
|
||||||
|
|
||||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
auto _o = new OperatorCodeT();
|
auto _o = new OperatorCodeT();
|
||||||
UnPackTo(_o, _resolver);
|
UnPackTo(_o, _resolver);
|
||||||
|
@ -14197,6 +14288,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||||
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
|
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
|
||||||
return verifier.VerifyTable(ptr);
|
return verifier.VerifyTable(ptr);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_BatchMatMulOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
|
||||||
|
return verifier.VerifyTable(ptr);
|
||||||
|
}
|
||||||
default: return true;
|
default: return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14615,6 +14710,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||||
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
|
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
|
||||||
return ptr->UnPack(resolver);
|
return ptr->UnPack(resolver);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_BatchMatMulOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
|
||||||
|
return ptr->UnPack(resolver);
|
||||||
|
}
|
||||||
default: return nullptr;
|
default: return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15021,6 +15120,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||||
auto ptr = reinterpret_cast<const tflite::SegmentSumOptionsT *>(value);
|
auto ptr = reinterpret_cast<const tflite::SegmentSumOptionsT *>(value);
|
||||||
return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union();
|
return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union();
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_BatchMatMulOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value);
|
||||||
|
return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union();
|
||||||
|
}
|
||||||
default: return 0;
|
default: return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15427,6 +15530,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||||
value = new tflite::SegmentSumOptionsT(*reinterpret_cast<tflite::SegmentSumOptionsT *>(u.value));
|
value = new tflite::SegmentSumOptionsT(*reinterpret_cast<tflite::SegmentSumOptionsT *>(u.value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_BatchMatMulOptions: {
|
||||||
|
value = new tflite::BatchMatMulOptionsT(*reinterpret_cast<tflite::BatchMatMulOptionsT *>(u.value));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -15934,6 +16041,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||||
delete ptr;
|
delete ptr;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_BatchMatMulOptions: {
|
||||||
|
auto ptr = reinterpret_cast<tflite::BatchMatMulOptionsT *>(value);
|
||||||
|
delete ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
value = nullptr;
|
value = nullptr;
|
||||||
|
|
|
@ -57,6 +57,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||||
{{OperatorType::kDiv, 1}, "1.6.0"},
|
{{OperatorType::kDiv, 1}, "1.6.0"},
|
||||||
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},
|
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},
|
||||||
{{OperatorType::kBatchToSpaceND, 2}, "1.14.0"},
|
{{OperatorType::kBatchToSpaceND, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kCast, 1}, "1.5.0"},
|
{{OperatorType::kCast, 1}, "1.5.0"},
|
||||||
{{OperatorType::kConcatenation, 1}, "1.5.0"},
|
{{OperatorType::kConcatenation, 1}, "1.5.0"},
|
||||||
{{OperatorType::kConcatenation, 2}, "1.14.0"},
|
{{OperatorType::kConcatenation, 2}, "1.14.0"},
|
||||||
|
|
Loading…
Reference in New Issue