From bc8bb3ba84087a86931f87226663eddb9fda7faf Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Mon, 20 Jul 2020 20:34:46 -0700 Subject: [PATCH] Add builtin BroadcastTo Op to TFLite Converter support will be added in a follow-up CL. PiperOrigin-RevId: 322281991 Change-Id: I9a96d0dff3a089a9b43b85c955cc416717e26aa9 --- RELEASE.md | 3 +- tensorflow/lite/builtin_ops.h | 1 + tensorflow/lite/c/common.c | 23 ++ tensorflow/lite/c/common.h | 3 + .../lite/core/api/flatbuffer_conversions.cc | 1 + tensorflow/lite/kernels/BUILD | 14 + tensorflow/lite/kernels/broadcast_to.cc | 136 ++++++++++ tensorflow/lite/kernels/broadcast_to_test.cc | 255 ++++++++++++++++++ tensorflow/lite/kernels/builtin_op_kernels.h | 1 + tensorflow/lite/kernels/internal/BUILD | 1 + tensorflow/lite/kernels/internal/common.h | 7 + .../kernels/internal/reference/broadcast_to.h | 90 +++++++ tensorflow/lite/kernels/register.cc | 1 + tensorflow/lite/kernels/register_ref.cc | 2 + tensorflow/lite/schema/schema.fbs | 9 +- tensorflow/lite/schema/schema_generated.h | 134 ++++++++- tensorflow/lite/toco/model.h | 1 + tensorflow/lite/toco/tflite/op_version.cc | 1 + .../benchmark/experimental/c/c_api_types.h | 3 + .../lite/tools/versioning/runtime_version.cc | 1 + .../tools/versioning/runtime_version_test.cc | 2 +- 21 files changed, 674 insertions(+), 15 deletions(-) create mode 100644 tensorflow/lite/kernels/broadcast_to.cc create mode 100644 tensorflow/lite/kernels/broadcast_to_test.cc create mode 100644 tensorflow/lite/kernels/internal/reference/broadcast_to.h diff --git a/RELEASE.md b/RELEASE.md index c4fa615cf4d..12b5168954b 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -54,7 +54,8 @@ * `tf.function`/AutoGraph: * * `tf.lite`: - * + * Better support for ops with high-dimensional broadcasting inputs by adding + `BroadcastTo` ops when necessary. * `tf.random`: * * Math and Linear Algebra: diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 85140289ac1..c6440729738 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -153,6 +153,7 @@ typedef enum { kTfLiteBuiltinDensify = 124, kTfLiteBuiltinSegmentSum = 125, kTfLiteBuiltinBatchMatmul = 126, + kTfLiteBuiltinBroadcastTo = 127, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index 0264f420b12..4bbbcbbf03c 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -219,6 +219,29 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "Unknown type"; } +// Size of string is not constant, return 0 in such case. +int TfLiteTypeGetSize(TfLiteType type) { + switch (type) { + case kTfLiteUInt8: + case kTfLiteInt8: + return 1; + case kTfLiteBool: + return sizeof(bool); + case kTfLiteInt16: + case kTfLiteFloat16: + return 2; + case kTfLiteFloat32: + case kTfLiteInt32: + return 4; + case kTfLiteInt64: + case kTfLiteComplex64: + case kTfLiteFloat64: + return 8; + default: + return 0; + } +} + TfLiteDelegate TfLiteDelegateCreate() { TfLiteDelegate d = { .data_ = NULL, diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 89b25892914..692a8eaf7a2 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -268,6 +268,9 @@ typedef enum { // Return the name of a given type, for error reporting purposes. const char* TfLiteTypeGetName(TfLiteType type); +// Return the size of given type in bytes. Return 0 in in case of string. +int TfLiteTypeGetSize(TfLiteType type); + // SupportedQuantizationTypes. typedef enum TfLiteQuantizationType { // No quantization. diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 0652c64f6c2..059ad97f551 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -820,6 +820,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SCATTER_ND: case BuiltinOperator_DENSIFY: case BuiltinOperator_SEGMENT_SUM: + case BuiltinOperator_BROADCAST_TO: return kTfLiteOk; } return kTfLiteError; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 4351a2c93a2..3157081dd21 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -491,6 +491,7 @@ BUILTIN_KERNEL_SRCS = [ "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", + "broadcast_to.cc", "cast.cc", "ceil.cc", "comparisons.cc", @@ -984,6 +985,19 @@ cc_test( ], ) +cc_test( + name = "broadcast_to_test", + size = "small", + srcs = ["broadcast_to_test.cc"], + deps = [ + ":builtin_ops", + ":test_main", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "cast_test", size = "small", diff --git a/tensorflow/lite/kernels/broadcast_to.cc b/tensorflow/lite/kernels/broadcast_to.cc new file mode 100644 index 00000000000..0e7baca2277 --- /dev/null +++ b/tensorflow/lite/kernels/broadcast_to.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h" + +#include + +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace broadcastto { + +constexpr int kInputTensor = 0; +constexpr int kShapeTensor = 1; +constexpr int kOutputTensor = 0; +constexpr int kMaxDims = 8; + +struct BroadcastToContext { + BroadcastToContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, kInputTensor); + shape = GetInput(context, node, kShapeTensor); + output = GetOutput(context, node, kOutputTensor); + } + const TfLiteTensor* input; + const TfLiteTensor* shape; + TfLiteTensor* output; +}; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BroadcastToContext* op_context) { + // Ensures the shape is 1D tensor. + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1); + + // Ensure output dims is not less than input dims. + int input_num_dims = NumDimensions(op_context->input); + int output_num_dims = SizeOfDimension(op_context->shape, 0); + TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims, + "Output shape must be broadcastable from input shape."); + TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims, + "BroadcastTo only supports 1-8D tensor."); + + // Check if output shape is broadcastable from input shape. + auto get_shape_data = [op_context](int i) -> int32_t { + if (op_context->shape->type == kTfLiteInt32) { + return GetTensorData(op_context->shape)[i]; + } else { + return GetTensorData(op_context->shape)[i]; + } + }; + + int extending_dims = output_num_dims - input_num_dims; + for (int idx = 0; idx < input_num_dims; ++idx) { + TF_LITE_ENSURE_MSG(context, + (SizeOfDimension(op_context->input, idx) == 1 || + SizeOfDimension(op_context->input, idx) == + get_shape_data(extending_dims + idx)), + "Output shape must be broadcastable from input shape."); + } + // Resizing the shape of the output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims); + std::unique_ptr + scoped_output_shape(output_shape, TfLiteIntArrayFree); + for (int idx = 0; idx < output_num_dims; ++idx) { + output_shape->data[idx] = get_shape_data(idx); + } + + return context->ResizeTensor(context, op_context->output, + scoped_output_shape.release()); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TF_LITE_ENSURE_MSG(context, + (NumDimensions(GetInput(context, node, 0)) <= kMaxDims), + "BroadcastTo only supports 1-8D tensor."); + + BroadcastToContext op_context(context, node); + TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 || + op_context.shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + // Not yet support string type due to the use of memcopy with fixed size. + TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString); + + if (IsConstantTensor(op_context.shape)) { + return ResizeOutputTensor(context, &op_context); + } + + SetTensorToDynamic(op_context.output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + BroadcastToContext op_context(context, node); + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + + // BroadcastTo op support upto 8 dims, matching the support of Tensorflow. + reference_ops::BroadcastTo( + GetTensorShape(op_context.input), op_context.input->data.raw, + GetTensorShape(op_context.output), op_context.output->data.raw, + op_context.input->type); + return kTfLiteOk; +} + +} // namespace broadcastto + +TfLiteRegistration* Register_BROADCAST_TO() { + static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare, + broadcastto::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/broadcast_to_test.cc b/tensorflow/lite/kernels/broadcast_to_test.cc new file mode 100644 index 00000000000..a36ed352055 --- /dev/null +++ b/tensorflow/lite/kernels/broadcast_to_test.cc @@ -0,0 +1,255 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { +using ::testing::ElementsAreArray; + +template +class BroadcastToOpModel : public SingleOpModel { + public: + // BroadcastTo with dynamic shape. + BroadcastToOpModel(std::initializer_list input_shape, + std::initializer_list shape_shape) { + input_ = AddInput({GetTensorType(), input_shape}); + shape_ = AddInput({GetTensorType(), shape_shape}); + output_ = AddOutput(GetTensorType()); + SetBuiltinOp(BuiltinOperator_BROADCAST_TO, + BuiltinOptions_BroadcastToOptions, + CreateBroadcastToOptions(builder_).Union()); + BuildInterpreter({input_shape, shape_shape}); + } + + // BroadcastTo with const shape. + BroadcastToOpModel(std::initializer_list input_shape, + std::initializer_list shape_shape, + std::initializer_list shape_values) { + input_ = AddInput({GetTensorType(), input_shape}); + shape_ = + AddConstInput(GetTensorType(), shape_values, shape_shape); + output_ = AddOutput(GetTensorType()); + SetBuiltinOp(BuiltinOperator_BROADCAST_TO, + BuiltinOptions_BroadcastToOptions, + CreateBroadcastToOptions(builder_).Union()); + BuildInterpreter({input_shape, shape_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetShape(std::initializer_list data) { + PopulateTensor(shape_, data); + } + + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int shape_; + int output_; +}; + +template +class BroadcastToOpTest : public ::testing::Test {}; + +using DataTypes = ::testing::Types; +TYPED_TEST_SUITE(BroadcastToOpTest, DataTypes); + +#ifdef GTEST_HAS_DEATH_TEST +TYPED_TEST(BroadcastToOpTest, ShapeMustBe1D) { + EXPECT_DEATH( + BroadcastToOpModel({2, 3, 4, 4}, {2, 2}, {2, 3, 4, 4}), ""); + // Non-constant Shape tensor. + BroadcastToOpModel m({2, 3, 4, 4}, {2, 2}); + m.SetShape({2, 3, 4, 4}); + EXPECT_THAT(m.InvokeUnchecked(), kTfLiteError); +} + +TYPED_TEST(BroadcastToOpTest, TooManyDimensions) { + EXPECT_DEATH(BroadcastToOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}, + {2, 2, 3, 4, 5, 6, 7, 8, 9}), + "BroadcastTo only supports 1-8D tensor."); + EXPECT_DEATH(BroadcastToOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}), + "BroadcastTo only supports 1-8D tensor."); +} + +TYPED_TEST(BroadcastToOpTest, MismatchDimension) { + EXPECT_DEATH(BroadcastToOpModel({2, 4, 1, 2}, {4}, {2, 4, 1, 3}), + "Output shape must be broadcastable from input shape."); + EXPECT_DEATH( + BroadcastToOpModel({2, 4, 1, 2, 3}, {4}, {2, 4, 1, 2}), + "Output shape must be broadcastable from input shape."); + + // Non-constant Shape tensor. + BroadcastToOpModel m1({2, 4, 1, 2}, {4}); + m1.SetShape({2, 3, 4, 4}); + EXPECT_THAT(m1.InvokeUnchecked(), kTfLiteError); + BroadcastToOpModel m2({2, 4, 1, 2}, {5}); + m2.SetShape({1, 2, 3, 4, 4}); + EXPECT_THAT(m2.InvokeUnchecked(), kTfLiteError); +} +#endif + +TYPED_TEST(BroadcastToOpTest, BroadcastTo1DConstTest) { + BroadcastToOpModel m({1}, {1}, {4}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo4DConstTest) { + BroadcastToOpModel m({1, 1, 1, 2}, {4}, {1, 1, 2, 2}); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo8DConstTest) { + BroadcastToOpModel m({1, 1, 1, 1, 1, 1, 2, 1}, {8}, + {1, 1, 1, 1, 1, 1, 2, 2}); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo1DDynamicTest) { + BroadcastToOpModel m({1}, {1}); + m.SetInput({3}); + m.SetShape({4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo4DDynamicTest) { + BroadcastToOpModel m({1, 1, 1, 2}, {4}); + m.SetInput({3, 4}); + m.SetShape({1, 1, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo8DDynamicTest) { + BroadcastToOpModel m({1, 1, 1, 1, 1, 1, 2, 1}, {8}); + m.SetInput({3, 4}); + m.SetShape({1, 1, 1, 1, 1, 1, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DConstTest) { + BroadcastToOpModel m({1, 3, 1, 2}, {4}, {3, 3, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, + 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DDynamicTest) { + BroadcastToOpModel m({1, 3, 1, 2}, {4}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetShape({3, 3, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, + 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DConstTest) { + BroadcastToOpModel m({1, 2, 1, 3, 1, 2}, {6}, {2, 2, 1, 3, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12, + 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) { + BroadcastToOpModel m({1, 2, 1, 3, 1, 2}, {6}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetShape({2, 2, 1, 3, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12, + 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12})); +} + +TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) { + BroadcastToOpModel m({3, 1, 2}, {4}, {3, 3, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, + 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) { + BroadcastToOpModel m({3, 1, 2}, {3}, {3, 1, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) { + BroadcastToOpModel m({1, 1, 1, 1, 1, 1, 2, 1}, {8}, + {1, 1, 1, 1, 1, 1, 2, 2}); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +TYPED_TEST(BroadcastToOpTest, Int64ShapeDDynamicTest) { + BroadcastToOpModel m({1, 1, 1, 1, 1, 1, 2, 1}, {8}); + m.SetInput({3, 4}); + m.SetShape({1, 1, 1, 1, 1, 1, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index 1c73f06487b..fea25f8605c 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -39,6 +39,7 @@ TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_BATCH_MATMUL(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); +TfLiteRegistration* Register_BROADCAST_TO(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_CEIL(); TfLiteRegistration* Register_CONCATENATION(); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 5acabeb45cd..075c1da9865 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -441,6 +441,7 @@ cc_library( "reference/arg_min_max.h", "reference/batch_matmul.h", "reference/binary_function.h", + "reference/broadcast_to.h", "reference/ceil.h", "reference/comparisons.h", "reference/concatenation.h", diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index c45aff9e47b..10cb164e696 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -665,6 +665,13 @@ inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) { indexes[4] * desc.strides[4]; } +inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) { + return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] + + indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] + + indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] + + indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7]; +} + // Given the dimensions of the operands for an element-wise binary broadcast, // adjusts them so that they can be directly iterated over with simple loops. // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and diff --git a/tensorflow/lite/kernels/internal/reference/broadcast_to.h b/tensorflow/lite/kernels/internal/reference/broadcast_to.h new file mode 100644 index 00000000000..69f4531ba14 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/broadcast_to.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_ops { +template +void BroadcastImpl(const NdArrayDesc& input_desc, const char* input_data, + const NdArrayDesc& output_desc, char* output_data, + int indexes[N], int dim, const int last_broadcasting_dim, + const int type_size) { + // Copy data from input to output. + if (dim == last_broadcasting_dim) { + int copy_size = output_desc.strides[dim] * type_size; + const char* data_src = + input_data + SubscriptToIndex(input_desc, indexes) * type_size; + char* data_dst = + output_data + SubscriptToIndex(output_desc, indexes) * type_size; + for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) { + memcpy(data_dst, data_src, copy_size); + } + return; + } + + // Recursive call to find the next broadcasting. + for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim]; + ++indexes[dim]) { + BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes, + dim + 1, last_broadcasting_dim, type_size); + } + + // Duplicate data in output tensor. + indexes[dim] = 0; + if (input_desc.extents[dim] != output_desc.extents[dim]) { + int copy_size = output_desc.strides[dim] * type_size; + char* data_src = + output_data + SubscriptToIndex(output_desc, indexes) * type_size; + char* data_dst = data_src + copy_size; + for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) { + memcpy(data_dst, data_src, copy_size); + } + } +} + +template +inline void BroadcastTo(const RuntimeShape& unextended_input_shape, + const char* input_data, + const RuntimeShape& unextended_output_shape, + char* output_data, TfLiteType data_type) { + NdArrayDesc input_desc; + NdArrayDesc output_desc; + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape), + &input_desc); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape), + &output_desc); + + // Get the last dimension has broadcasting. At this dimension, the data is + // copied from input tensor to output tensor. + int last_broadcast_dim = 0; + for (int i = N - 1; i > 0; --i) { + if (input_desc.extents[i] != output_desc.extents[i]) { + last_broadcast_dim = i; + break; + } + } + + // Broadcasting using memcpy. + int indexes[N] = {0}; + BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes, 0, + last_broadcast_dim, TfLiteTypeGetSize(data_type)); +} +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 275340ec225..da7480d448d 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -292,6 +292,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(), /* min_version = */ 1, /* max_version = */ 2); + AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 233520e2165..b1dd31ab168 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -139,6 +139,7 @@ TfLiteRegistration* Register_DEPTH_TO_SPACE_REF(); TfLiteRegistration* Register_SELECT_V2(); TfLiteRegistration* Register_SEGMENT_SUM(); TfLiteRegistration* Register_BATCH_MATMUL_REF(); +TfLiteRegistration* Register_BROADCAST_TO(); namespace { @@ -207,6 +208,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { Register_SPACE_TO_BATCH_ND_REF()); AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND_REF()); + AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO()); AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF()); AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 878acde1e16..949d769b457 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -349,7 +349,8 @@ enum BuiltinOperator : byte { SELECT_V2 = 123, DENSIFY = 124, SEGMENT_SUM = 125, - BATCH_MATMUL = 126 + BATCH_MATMUL = 126, + BROADCAST_TO = 127 } @@ -455,7 +456,8 @@ union BuiltinOptions { SelectV2Options, DensifyOptions, SegmentSumOptions, - BatchMatMulOptions + BatchMatMulOptions, + BroadcastToOptions } enum Padding : byte { SAME, VALID } @@ -975,6 +977,9 @@ table BatchMatMulOptions { adj_y:bool; } +table BroadcastToOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index a6117dc72ab..7bf79e52e27 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -349,6 +349,9 @@ struct SegmentSumOptionsT; struct BatchMatMulOptions; struct BatchMatMulOptionsT; +struct BroadcastToOptions; +struct BroadcastToOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -781,11 +784,12 @@ enum BuiltinOperator { BuiltinOperator_DENSIFY = 124, BuiltinOperator_SEGMENT_SUM = 125, BuiltinOperator_BATCH_MATMUL = 126, + BuiltinOperator_BROADCAST_TO = 127, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL + BuiltinOperator_MAX = BuiltinOperator_BROADCAST_TO }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[128] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -913,13 +917,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] { BuiltinOperator_SELECT_V2, BuiltinOperator_DENSIFY, BuiltinOperator_SEGMENT_SUM, - BuiltinOperator_BATCH_MATMUL + BuiltinOperator_BATCH_MATMUL, + BuiltinOperator_BROADCAST_TO }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[128] = { + static const char * const names[129] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1047,13 +1052,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "DENSIFY", "SEGMENT_SUM", "BATCH_MATMUL", + "BROADCAST_TO", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_TO)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1161,11 +1167,12 @@ enum BuiltinOptions { BuiltinOptions_DensifyOptions = 99, BuiltinOptions_SegmentSumOptions = 100, BuiltinOptions_BatchMatMulOptions = 101, + BuiltinOptions_BroadcastToOptions = 102, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions + BuiltinOptions_MAX = BuiltinOptions_BroadcastToOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[103] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1268,13 +1275,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] { BuiltinOptions_SelectV2Options, BuiltinOptions_DensifyOptions, BuiltinOptions_SegmentSumOptions, - BuiltinOptions_BatchMatMulOptions + BuiltinOptions_BatchMatMulOptions, + BuiltinOptions_BroadcastToOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[103] = { + static const char * const names[104] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1377,13 +1385,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "DensifyOptions", "SegmentSumOptions", "BatchMatMulOptions", + "BroadcastToOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BroadcastToOptions)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1796,6 +1805,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2636,6 +2649,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_BatchMatMulOptions ? reinterpret_cast(value) : nullptr; } + tflite::BroadcastToOptionsT *AsBroadcastToOptions() { + return type == BuiltinOptions_BroadcastToOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BroadcastToOptionsT *AsBroadcastToOptions() const { + return type == BuiltinOptions_BroadcastToOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -9310,6 +9331,46 @@ inline flatbuffers::Offset CreateBatchMatMulOptions( flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BroadcastToOptionsT : public flatbuffers::NativeTable { + typedef BroadcastToOptions TableType; + BroadcastToOptionsT() { + } +}; + +struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BroadcastToOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BroadcastToOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BroadcastToOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit BroadcastToOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BroadcastToOptionsBuilder &operator=(const BroadcastToOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBroadcastToOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + BroadcastToOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; tflite::BuiltinOperator builtin_code; @@ -9749,6 +9810,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const { return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast(builtin_options()) : nullptr; } + const tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -10189,6 +10253,10 @@ template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as return builtin_options_as_BatchMatMulOptions(); } +template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as() const { + return builtin_options_as_BroadcastToOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -13656,6 +13724,29 @@ inline flatbuffers::Offset CreateBatchMatMulOptions(flatbuff _adj_y); } +inline BroadcastToOptionsT *BroadcastToOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BroadcastToOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BroadcastToOptions::UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset BroadcastToOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBroadcastToOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BroadcastToOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBroadcastToOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -14465,6 +14556,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -14887,6 +14982,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -15297,6 +15396,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(value); + return CreateBroadcastToOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -15707,6 +15810,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::BatchMatMulOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_BroadcastToOptions: { + value = new tflite::BroadcastToOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -16219,6 +16326,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; @@ -16282,4 +16394,4 @@ inline std::unique_ptr UnPackSizePrefixedModel( } // namespace tflite -#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ +#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ \ No newline at end of file diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index b42fed6fbc1..2478ca6f4a3 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -43,6 +43,7 @@ enum class OperatorType : uint8 { kAveragePool, kBatchMatMul, kBatchNormalization, + kBroadcastTo, kCeil, kConv, kConcatenation, diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index b16f282bedd..3793bb50c9f 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -63,6 +63,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kBatchToSpaceND, 1}, "1.6.0"}, {{OperatorType::kBatchToSpaceND, 2}, "1.14.0"}, {{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion}, + {{OperatorType::kBroadcastTo, 1}, kPendingReleaseOpVersion}, {{OperatorType::kCast, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 1}, "1.5.0"}, {{OperatorType::kConcatenation, 2}, "1.14.0"}, diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 89b25892914..692a8eaf7a2 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -268,6 +268,9 @@ typedef enum { // Return the name of a given type, for error reporting purposes. const char* TfLiteTypeGetName(TfLiteType type); +// Return the size of given type in bytes. Return 0 in in case of string. +int TfLiteTypeGetSize(TfLiteType type); + // SupportedQuantizationTypes. typedef enum TfLiteQuantizationType { // No quantization. diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index c2e3f279a90..8938c0b5d4b 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -59,6 +59,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"}, {{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"}, {{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"}, + {{BuiltinOperator_BROADCAST_TO, 1}, kPendingReleaseVersion}, {{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, {{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, {{BuiltinOperator_CONV_2D, 3}, "1.14.0"}, diff --git a/tensorflow/lite/tools/versioning/runtime_version_test.cc b/tensorflow/lite/tools/versioning/runtime_version_test.cc index c32de228cc3..df1ca46410c 100644 --- a/tensorflow/lite/tools/versioning/runtime_version_test.cc +++ b/tensorflow/lite/tools/versioning/runtime_version_test.cc @@ -47,7 +47,7 @@ TEST(OpVersionTest, OpversionMissing) { EXPECT_NE(runtime_version, "") << "Please add the version " << version << " of " << tflite::EnumNamesBuiltinOperator()[op_code] - << " runtime_version.cc"; + << " to runtime_version.cc"; } } }