diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index c4e2907ffa9..85140289ac1 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -152,6 +152,7 @@ typedef enum { kTfLiteBuiltinSelectV2 = 123, kTfLiteBuiltinDensify = 124, kTfLiteBuiltinSegmentSum = 125, + kTfLiteBuiltinBatchMatmul = 126, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 6621e608d35..83b4159cce0 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -840,6 +840,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SCATTER_ND: case BuiltinOperator_DENSIFY: case BuiltinOperator_SEGMENT_SUM: + case BuiltinOperator_BATCH_MATMUL: break; } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 1f04cc3ee47..872d3c0822b 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -426,6 +426,7 @@ cc_library( "arg_min_max.cc", "audio_spectrogram.cc", "basic_rnn.cc", + "batch_matmul.cc", "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", @@ -849,6 +850,19 @@ cc_test( ], ) +cc_test( + name = "batch_matmul_test", + size = "small", + srcs = ["batch_matmul_test.cc"], + deps = [ + ":builtin_ops", + ":test_main", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "cast_test", size = "small", diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc new file mode 100644 index 00000000000..30bc624a218 --- /dev/null +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace batch_matmul { + +static const int kInputLHSTensor = 0; +static const int kInputRHSTensor = 1; +static const int kOutputTensor = 0; + +// This file has two implementations of Transpose. +enum KernelType { + kReference, + kGenericOptimized, +}; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const RuntimeShape& extended_lhs_shape, + const RuntimeShape& extended_rhs_shape, + int output_rank, TfLiteTensor* output) { + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + // Fill in any broadcast dimensions. + for (int i = 0; i < output_rank - 2; ++i) { + const int lhs_dim = extended_lhs_shape.Dims(i); + const int rhs_dim = extended_rhs_shape.Dims(i); + int broadcast_dim = lhs_dim; + if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) { + broadcast_dim = rhs_dim; + } + output_shape->data[i] = broadcast_dim; + } + // Fill in the matmul dimensions. + output_shape->data[output_rank - 2] = + extended_lhs_shape.Dims(output_rank - 2); + output_shape->data[output_rank - 1] = + extended_rhs_shape.Dims(output_rank - 1); + TfLiteStatus stat = context->ResizeTensor(context, output, output_shape); + return stat; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* lhs_data = GetInput(context, node, kInputLHSTensor); + const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, lhs_data->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, rhs_data->type, kTfLiteFloat32); + // Support dimensions between 2 and 5, inclusive. + TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2); + TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5); + TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2); + TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5); + + const int lhs_rank = NumDimensions(lhs_data); + const int rhs_rank = NumDimensions(rhs_data); + const int output_rank = std::max(lhs_rank, rhs_rank); + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data)); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data)); + + // Ensure any batch dimensions obey broacasting rules. + for (int i = 0; i < output_rank - 2; ++i) { + const int lhs_dim = extended_lhs_shape.Dims(i); + const int rhs_dim = extended_rhs_shape.Dims(i); + if (lhs_dim != rhs_dim) { + if (lhs_dim != 1) { + TF_LITE_ENSURE_EQ(context, rhs_dim, 1); + } + } + } + // Ensure other dimensions work for matrix multiplication. + TF_LITE_ENSURE_EQ(context, extended_lhs_shape.Dims(output_rank - 1), + extended_rhs_shape.Dims(output_rank - 2)); + return ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape, + output_rank, output); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lhs = GetInput(context, node, kInputLHSTensor); + const TfLiteTensor* rhs = GetInput(context, node, kInputRHSTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + switch (lhs->type) { + case kTfLiteFloat32: + if (kernel_type == kGenericOptimized) { + optimized_ops::BatchMatMul( + GetTensorShape(lhs), GetTensorData(lhs), GetTensorShape(rhs), + GetTensorData(rhs), GetTensorShape(output), + GetTensorData(output), + CpuBackendContext::GetFromContext(context)); + } else { + reference_ops::BatchMatMul( + GetTensorShape(lhs), GetTensorData(lhs), GetTensorShape(rhs), + GetTensorData(rhs), GetTensorShape(output), + GetTensorData(output)); + } + break; + default: + TF_LITE_KERNEL_LOG(context, + "Currently BatchMatMul doesn't support type: %s", + TfLiteTypeGetName(lhs->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace batch_matmul + +TfLiteRegistration* Register_BATCH_MATMUL_REF() { + static TfLiteRegistration r = {nullptr, nullptr, batch_matmul::Prepare, + batch_matmul::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_matmul::Prepare, + batch_matmul::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_MATMUL() { + return Register_BATCH_MATMUL_GENERIC_OPTIMIZED(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/batch_matmul_test.cc b/tensorflow/lite/kernels/batch_matmul_test.cc new file mode 100644 index 00000000000..9b33ebef542 --- /dev/null +++ b/tensorflow/lite/kernels/batch_matmul_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class BatchMatMulOpModel : public SingleOpModel { + public: + BatchMatMulOpModel(const TensorData& lhs, const TensorData& rhs) { + lhs_id_ = AddInput(lhs); + rhs_id_ = AddInput(rhs); + output_id_ = AddOutput(lhs.type); + SetBuiltinOp(BuiltinOperator_BATCH_MATMUL, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)}); + } + + int lhs() const { return lhs_id_; } + int rhs() const { return rhs_id_; } + std::vector GetOutput() { return ExtractVector(output_id_); } + std::vector GetOutputShape() { return GetTensorShape(output_id_); } + + protected: + int lhs_id_; + int rhs_id_; + int output_id_; +}; + +TEST(BatchMatMulOpModelTest, Float32Test_Simple) { + BatchMatMulOpModel model({TensorType_FLOAT32, {1, 2, 3}}, + {TensorType_FLOAT32, {1, 3, 4}}); + model.PopulateTensor(model.lhs(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f, + 104.0f, 257.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) { + BatchMatMulOpModel model({TensorType_FLOAT32, {2, 2, 3}}, + {TensorType_FLOAT32, {2, 3, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f, + 104.0f, 257.0f, 482.0f, 662.0f, 554.0f, 761.0f, + 626.0f, 860.0f, 698.0f, 959.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) { + BatchMatMulOpModel model({TensorType_FLOAT32, {2, 2, 3}}, + {TensorType_FLOAT32, {3, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({50.0f, 122.0f, 68.0f, 167.0f, 86.0f, 212.0f, + 104.0f, 257.0f, 194.0f, 266.0f, 266.0f, 365.0f, + 338.0f, 464.0f, 410.0f, 563.0f})); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) { + BatchMatMulOpModel model({TensorType_FLOAT32, {2, 1, 3, 2}}, + {TensorType_FLOAT32, {3, 2, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + model.Invoke(); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray( + {23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f, + 127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f, + 123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f, + 71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f, + 303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f, + 181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f, + 233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f, + 449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f, + 485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f})); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFiveD) { + BatchMatMulOpModel model({TensorType_FLOAT32, {1, 2, 1, 3, 2}}, + {TensorType_FLOAT32, {3, 2, 4}}); + model.PopulateTensor(model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + model.PopulateTensor(model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}); + + model.Invoke(); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray( + {23.0f, 53.0f, 83.0f, 29.0f, 67.0f, 105.0f, 35.0f, 81.0f, + 127.0f, 41.0f, 95.0f, 149.0f, 47.0f, 109.0f, 171.0f, 53.0f, + 123.0f, 193.0f, 59.0f, 137.0f, 215.0f, 65.0f, 151.0f, 237.0f, + 71.0f, 165.0f, 259.0f, 77.0f, 179.0f, 281.0f, 83.0f, 193.0f, + 303.0f, 89.0f, 207.0f, 325.0f, 113.0f, 143.0f, 173.0f, 143.0f, + 181.0f, 219.0f, 173.0f, 219.0f, 265.0f, 203.0f, 257.0f, 311.0f, + 233.0f, 295.0f, 357.0f, 263.0f, 333.0f, 403.0f, 293.0f, 371.0f, + 449.0f, 323.0f, 409.0f, 495.0f, 353.0f, 447.0f, 541.0f, 383.0f, + 485.0f, 587.0f, 413.0f, 523.0f, 633.0f, 443.0f, 561.0f, 679.0f})); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 3, 4})); +} + +TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) { + BatchMatMulOpModel model({TensorType_FLOAT32, {4, 5}}, + {TensorType_FLOAT32, {3, 1, 5, 2}}); + model.PopulateTensor( + model.lhs(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + model.PopulateTensor( + model.rhs(), + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}); + + model.Invoke(); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({145.0f, 370.0f, 595.0f, 820.0f, 220.0f, 570.0f, + 920.0f, 1270.0f, 295.0f, 770.0f, 1245.0f, 1720.0f, + 370.0f, 970.0f, 1570.0f, 2170.0f, 445.0f, 1170.0f, + 1895.0f, 2620.0f, 520.0f, 1370.0f, 2220.0f, 3070.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index e5f00ddd229..1c73f06487b 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -36,6 +36,7 @@ TfLiteRegistration* Register_ARG_MAX(); TfLiteRegistration* Register_ARG_MIN(); TfLiteRegistration* Register_AVERAGE_POOL_2D(); TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); +TfLiteRegistration* Register_BATCH_MATMUL(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_CAST(); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index c9e6c082b53..e7612e39c71 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -213,6 +213,7 @@ cc_library( name = "optimized_base", srcs = [], hdrs = [ + "optimized/batch_matmul.h", "optimized/depthwiseconv_3x3_filter_common.h", "optimized/depthwiseconv_float.h", "optimized/depthwiseconv_multithread.h", @@ -416,6 +417,7 @@ cc_library( hdrs = [ "reference/add.h", "reference/arg_min_max.h", + "reference/batch_matmul.h", "reference/binary_function.h", "reference/ceil.h", "reference/comparisons.h", diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h new file mode 100644 index 00000000000..03cef848026 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, + const RuntimeShape& rhs_shape, const float* rhs_data, + const RuntimeShape& output_shape, float* output_data, + CpuBackendContext* context) { + using ::tflite::cpu_backend_gemm::Gemm; + using ::tflite::cpu_backend_gemm::GemmParams; + using ::tflite::cpu_backend_gemm::MatrixParams; + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(5, lhs_shape); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(5, rhs_shape); + + // Determine which dimension is the broadcast dimension. + auto broadcast_dim = [](int lhs_dim, int rhs_dim) { + if (lhs_dim == rhs_dim) return lhs_dim; + if (lhs_dim == 1) return rhs_dim; + TFLITE_DCHECK_EQ(rhs_dim, 1); + return lhs_dim; + }; + + // Compute the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + auto extent = [](const RuntimeShape& shape, int x) { + if (shape.Dims(x) == 1) { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) { + prod *= shape.Dims(i); + } + return prod; + }; + + const int batch_dim0 = + broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + const int batch_dim1 = + broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + const int batch_dim2 = + broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + const int lhs_ext0 = extent(extended_lhs_shape, 0); + const int lhs_ext1 = extent(extended_lhs_shape, 1); + const int lhs_ext2 = extent(extended_lhs_shape, 2); + const int rhs_ext0 = extent(extended_rhs_shape, 0); + const int rhs_ext1 = extent(extended_rhs_shape, 1); + const int rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + const int lhs_rows = extended_lhs_shape.Dims(3); + const int rhs_cols = extended_rhs_shape.Dims(4); + const int accum_depth = extended_lhs_shape.Dims(4); + + MatrixParams lhs_params; + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + lhs_params.rows = lhs_rows; + lhs_params.cols = accum_depth; + + MatrixParams rhs_params; + rhs_params.order = cpu_backend_gemm::Order::kColMajor; + rhs_params.rows = accum_depth; + rhs_params.cols = rhs_cols; + + MatrixParams dst_params; + dst_params.order = cpu_backend_gemm::Order::kColMajor; + dst_params.rows = lhs_rows; + dst_params.cols = rhs_cols; + + for (int b0 = 0; b0 < batch_dim0; ++b0) { + const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + for (int b1 = 0; b1 < batch_dim1; ++b1) { + const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + for (int b2 = 0; b2 < batch_dim2; ++b2) { + const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; + GemmParams gemm_params; + cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2, + dst_params, out_ptr, gemm_params, context); + } + } + } +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_ diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h new file mode 100644 index 00000000000..4fe84aa3388 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/round.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, + const RuntimeShape& rhs_shape, const float* rhs_data, + const RuntimeShape& output_shape, float* output_data) { + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(5, lhs_shape); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(5, rhs_shape); + + // Determine which dimension is the broadcast dimension. + auto broadcast_dim = [](int lhs_dim, int rhs_dim) { + if (lhs_dim == rhs_dim) return lhs_dim; + if (lhs_dim == 1) return rhs_dim; + TFLITE_DCHECK_EQ(rhs_dim, 1); + return lhs_dim; + }; + + // Compute the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + auto extent = [](const RuntimeShape& shape, int x) { + if (shape.Dims(x) == 1) { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) { + prod *= shape.Dims(i); + } + return prod; + }; + + const int batch_dim0 = + broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + const int batch_dim1 = + broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + const int batch_dim2 = + broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + const int lhs_ext0 = extent(extended_lhs_shape, 0); + const int lhs_ext1 = extent(extended_lhs_shape, 1); + const int lhs_ext2 = extent(extended_lhs_shape, 2); + const int rhs_ext0 = extent(extended_rhs_shape, 0); + const int rhs_ext1 = extent(extended_rhs_shape, 1); + const int rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + const int lhs_rows = extended_lhs_shape.Dims(3); + const int rhs_cols = extended_rhs_shape.Dims(4); + const int accum_depth = extended_lhs_shape.Dims(4); + + for (int b0 = 0; b0 < batch_dim0; ++b0) { + const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + for (int b1 = 0; b1 < batch_dim1; ++b1) { + const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + for (int b2 = 0; b2 < batch_dim2; ++b2) { + const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; + for (int j = 0; j < rhs_cols; ++j) { + for (int i = 0; i < lhs_rows; ++i) { + float total = 0.f; + for (int k = 0; k < accum_depth; ++k) { + total += + lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k]; + } + int idx = lhs_rows * j + i; + out_ptr[idx] = total; + } + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_ diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index cf9f8b99ee4..1e148a0c1f5 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -278,6 +278,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND()); AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY()); AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM()); + AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 2381e8f8c9d..426f8a8e896 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -134,6 +134,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF(); TfLiteRegistration* Register_DEPTH_TO_SPACE_REF(); TfLiteRegistration* Register_SELECT_V2(); TfLiteRegistration* Register_SEGMENT_SUM(); +TfLiteRegistration* Register_BATCH_MATMUL_REF(); namespace { diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 5c12d74c067..24cd73eef7a 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -346,7 +346,8 @@ enum BuiltinOperator : byte { SCATTER_ND = 122, SELECT_V2 = 123, DENSIFY = 124, - SEGMENT_SUM = 125 + SEGMENT_SUM = 125, + BATCH_MATMUL = 126 } @@ -451,7 +452,8 @@ union BuiltinOptions { ScatterNdOptions, SelectV2Options, DensifyOptions, - SegmentSumOptions + SegmentSumOptions, + BatchMatMulOptions } enum Padding : byte { SAME, VALID } @@ -945,6 +947,9 @@ table DensifyOptions { table SegmentSumOptions { } +table BatchMatMulOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 8caf2409b96..609eac198fb 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -346,6 +346,9 @@ struct DensifyOptionsT; struct SegmentSumOptions; struct SegmentSumOptionsT; +struct BatchMatMulOptions; +struct BatchMatMulOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -771,11 +774,12 @@ enum BuiltinOperator { BuiltinOperator_SELECT_V2 = 123, BuiltinOperator_DENSIFY = 124, BuiltinOperator_SEGMENT_SUM = 125, + BuiltinOperator_BATCH_MATMUL = 126, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM + BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -902,13 +906,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] { BuiltinOperator_SCATTER_ND, BuiltinOperator_SELECT_V2, BuiltinOperator_DENSIFY, - BuiltinOperator_SEGMENT_SUM + BuiltinOperator_SEGMENT_SUM, + BuiltinOperator_BATCH_MATMUL }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[127] = { + static const char * const names[128] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1035,13 +1040,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "SELECT_V2", "DENSIFY", "SEGMENT_SUM", + "BATCH_MATMUL", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_SEGMENT_SUM)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1148,11 +1154,12 @@ enum BuiltinOptions { BuiltinOptions_SelectV2Options = 98, BuiltinOptions_DensifyOptions = 99, BuiltinOptions_SegmentSumOptions = 100, + BuiltinOptions_BatchMatMulOptions = 101, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions + BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1254,13 +1261,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] { BuiltinOptions_ScatterNdOptions, BuiltinOptions_SelectV2Options, BuiltinOptions_DensifyOptions, - BuiltinOptions_SegmentSumOptions + BuiltinOptions_SegmentSumOptions, + BuiltinOptions_BatchMatMulOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[102] = { + static const char * const names[103] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1362,13 +1370,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "SelectV2Options", "DensifyOptions", "SegmentSumOptions", + "BatchMatMulOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_SegmentSumOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1777,6 +1786,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2609,6 +2622,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_SegmentSumOptions ? reinterpret_cast(value) : nullptr; } + tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() { + return type == BuiltinOptions_BatchMatMulOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() const { + return type == BuiltinOptions_BatchMatMulOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -9109,6 +9130,46 @@ inline flatbuffers::Offset CreateSegmentSumOptions( flatbuffers::Offset CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BatchMatMulOptionsT : public flatbuffers::NativeTable { + typedef BatchMatMulOptions TableType; + BatchMatMulOptionsT() { + } +}; + +struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BatchMatMulOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BatchMatMulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BatchMatMulOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit BatchMatMulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BatchMatMulOptionsBuilder &operator=(const BatchMatMulOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBatchMatMulOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + BatchMatMulOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; tflite::BuiltinOperator builtin_code; @@ -9545,6 +9606,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const { return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast(builtin_options()) : nullptr; } + const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -9981,6 +10045,10 @@ template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as< return builtin_options_as_SegmentSumOptions(); } +template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as() const { + return builtin_options_as_BatchMatMulOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -13392,6 +13460,29 @@ inline flatbuffers::Offset CreateSegmentSumOptions(flatbuffer _fbb); } +inline BatchMatMulOptionsT *BatchMatMulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BatchMatMulOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BatchMatMulOptions::UnPackTo(BatchMatMulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset BatchMatMulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBatchMatMulOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchMatMulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBatchMatMulOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -14197,6 +14288,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -14615,6 +14710,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -15021,6 +15120,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(value); + return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -15427,6 +15530,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::SegmentSumOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_BatchMatMulOptions: { + value = new tflite::BatchMatMulOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -15934,6 +16041,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index b375706f6c7..bbec4f91646 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -57,6 +57,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kDiv, 1}, "1.6.0"}, {{OperatorType::kBatchToSpaceND, 1}, "1.6.0"}, {{OperatorType::kBatchToSpaceND, 2}, "1.14.0"}, + {{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion}, {{OperatorType::kCast, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 2}, "1.14.0"},