Add BatchMatMul built-in op for TF Lite

PiperOrigin-RevId: 302489633
Change-Id: Ie4ad2abad069b1e5bc654fc51caf0bcbc99b714f
This commit is contained in:
T.J. Alumbaugh 2020-03-23 12:19:59 -07:00 committed by TensorFlower Gardener
parent c6ec2565db
commit 828fe43cf3
14 changed files with 699 additions and 12 deletions

View File

@ -152,6 +152,7 @@ typedef enum {
kTfLiteBuiltinSelectV2 = 123,
kTfLiteBuiltinDensify = 124,
kTfLiteBuiltinSegmentSum = 125,
kTfLiteBuiltinBatchMatmul = 126,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -840,6 +840,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_BATCH_MATMUL:
break;
}
return kTfLiteOk;

View File

@ -426,6 +426,7 @@ cc_library(
"arg_min_max.cc",
"audio_spectrogram.cc",
"basic_rnn.cc",
"batch_matmul.cc",
"batch_to_space_nd.cc",
"bidirectional_sequence_lstm.cc",
"bidirectional_sequence_rnn.cc",
@ -849,6 +850,19 @@ cc_test(
],
)
cc_test(
name = "batch_matmul_test",
size = "small",
srcs = ["batch_matmul_test.cc"],
deps = [
":builtin_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "cast_test",
size = "small",

View File

@ -0,0 +1,156 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace batch_matmul {
static const int kInputLHSTensor = 0;
static const int kInputRHSTensor = 1;
static const int kOutputTensor = 0;
// This file has two implementations of Transpose.
enum KernelType {
kReference,
kGenericOptimized,
};
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const RuntimeShape& extended_lhs_shape,
const RuntimeShape& extended_rhs_shape,
int output_rank, TfLiteTensor* output) {
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
// Fill in any broadcast dimensions.
for (int i = 0; i < output_rank - 2; ++i) {
const int lhs_dim = extended_lhs_shape.Dims(i);
const int rhs_dim = extended_rhs_shape.Dims(i);
int broadcast_dim = lhs_dim;
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
broadcast_dim = rhs_dim;
}
output_shape->data[i] = broadcast_dim;
}
// Fill in the matmul dimensions.
output_shape->data[output_rank - 2] =
extended_lhs_shape.Dims(output_rank - 2);
output_shape->data[output_rank - 1] =
extended_rhs_shape.Dims(output_rank - 1);
TfLiteStatus stat = context->ResizeTensor(context, output, output_shape);
return stat;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor);
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, lhs_data->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, rhs_data->type, kTfLiteFloat32);
// Support dimensions between 2 and 5, inclusive.
TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5);
TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2);
TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5);
const int lhs_rank = NumDimensions(lhs_data);
const int rhs_rank = NumDimensions(rhs_data);
const int output_rank = std::max(lhs_rank, rhs_rank);
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data));
const RuntimeShape extended_rhs_shape =
RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data));
// Ensure any batch dimensions obey broacasting rules.
for (int i = 0; i < output_rank - 2; ++i) {
const int lhs_dim = extended_lhs_shape.Dims(i);
const int rhs_dim = extended_rhs_shape.Dims(i);
if (lhs_dim != rhs_dim) {
if (lhs_dim != 1) {
TF_LITE_ENSURE_EQ(context, rhs_dim, 1);
}
}
}
// Ensure other dimensions work for matrix multiplication.
TF_LITE_ENSURE_EQ(context, extended_lhs_shape.Dims(output_rank - 1),
extended_rhs_shape.Dims(output_rank - 2));
return ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape,
output_rank, output);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor);
const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (lhs->type) {
case kTfLiteFloat32:
if (kernel_type == kGenericOptimized) {
optimized_ops::BatchMatMul(
GetTensorShape(lhs), GetTensorData<float>(lhs), GetTensorShape(rhs),
GetTensorData<float>(rhs), GetTensorShape(output),
GetTensorData<float>(output),
CpuBackendContext::GetFromContext(context));
} else {
reference_ops::BatchMatMul(
GetTensorShape(lhs), GetTensorData<float>(lhs), GetTensorShape(rhs),
GetTensorData<float>(rhs), GetTensorShape(output),
GetTensorData<float>(output));
}
break;
default:
TF_LITE_KERNEL_LOG(context,
"Currently BatchMatMul doesn't support type: %s",
TfLiteTypeGetName(lhs->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace batch_matmul
TfLiteRegistration* Register_BATCH_MATMUL_REF() {
static TfLiteRegistration r = {nullptr, nullptr, batch_matmul::Prepare,
batch_matmul::Eval<batch_matmul::kReference>};
return &r;
}
TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() {
static TfLiteRegistration r = {
nullptr, nullptr, batch_matmul::Prepare,
batch_matmul::Eval<batch_matmul::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_BATCH_MATMUL() {
return Register_BATCH_MATMUL_GENERIC_OPTIMIZED();
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,169 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
template <typename T>
class BatchMatMulOpModel : public SingleOpModel {
public:
BatchMatMulOpModel(const TensorData& lhs, const TensorData& rhs) {
lhs_id_ = AddInput(lhs);
rhs_id_ = AddInput(rhs);
output_id_ = AddOutput(lhs.type);
SetBuiltinOp(BuiltinOperator_BATCH_MATMUL, BuiltinOptions_NONE, 0);
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
}
int lhs() const { return lhs_id_; }
int rhs() const { return rhs_id_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); }
protected:
int lhs_id_;
int rhs_id_;
int output_id_;
};
TEST(BatchMatMulOpModelTest, Float32Test_Simple) {
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
{TensorType_FLOAT32, {1, 3, 4}});
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<float>(model.rhs(),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
model.Invoke();
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
104.0f, 257.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
}
TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) {
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
{TensorType_FLOAT32, {2, 3, 4}});
model.PopulateTensor<float>(model.lhs(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
model.PopulateTensor<float>(model.rhs(),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
model.Invoke();
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
104.0f, 257.0f, 482.0f, 662.0f, 554.0f, 761.0f,
626.0f, 860.0f, 698.0f, 959.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
}
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) {
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
{TensorType_FLOAT32, {3, 4}});
model.PopulateTensor<float>(model.lhs(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
model.PopulateTensor<float>(model.rhs(),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
model.Invoke();
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f,
104.0f, 257.0f, 194.0f, 266.0f, 266.0f, 365.0f,
338.0f, 464.0f, 410.0f, 563.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
}
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) {
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
{TensorType_FLOAT32, {3, 2, 4}});
model.PopulateTensor<float>(model.lhs(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
model.PopulateTensor<float>(model.rhs(),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
model.Invoke();
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray(
{23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f,
127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f,
123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f,
71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f,
303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f,
181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f,
233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f,
449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f,
485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
}
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFiveD) {
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 1, 3, 2}},
{TensorType_FLOAT32, {3, 2, 4}});
model.PopulateTensor<float>(model.lhs(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
model.PopulateTensor<float>(model.rhs(),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
model.Invoke();
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray(
{23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f,
127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f,
123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f,
71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f,
303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f,
181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f,
233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f,
449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f,
485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 3, 4}));
}
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) {
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {4, 5}},
{TensorType_FLOAT32, {3, 1, 5, 2}});
model.PopulateTensor<float>(
model.lhs(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
model.PopulateTensor<float>(
model.rhs(),
{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36});
model.Invoke();
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({145.0f, 370.0f, 595.0f, 820.0f, 220.0f, 570.0f,
920.0f, 1270.0f, 295.0f, 770.0f, 1245.0f, 1720.0f,
370.0f, 970.0f, 1570.0f, 2170.0f, 445.0f, 1170.0f,
1895.0f, 2620.0f, 520.0f, 1370.0f, 2220.0f, 3070.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2}));
}
} // namespace
} // namespace tflite

View File

@ -36,6 +36,7 @@ TfLiteRegistration* Register_ARG_MAX();
TfLiteRegistration* Register_ARG_MIN();
TfLiteRegistration* Register_AVERAGE_POOL_2D();
TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
TfLiteRegistration* Register_BATCH_MATMUL();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
TfLiteRegistration* Register_CAST();

View File

@ -213,6 +213,7 @@ cc_library(
name = "optimized_base",
srcs = [],
hdrs = [
"optimized/batch_matmul.h",
"optimized/depthwiseconv_3x3_filter_common.h",
"optimized/depthwiseconv_float.h",
"optimized/depthwiseconv_multithread.h",
@ -416,6 +417,7 @@ cc_library(
hdrs = [
"reference/add.h",
"reference/arg_min_max.h",
"reference/batch_matmul.h",
"reference/binary_function.h",
"reference/ceil.h",
"reference/comparisons.h",

View File

@ -0,0 +1,118 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace optimized_ops {
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
const RuntimeShape& rhs_shape, const float* rhs_data,
const RuntimeShape& output_shape, float* output_data,
CpuBackendContext* context) {
using ::tflite::cpu_backend_gemm::Gemm;
using ::tflite::cpu_backend_gemm::GemmParams;
using ::tflite::cpu_backend_gemm::MatrixParams;
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(5, lhs_shape);
const RuntimeShape extended_rhs_shape =
RuntimeShape::ExtendedShape(5, rhs_shape);
// Determine which dimension is the broadcast dimension.
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
if (lhs_dim == rhs_dim) return lhs_dim;
if (lhs_dim == 1) return rhs_dim;
TFLITE_DCHECK_EQ(rhs_dim, 1);
return lhs_dim;
};
// Compute the "extent" for iterating on this dimension.
// If we are broadcasting, then don't advance (i.e return 0).
auto extent = [](const RuntimeShape& shape, int x) {
if (shape.Dims(x) == 1) {
return 0;
}
int prod = 1;
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
prod *= shape.Dims(i);
}
return prod;
};
const int batch_dim0 =
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
const int batch_dim1 =
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
const int batch_dim2 =
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
const int lhs_ext0 = extent(extended_lhs_shape, 0);
const int lhs_ext1 = extent(extended_lhs_shape, 1);
const int lhs_ext2 = extent(extended_lhs_shape, 2);
const int rhs_ext0 = extent(extended_rhs_shape, 0);
const int rhs_ext1 = extent(extended_rhs_shape, 1);
const int rhs_ext2 = extent(extended_rhs_shape, 2);
// Set params for each matrix multiply.
const int lhs_rows = extended_lhs_shape.Dims(3);
const int rhs_cols = extended_rhs_shape.Dims(4);
const int accum_depth = extended_lhs_shape.Dims(4);
MatrixParams<float> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = lhs_rows;
lhs_params.cols = accum_depth;
MatrixParams<float> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.rows = accum_depth;
rhs_params.cols = rhs_cols;
MatrixParams<float> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = lhs_rows;
dst_params.cols = rhs_cols;
for (int b0 = 0; b0 < batch_dim0; ++b0) {
const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
for (int b1 = 0; b1 < batch_dim1; ++b1) {
const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
for (int b2 = 0; b2 < batch_dim2; ++b2) {
const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;
GemmParams<float, float> gemm_params;
cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
dst_params, out_ptr, gemm_params, context);
}
}
}
}
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/round.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
const RuntimeShape& rhs_shape, const float* rhs_data,
const RuntimeShape& output_shape, float* output_data) {
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(5, lhs_shape);
const RuntimeShape extended_rhs_shape =
RuntimeShape::ExtendedShape(5, rhs_shape);
// Determine which dimension is the broadcast dimension.
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
if (lhs_dim == rhs_dim) return lhs_dim;
if (lhs_dim == 1) return rhs_dim;
TFLITE_DCHECK_EQ(rhs_dim, 1);
return lhs_dim;
};
// Compute the "extent" for iterating on this dimension.
// If we are broadcasting, then don't advance (i.e return 0).
auto extent = [](const RuntimeShape& shape, int x) {
if (shape.Dims(x) == 1) {
return 0;
}
int prod = 1;
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
prod *= shape.Dims(i);
}
return prod;
};
const int batch_dim0 =
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
const int batch_dim1 =
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
const int batch_dim2 =
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
const int lhs_ext0 = extent(extended_lhs_shape, 0);
const int lhs_ext1 = extent(extended_lhs_shape, 1);
const int lhs_ext2 = extent(extended_lhs_shape, 2);
const int rhs_ext0 = extent(extended_rhs_shape, 0);
const int rhs_ext1 = extent(extended_rhs_shape, 1);
const int rhs_ext2 = extent(extended_rhs_shape, 2);
// Set params for each matrix multiply.
const int lhs_rows = extended_lhs_shape.Dims(3);
const int rhs_cols = extended_rhs_shape.Dims(4);
const int accum_depth = extended_lhs_shape.Dims(4);
for (int b0 = 0; b0 < batch_dim0; ++b0) {
const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
for (int b1 = 0; b1 < batch_dim1; ++b1) {
const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
for (int b2 = 0; b2 < batch_dim2; ++b2) {
const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;
for (int j = 0; j < rhs_cols; ++j) {
for (int i = 0; i < lhs_rows; ++i) {
float total = 0.f;
for (int k = 0; k < accum_depth; ++k) {
total +=
lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
}
int idx = lhs_rows * j + i;
out_ptr[idx] = total;
}
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_

View File

@ -278,6 +278,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -134,6 +134,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF();
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_BATCH_MATMUL_REF();
namespace {

View File

@ -346,7 +346,8 @@ enum BuiltinOperator : byte {
SCATTER_ND = 122,
SELECT_V2 = 123,
DENSIFY = 124,
SEGMENT_SUM = 125
SEGMENT_SUM = 125,
BATCH_MATMUL = 126
}
@ -451,7 +452,8 @@ union BuiltinOptions {
ScatterNdOptions,
SelectV2Options,
DensifyOptions,
SegmentSumOptions
SegmentSumOptions,
BatchMatMulOptions
}
enum Padding : byte { SAME, VALID }
@ -945,6 +947,9 @@ table DensifyOptions {
table SegmentSumOptions {
}
table BatchMatMulOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -346,6 +346,9 @@ struct DensifyOptionsT;
struct SegmentSumOptions;
struct SegmentSumOptionsT;
struct BatchMatMulOptions;
struct BatchMatMulOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -771,11 +774,12 @@ enum BuiltinOperator {
BuiltinOperator_SELECT_V2 = 123,
BuiltinOperator_DENSIFY = 124,
BuiltinOperator_SEGMENT_SUM = 125,
BuiltinOperator_BATCH_MATMUL = 126,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM
BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -902,13 +906,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
BuiltinOperator_SCATTER_ND,
BuiltinOperator_SELECT_V2,
BuiltinOperator_DENSIFY,
BuiltinOperator_SEGMENT_SUM
BuiltinOperator_SEGMENT_SUM,
BuiltinOperator_BATCH_MATMUL
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
static const char * const names[127] = {
static const char * const names[128] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@ -1035,13 +1040,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"SELECT_V2",
"DENSIFY",
"SEGMENT_SUM",
"BATCH_MATMUL",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_SEGMENT_SUM)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -1148,11 +1154,12 @@ enum BuiltinOptions {
BuiltinOptions_SelectV2Options = 98,
BuiltinOptions_DensifyOptions = 99,
BuiltinOptions_SegmentSumOptions = 100,
BuiltinOptions_BatchMatMulOptions = 101,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions
BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1254,13 +1261,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
BuiltinOptions_ScatterNdOptions,
BuiltinOptions_SelectV2Options,
BuiltinOptions_DensifyOptions,
BuiltinOptions_SegmentSumOptions
BuiltinOptions_SegmentSumOptions,
BuiltinOptions_BatchMatMulOptions
};
return values;
}
inline const char * const *EnumNamesBuiltinOptions() {
static const char * const names[102] = {
static const char * const names[103] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@ -1362,13 +1370,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"SelectV2Options",
"DensifyOptions",
"SegmentSumOptions",
"BatchMatMulOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_SegmentSumOptions)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1777,6 +1786,10 @@ template<> struct BuiltinOptionsTraits<tflite::SegmentSumOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
};
template<> struct BuiltinOptionsTraits<tflite::BatchMatMulOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2609,6 +2622,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_SegmentSumOptions ?
reinterpret_cast<const tflite::SegmentSumOptionsT *>(value) : nullptr;
}
tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() {
return type == BuiltinOptions_BatchMatMulOptions ?
reinterpret_cast<tflite::BatchMatMulOptionsT *>(value) : nullptr;
}
const tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() const {
return type == BuiltinOptions_BatchMatMulOptions ?
reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -9109,6 +9130,46 @@ inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(
flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct BatchMatMulOptionsT : public flatbuffers::NativeTable {
typedef BatchMatMulOptions TableType;
BatchMatMulOptionsT() {
}
};
struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef BatchMatMulOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
BatchMatMulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<BatchMatMulOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct BatchMatMulOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit BatchMatMulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
BatchMatMulOptionsBuilder &operator=(const BatchMatMulOptionsBuilder &);
flatbuffers::Offset<BatchMatMulOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<BatchMatMulOptions>(end);
return o;
}
};
inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
BatchMatMulOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
tflite::BuiltinOperator builtin_code;
@ -9545,6 +9606,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast<const tflite::SegmentSumOptions *>(builtin_options()) : nullptr;
}
const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast<const tflite::BatchMatMulOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -9981,6 +10045,10 @@ template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as<
return builtin_options_as_SegmentSumOptions();
}
template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as<tflite::BatchMatMulOptions>() const {
return builtin_options_as_BatchMatMulOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -13392,6 +13460,29 @@ inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffer
_fbb);
}
inline BatchMatMulOptionsT *BatchMatMulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BatchMatMulOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void BatchMatMulOptions::UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<BatchMatMulOptions> BatchMatMulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateBatchMatMulOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchMatMulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateBatchMatMulOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -14197,6 +14288,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_BatchMatMulOptions: {
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
@ -14615,6 +14710,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const tflite::SegmentSumOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_BatchMatMulOptions: {
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -15021,6 +15120,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const tflite::SegmentSumOptionsT *>(value);
return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_BatchMatMulOptions: {
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value);
return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -15427,6 +15530,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new tflite::SegmentSumOptionsT(*reinterpret_cast<tflite::SegmentSumOptionsT *>(u.value));
break;
}
case BuiltinOptions_BatchMatMulOptions: {
value = new tflite::BatchMatMulOptionsT(*reinterpret_cast<tflite::BatchMatMulOptionsT *>(u.value));
break;
}
default:
break;
}
@ -15934,6 +16041,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_BatchMatMulOptions: {
auto ptr = reinterpret_cast<tflite::BatchMatMulOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -57,6 +57,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kDiv, 1}, "1.6.0"},
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},
{{OperatorType::kBatchToSpaceND, 2}, "1.14.0"},
{{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion},
{{OperatorType::kCast, 1}, "1.5.0"},
{{OperatorType::kConcatenation, 1}, "1.5.0"},
{{OperatorType::kConcatenation, 2}, "1.14.0"},