From 4565291a981f23bd2ab6a9f13029855deec3783d Mon Sep 17 00:00:00 2001 From: Jaesung Chung <jaesung@google.com> Date: Tue, 3 Nov 2020 18:27:02 -0800 Subject: [PATCH] Add builtin BroadcastTo Op to TFLite Converter support will be added in a follow-up CL. PiperOrigin-RevId: 340563780 Change-Id: I4ea49fef309518b6447cb117365653b9b1d7d6a1 --- tensorflow/lite/builtin_ops.h | 1 + .../lite/core/api/flatbuffer_conversions.cc | 1 + tensorflow/lite/kernels/BUILD | 14 + tensorflow/lite/kernels/broadcast_to.cc | 136 ++++++++++ tensorflow/lite/kernels/broadcast_to_test.cc | 255 ++++++++++++++++++ tensorflow/lite/kernels/builtin_op_kernels.h | 1 + tensorflow/lite/kernels/internal/BUILD | 2 + tensorflow/lite/kernels/internal/common.h | 7 + .../kernels/internal/reference/broadcast_to.h | 90 +++++++ tensorflow/lite/kernels/kernel_util.cc | 41 +++ tensorflow/lite/kernels/kernel_util.h | 4 + tensorflow/lite/kernels/register.cc | 6 + tensorflow/lite/kernels/register_ref.cc | 7 + tensorflow/lite/schema/schema.fbs | 9 +- tensorflow/lite/schema/schema_generated.h | 132 ++++++++- .../lite/tools/versioning/op_version.cc | 5 + .../lite/tools/versioning/runtime_version.cc | 4 + .../tools/versioning/runtime_version_test.cc | 2 +- 18 files changed, 704 insertions(+), 13 deletions(-) create mode 100644 tensorflow/lite/kernels/broadcast_to.cc create mode 100644 tensorflow/lite/kernels/broadcast_to_test.cc create mode 100644 tensorflow/lite/kernels/internal/reference/broadcast_to.h diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 71cefe62f92..e597f5c34c0 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -157,6 +157,7 @@ typedef enum { kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127, kTfLiteBuiltinCumsum = 128, kTfLiteBuiltinCallOnce = 129, + kTfLiteBuiltinBroadcastTo = 130, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index dee2a990761..16118d41e65 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -822,6 +822,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SCATTER_ND: case BuiltinOperator_DENSIFY: case BuiltinOperator_SEGMENT_SUM: + case BuiltinOperator_BROADCAST_TO: return kTfLiteOk; case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES: return kTfLiteError; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 9cc5d0452ec..5f67a2e2def 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -544,6 +544,7 @@ BUILTIN_KERNEL_SRCS = [ "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", + "broadcast_to.cc", "call_once.cc", "cast.cc", "ceil.cc", @@ -1017,6 +1018,19 @@ cc_test( ], ) +cc_test( + name = "broadcast_to_test", + size = "small", + srcs = ["broadcast_to_test.cc"], + deps = [ + ":builtin_ops", + ":test_main", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "cast_test", size = "small", diff --git a/tensorflow/lite/kernels/broadcast_to.cc b/tensorflow/lite/kernels/broadcast_to.cc new file mode 100644 index 00000000000..0e7baca2277 --- /dev/null +++ b/tensorflow/lite/kernels/broadcast_to.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h" + +#include <string.h> + +#include <cstdint> +#include <memory> + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace broadcastto { + +constexpr int kInputTensor = 0; +constexpr int kShapeTensor = 1; +constexpr int kOutputTensor = 0; +constexpr int kMaxDims = 8; + +struct BroadcastToContext { + BroadcastToContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, kInputTensor); + shape = GetInput(context, node, kShapeTensor); + output = GetOutput(context, node, kOutputTensor); + } + const TfLiteTensor* input; + const TfLiteTensor* shape; + TfLiteTensor* output; +}; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BroadcastToContext* op_context) { + // Ensures the shape is 1D tensor. + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1); + + // Ensure output dims is not less than input dims. + int input_num_dims = NumDimensions(op_context->input); + int output_num_dims = SizeOfDimension(op_context->shape, 0); + TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims, + "Output shape must be broadcastable from input shape."); + TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims, + "BroadcastTo only supports 1-8D tensor."); + + // Check if output shape is broadcastable from input shape. + auto get_shape_data = [op_context](int i) -> int32_t { + if (op_context->shape->type == kTfLiteInt32) { + return GetTensorData<int32_t>(op_context->shape)[i]; + } else { + return GetTensorData<int64_t>(op_context->shape)[i]; + } + }; + + int extending_dims = output_num_dims - input_num_dims; + for (int idx = 0; idx < input_num_dims; ++idx) { + TF_LITE_ENSURE_MSG(context, + (SizeOfDimension(op_context->input, idx) == 1 || + SizeOfDimension(op_context->input, idx) == + get_shape_data(extending_dims + idx)), + "Output shape must be broadcastable from input shape."); + } + // Resizing the shape of the output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims); + std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> + scoped_output_shape(output_shape, TfLiteIntArrayFree); + for (int idx = 0; idx < output_num_dims; ++idx) { + output_shape->data[idx] = get_shape_data(idx); + } + + return context->ResizeTensor(context, op_context->output, + scoped_output_shape.release()); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TF_LITE_ENSURE_MSG(context, + (NumDimensions(GetInput(context, node, 0)) <= kMaxDims), + "BroadcastTo only supports 1-8D tensor."); + + BroadcastToContext op_context(context, node); + TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 || + op_context.shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + // Not yet support string type due to the use of memcopy with fixed size. + TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString); + + if (IsConstantTensor(op_context.shape)) { + return ResizeOutputTensor(context, &op_context); + } + + SetTensorToDynamic(op_context.output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + BroadcastToContext op_context(context, node); + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + + // BroadcastTo op support upto 8 dims, matching the support of Tensorflow. + reference_ops::BroadcastTo<kMaxDims>( + GetTensorShape(op_context.input), op_context.input->data.raw, + GetTensorShape(op_context.output), op_context.output->data.raw, + op_context.input->type); + return kTfLiteOk; +} + +} // namespace broadcastto + +TfLiteRegistration* Register_BROADCAST_TO() { + static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare, + broadcastto::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/broadcast_to_test.cc b/tensorflow/lite/kernels/broadcast_to_test.cc new file mode 100644 index 00000000000..a36ed352055 --- /dev/null +++ b/tensorflow/lite/kernels/broadcast_to_test.cc @@ -0,0 +1,255 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <cstdint> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { +using ::testing::ElementsAreArray; + +template <class InputType, class ShapeType = int32_t> +class BroadcastToOpModel : public SingleOpModel { + public: + // BroadcastTo with dynamic shape. + BroadcastToOpModel(std::initializer_list<int> input_shape, + std::initializer_list<int> shape_shape) { + input_ = AddInput({GetTensorType<InputType>(), input_shape}); + shape_ = AddInput({GetTensorType<ShapeType>(), shape_shape}); + output_ = AddOutput(GetTensorType<InputType>()); + SetBuiltinOp(BuiltinOperator_BROADCAST_TO, + BuiltinOptions_BroadcastToOptions, + CreateBroadcastToOptions(builder_).Union()); + BuildInterpreter({input_shape, shape_shape}); + } + + // BroadcastTo with const shape. + BroadcastToOpModel(std::initializer_list<int> input_shape, + std::initializer_list<int> shape_shape, + std::initializer_list<ShapeType> shape_values) { + input_ = AddInput({GetTensorType<InputType>(), input_shape}); + shape_ = + AddConstInput(GetTensorType<ShapeType>(), shape_values, shape_shape); + output_ = AddOutput(GetTensorType<InputType>()); + SetBuiltinOp(BuiltinOperator_BROADCAST_TO, + BuiltinOptions_BroadcastToOptions, + CreateBroadcastToOptions(builder_).Union()); + BuildInterpreter({input_shape, shape_shape}); + } + + void SetInput(std::initializer_list<InputType> data) { + PopulateTensor(input_, data); + } + + void SetShape(std::initializer_list<ShapeType> data) { + PopulateTensor(shape_, data); + } + + std::vector<InputType> GetOutput() { + return ExtractVector<InputType>(output_); + } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int shape_; + int output_; +}; + +template <typename T> +class BroadcastToOpTest : public ::testing::Test {}; + +using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>; +TYPED_TEST_SUITE(BroadcastToOpTest, DataTypes); + +#ifdef GTEST_HAS_DEATH_TEST +TYPED_TEST(BroadcastToOpTest, ShapeMustBe1D) { + EXPECT_DEATH( + BroadcastToOpModel<TypeParam>({2, 3, 4, 4}, {2, 2}, {2, 3, 4, 4}), ""); + // Non-constant Shape tensor. + BroadcastToOpModel<TypeParam> m({2, 3, 4, 4}, {2, 2}); + m.SetShape({2, 3, 4, 4}); + EXPECT_THAT(m.InvokeUnchecked(), kTfLiteError); +} + +TYPED_TEST(BroadcastToOpTest, TooManyDimensions) { + EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}, + {2, 2, 3, 4, 5, 6, 7, 8, 9}), + "BroadcastTo only supports 1-8D tensor."); + EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}), + "BroadcastTo only supports 1-8D tensor."); +} + +TYPED_TEST(BroadcastToOpTest, MismatchDimension) { + EXPECT_DEATH(BroadcastToOpModel<TypeParam>({2, 4, 1, 2}, {4}, {2, 4, 1, 3}), + "Output shape must be broadcastable from input shape."); + EXPECT_DEATH( + BroadcastToOpModel<TypeParam>({2, 4, 1, 2, 3}, {4}, {2, 4, 1, 2}), + "Output shape must be broadcastable from input shape."); + + // Non-constant Shape tensor. + BroadcastToOpModel<TypeParam> m1({2, 4, 1, 2}, {4}); + m1.SetShape({2, 3, 4, 4}); + EXPECT_THAT(m1.InvokeUnchecked(), kTfLiteError); + BroadcastToOpModel<TypeParam> m2({2, 4, 1, 2}, {5}); + m2.SetShape({1, 2, 3, 4, 4}); + EXPECT_THAT(m2.InvokeUnchecked(), kTfLiteError); +} +#endif + +TYPED_TEST(BroadcastToOpTest, BroadcastTo1DConstTest) { + BroadcastToOpModel<TypeParam> m({1}, {1}, {4}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo4DConstTest) { + BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4}, {1, 1, 2, 2}); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo8DConstTest) { + BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8}, + {1, 1, 1, 1, 1, 1, 2, 2}); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo1DDynamicTest) { + BroadcastToOpModel<TypeParam> m({1}, {1}); + m.SetInput({3}); + m.SetShape({4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo4DDynamicTest) { + BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4}); + m.SetInput({3, 4}); + m.SetShape({1, 1, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4})); +} + +TYPED_TEST(BroadcastToOpTest, BroadcastTo8DDynamicTest) { + BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8}); + m.SetInput({3, 4}); + m.SetShape({1, 1, 1, 1, 1, 1, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DConstTest) { + BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4}, {3, 3, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, + 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DDynamicTest) { + BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetShape({3, 3, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, + 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DConstTest) { + BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6}, {2, 2, 1, 3, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12, + 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12})); +} + +TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) { + BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetShape({2, 2, 1, 3, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12, + 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, + 7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12})); +} + +TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) { + BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, + 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) { + BroadcastToOpModel<TypeParam> m({3, 1, 2}, {3}, {3, 1, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) { + BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8}, + {1, 1, 1, 1, 1, 1, 2, 2}); + m.SetInput({3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +TYPED_TEST(BroadcastToOpTest, Int64ShapeDDynamicTest) { + BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8}); + m.SetInput({3, 4}); + m.SetShape({1, 1, 1, 1, 1, 1, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index f1ba36e59b8..6ab4493e4d0 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -39,6 +39,7 @@ TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_BATCH_MATMUL(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); +TfLiteRegistration* Register_BROADCAST_TO(); TfLiteRegistration* Register_CALL_ONCE(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_CEIL(); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 94135c6adbe..819d711df84 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -452,6 +452,7 @@ cc_library( "reference/arg_min_max.h", "reference/batch_matmul.h", "reference/binary_function.h", + "reference/broadcast_to.h", "reference/ceil.h", "reference/comparisons.h", "reference/concatenation.h", @@ -521,6 +522,7 @@ cc_library( "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/tools/optimize/sparsity:format_converter", ] + select({ diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index dec6c9721a3..662a1864025 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -696,6 +696,13 @@ inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) { indexes[4] * desc.strides[4]; } +inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) { + return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] + + indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] + + indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] + + indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7]; +} + // Given the dimensions of the operands for an element-wise binary broadcast, // adjusts them so that they can be directly iterated over with simple loops. // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and diff --git a/tensorflow/lite/kernels/internal/reference/broadcast_to.h b/tensorflow/lite/kernels/internal/reference/broadcast_to.h new file mode 100644 index 00000000000..09ffa704cca --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/broadcast_to.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace reference_ops { +template <int N> +void BroadcastImpl(const NdArrayDesc<N>& input_desc, const char* input_data, + const NdArrayDesc<N>& output_desc, char* output_data, + int indexes[N], int dim, const int last_broadcasting_dim, + const int type_size) { + // Copy data from input to output. + if (dim == last_broadcasting_dim) { + int copy_size = output_desc.strides[dim] * type_size; + const char* data_src = + input_data + SubscriptToIndex(input_desc, indexes) * type_size; + char* data_dst = + output_data + SubscriptToIndex(output_desc, indexes) * type_size; + for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) { + memcpy(data_dst, data_src, copy_size); + } + return; + } + + // Recursive call to find the next broadcasting. + for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim]; + ++indexes[dim]) { + BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, + dim + 1, last_broadcasting_dim, type_size); + } + + // Duplicate data in output tensor. + indexes[dim] = 0; + if (input_desc.extents[dim] != output_desc.extents[dim]) { + int copy_size = output_desc.strides[dim] * type_size; + char* data_src = + output_data + SubscriptToIndex(output_desc, indexes) * type_size; + char* data_dst = data_src + copy_size; + for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) { + memcpy(data_dst, data_src, copy_size); + } + } +} + +template <int N> +inline void BroadcastTo(const RuntimeShape& unextended_input_shape, + const char* input_data, + const RuntimeShape& unextended_output_shape, + char* output_data, TfLiteType data_type) { + NdArrayDesc<N> input_desc; + NdArrayDesc<N> output_desc; + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape), + &input_desc); + CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape), + &output_desc); + + // Get the last dimension has broadcasting. At this dimension, the data is + // copied from input tensor to output tensor. + int last_broadcast_dim = 0; + for (int i = N - 1; i > 0; --i) { + if (input_desc.extents[i] != output_desc.extents[i]) { + last_broadcast_dim = i; + break; + } + } + + // Broadcasting using memcpy. + int indexes[N] = {0}; + BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0, + last_broadcast_dim, TfLiteTypeGetSize(data_type)); +} +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index a834d8ab913..f986655f413 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include <stdlib.h> #include <algorithm> +#include <complex> #include <limits> #include <memory> @@ -434,4 +435,44 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, } #endif // TF_LITE_STATIC_MEMORY +// Size of string is not constant, return 0 in such case. +int TfLiteTypeGetSize(TfLiteType type) { + switch (type) { + case kTfLiteUInt8: + TF_LITE_ASSERT_EQ(sizeof(uint8_t), 1); + return 1; + case kTfLiteInt8: + TF_LITE_ASSERT_EQ(sizeof(int8_t), 1); + return 1; + case kTfLiteBool: + return sizeof(bool); + case kTfLiteInt16: + TF_LITE_ASSERT_EQ(sizeof(int16_t), 2); + return 2; + case kTfLiteFloat16: + TF_LITE_ASSERT_EQ(sizeof(int16_t), 2); + return 2; + case kTfLiteFloat32: + TF_LITE_ASSERT_EQ(sizeof(float), 4); + return 4; + case kTfLiteInt32: + TF_LITE_ASSERT_EQ(sizeof(int32_t), 4); + return 4; + case kTfLiteInt64: + TF_LITE_ASSERT_EQ(sizeof(int64_t), 8); + return 8; + case kTfLiteFloat64: + TF_LITE_ASSERT_EQ(sizeof(double), 8); + return 8; + case kTfLiteComplex64: + TF_LITE_ASSERT_EQ(sizeof(std::complex<float>), 8); + return 8; + case kTfLiteComplex128: + TF_LITE_ASSERT_EQ(sizeof(std::complex<double>), 16); + return 16; + default: + return 0; + } +} + } // namespace tflite diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 06f24b8e7d1..7a1aa165405 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -284,6 +284,10 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, const TfLiteTensor* input2, const TfLiteTensor* input3, TfLiteIntArray** output_shape); + +// Return the size of given type in bytes. Return 0 in in case of string. +int TfLiteTypeGetSize(TfLiteType type); + } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 997d6d86a4e..33486922ea3 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -297,6 +297,12 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version = */ 1, /* max_version = */ 3); AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM()); + // The version one of broadcast to op won't be not supported since the version + // one was rollbacked and the builtin op code number has been changed because + // of builtin op code shortage problem. + AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(), + /* min_version = */ 2, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_CALL_ONCE, tflite::ops::builtin::Register_CALL_ONCE()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index 1fdfd3a073d..d5f5ba75833 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -156,6 +156,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF(); TfLiteRegistration* Register_DEPTH_TO_SPACE_REF(); TfLiteRegistration* Register_SELECT_V2(); TfLiteRegistration* Register_SEGMENT_SUM(); +TfLiteRegistration* Register_BROADCAST_TO(); namespace { @@ -262,6 +263,12 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF(), /* min_version = */ 1, /* max_version = */ 2); + // The version one of broadcast to op won't be not supported since the version + // one was rollbacked and the builtin op code number has been changed because + // of builtin op code shortage problem. + AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(), + /* min_version = */ 2, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORM_REF()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index ef1592193f7..2a0981b40e3 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -353,7 +353,8 @@ enum BuiltinOperator : int32 { BATCH_MATMUL = 126, PLACEHOLDER_FOR_GREATER_OP_CODES = 127, CUMSUM = 128, - CALL_ONCE = 129 + CALL_ONCE = 129, + BROADCAST_TO = 130 } @@ -461,7 +462,8 @@ union BuiltinOptions { SegmentSumOptions, BatchMatMulOptions, CumsumOptions, - CallOnceOptions + CallOnceOptions, + BroadcastToOptions } enum Padding : byte { SAME, VALID } @@ -994,6 +996,9 @@ table CumsumOptions { reverse:bool; } +table BroadcastToOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index dd9b655c6e6..84442b86b22 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -355,6 +355,9 @@ struct BatchMatMulOptionsT; struct CumsumOptions; struct CumsumOptionsT; +struct BroadcastToOptions; +struct BroadcastToOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -796,11 +799,12 @@ enum BuiltinOperator { BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127, BuiltinOperator_CUMSUM = 128, BuiltinOperator_CALL_ONCE = 129, + BuiltinOperator_BROADCAST_TO = 130, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_CALL_ONCE + BuiltinOperator_MAX = BuiltinOperator_BROADCAST_TO }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[131] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -931,13 +935,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] { BuiltinOperator_BATCH_MATMUL, BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES, BuiltinOperator_CUMSUM, - BuiltinOperator_CALL_ONCE + BuiltinOperator_CALL_ONCE, + BuiltinOperator_BROADCAST_TO }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[131] = { + static const char * const names[132] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1068,13 +1073,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "PLACEHOLDER_FOR_GREATER_OP_CODES", "CUMSUM", "CALL_ONCE", + "BROADCAST_TO", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CALL_ONCE)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_TO)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOperator()[index]; } @@ -1184,11 +1190,12 @@ enum BuiltinOptions { BuiltinOptions_BatchMatMulOptions = 101, BuiltinOptions_CumsumOptions = 102, BuiltinOptions_CallOnceOptions = 103, + BuiltinOptions_BroadcastToOptions = 104, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_CallOnceOptions + BuiltinOptions_MAX = BuiltinOptions_BroadcastToOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[105] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1293,13 +1300,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] { BuiltinOptions_SegmentSumOptions, BuiltinOptions_BatchMatMulOptions, BuiltinOptions_CumsumOptions, - BuiltinOptions_CallOnceOptions + BuiltinOptions_CallOnceOptions, + BuiltinOptions_BroadcastToOptions }; return values; } inline const char * const *EnumNamesBuiltinOptions() { - static const char * const names[105] = { + static const char * const names[106] = { "NONE", "Conv2DOptions", "DepthwiseConv2DOptions", @@ -1404,13 +1412,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "BatchMatMulOptions", "CumsumOptions", "CallOnceOptions", + "BroadcastToOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_CallOnceOptions)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BroadcastToOptions)) return ""; const size_t index = static_cast<size_t>(e); return EnumNamesBuiltinOptions()[index]; } @@ -1831,6 +1840,10 @@ template<> struct BuiltinOptionsTraits<tflite::CallOnceOptions> { static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions; }; +template<> struct BuiltinOptionsTraits<tflite::BroadcastToOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2687,6 +2700,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_CallOnceOptions ? reinterpret_cast<const tflite::CallOnceOptionsT *>(value) : nullptr; } + tflite::BroadcastToOptionsT *AsBroadcastToOptions() { + return type == BuiltinOptions_BroadcastToOptions ? + reinterpret_cast<tflite::BroadcastToOptionsT *>(value) : nullptr; + } + const tflite::BroadcastToOptionsT *AsBroadcastToOptions() const { + return type == BuiltinOptions_BroadcastToOptions ? + reinterpret_cast<const tflite::BroadcastToOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -9505,6 +9526,46 @@ inline flatbuffers::Offset<CumsumOptions> CreateCumsumOptions( flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BroadcastToOptionsT : public flatbuffers::NativeTable { + typedef BroadcastToOptions TableType; + BroadcastToOptionsT() { + } +}; + +struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BroadcastToOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BroadcastToOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<BroadcastToOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BroadcastToOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit BroadcastToOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BroadcastToOptionsBuilder &operator=(const BroadcastToOptionsBuilder &); + flatbuffers::Offset<BroadcastToOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BroadcastToOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + BroadcastToOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; int8_t deprecated_builtin_code; @@ -9964,6 +10025,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const { return builtin_options_type() == tflite::BuiltinOptions_CallOnceOptions ? static_cast<const tflite::CallOnceOptions *>(builtin_options()) : nullptr; } + const tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast<const tflite::BroadcastToOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -10412,6 +10476,10 @@ template<> inline const tflite::CallOnceOptions *Operator::builtin_options_as<tf return builtin_options_as_CallOnceOptions(); } +template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as<tflite::BroadcastToOptions>() const { + return builtin_options_as_BroadcastToOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -14143,6 +14211,29 @@ inline flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(flatbuffers::FlatB _reverse); } +inline BroadcastToOptionsT *BroadcastToOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BroadcastToOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BroadcastToOptions::UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<BroadcastToOptions> BroadcastToOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBroadcastToOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BroadcastToOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBroadcastToOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -15030,6 +15121,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -15460,6 +15555,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -15878,6 +15977,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const tflite::CallOnceOptionsT *>(value); return CreateCallOnceOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast<const tflite::BroadcastToOptionsT *>(value); + return CreateBroadcastToOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -16296,6 +16399,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new tflite::CallOnceOptionsT(*reinterpret_cast<tflite::CallOnceOptionsT *>(u.value)); break; } + case BuiltinOptions_BroadcastToOptions: { + value = new tflite::BroadcastToOptionsT(*reinterpret_cast<tflite::BroadcastToOptionsT *>(u.value)); + break; + } default: break; } @@ -16818,6 +16925,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast<tflite::BroadcastToOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 49777026fc2..4532c7ab988 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -610,6 +610,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 2; } return 1; + // The version one of broadcast to op won't be not supported since the + // version one was rollbacked and the builtin op code number has been + // changed because of builtin op code shortage problem. + case BuiltinOperator_BROADCAST_TO: + return 2; default: return 1; } diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 10a4958e4af..433224c86c6 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -61,6 +61,10 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"}, {{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"}, {{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"}, + // The version one of broadcast to op won't be not supported since + // the version one was rollbacked and the builtin op code number + // has been changed because of builtin op code shortage problem. + {{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion}, {{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, {{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, {{BuiltinOperator_CONV_2D, 3}, "1.14.0"}, diff --git a/tensorflow/lite/tools/versioning/runtime_version_test.cc b/tensorflow/lite/tools/versioning/runtime_version_test.cc index c32de228cc3..df1ca46410c 100644 --- a/tensorflow/lite/tools/versioning/runtime_version_test.cc +++ b/tensorflow/lite/tools/versioning/runtime_version_test.cc @@ -47,7 +47,7 @@ TEST(OpVersionTest, OpversionMissing) { EXPECT_NE(runtime_version, "") << "Please add the version " << version << " of " << tflite::EnumNamesBuiltinOperator()[op_code] - << " runtime_version.cc"; + << " to runtime_version.cc"; } } }