Add builtin BroadcastTo Op to TFLite

Converter support will be added in a follow-up CL.

PiperOrigin-RevId: 340563780
Change-Id: I4ea49fef309518b6447cb117365653b9b1d7d6a1
This commit is contained in:
Jaesung Chung 2020-11-03 18:27:02 -08:00 committed by TensorFlower Gardener
parent b37cacd6ea
commit 4565291a98
18 changed files with 704 additions and 13 deletions

View File

@ -157,6 +157,7 @@ typedef enum {
kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127,
kTfLiteBuiltinCumsum = 128,
kTfLiteBuiltinCallOnce = 129,
kTfLiteBuiltinBroadcastTo = 130,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -822,6 +822,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_BROADCAST_TO:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;

View File

@ -544,6 +544,7 @@ BUILTIN_KERNEL_SRCS = [
"batch_to_space_nd.cc",
"bidirectional_sequence_lstm.cc",
"bidirectional_sequence_rnn.cc",
"broadcast_to.cc",
"call_once.cc",
"cast.cc",
"ceil.cc",
@ -1017,6 +1018,19 @@ cc_test(
],
)
cc_test(
name = "broadcast_to_test",
size = "small",
srcs = ["broadcast_to_test.cc"],
deps = [
":builtin_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "cast_test",
size = "small",

View File

@ -0,0 +1,136 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
#include <string.h>
#include <cstdint>
#include <memory>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace broadcastto {
constexpr int kInputTensor = 0;
constexpr int kShapeTensor = 1;
constexpr int kOutputTensor = 0;
constexpr int kMaxDims = 8;
struct BroadcastToContext {
BroadcastToContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, kInputTensor);
shape = GetInput(context, node, kShapeTensor);
output = GetOutput(context, node, kOutputTensor);
}
const TfLiteTensor* input;
const TfLiteTensor* shape;
TfLiteTensor* output;
};
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
BroadcastToContext* op_context) {
// Ensures the shape is 1D tensor.
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1);
// Ensure output dims is not less than input dims.
int input_num_dims = NumDimensions(op_context->input);
int output_num_dims = SizeOfDimension(op_context->shape, 0);
TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
"Output shape must be broadcastable from input shape.");
TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
"BroadcastTo only supports 1-8D tensor.");
// Check if output shape is broadcastable from input shape.
auto get_shape_data = [op_context](int i) -> int32_t {
if (op_context->shape->type == kTfLiteInt32) {
return GetTensorData<int32_t>(op_context->shape)[i];
} else {
return GetTensorData<int64_t>(op_context->shape)[i];
}
};
int extending_dims = output_num_dims - input_num_dims;
for (int idx = 0; idx < input_num_dims; ++idx) {
TF_LITE_ENSURE_MSG(context,
(SizeOfDimension(op_context->input, idx) == 1 ||
SizeOfDimension(op_context->input, idx) ==
get_shape_data(extending_dims + idx)),
"Output shape must be broadcastable from input shape.");
}
// Resizing the shape of the output tensor.
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims);
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
scoped_output_shape(output_shape, TfLiteIntArrayFree);
for (int idx = 0; idx < output_num_dims; ++idx) {
output_shape->data[idx] = get_shape_data(idx);
}
return context->ResizeTensor(context, op_context->output,
scoped_output_shape.release());
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_MSG(context,
(NumDimensions(GetInput(context, node, 0)) <= kMaxDims),
"BroadcastTo only supports 1-8D tensor.");
BroadcastToContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 ||
op_context.shape->type == kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
// Not yet support string type due to the use of memcopy with fixed size.
TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString);
if (IsConstantTensor(op_context.shape)) {
return ResizeOutputTensor(context, &op_context);
}
SetTensorToDynamic(op_context.output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
BroadcastToContext op_context(context, node);
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
// BroadcastTo op support upto 8 dims, matching the support of Tensorflow.
reference_ops::BroadcastTo<kMaxDims>(
GetTensorShape(op_context.input), op_context.input->data.raw,
GetTensorShape(op_context.output), op_context.output->data.raw,
op_context.input->type);
return kTfLiteOk;
}
} // namespace broadcastto
TfLiteRegistration* Register_BROADCAST_TO() {
static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare,
broadcastto::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,255 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
template <class InputType, class ShapeType = int32_t>
class BroadcastToOpModel : public SingleOpModel {
public:
// BroadcastTo with dynamic shape.
BroadcastToOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> shape_shape) {
input_ = AddInput({GetTensorType<InputType>(), input_shape});
shape_ = AddInput({GetTensorType<ShapeType>(), shape_shape});
output_ = AddOutput(GetTensorType<InputType>());
SetBuiltinOp(BuiltinOperator_BROADCAST_TO,
BuiltinOptions_BroadcastToOptions,
CreateBroadcastToOptions(builder_).Union());
BuildInterpreter({input_shape, shape_shape});
}
// BroadcastTo with const shape.
BroadcastToOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> shape_shape,
std::initializer_list<ShapeType> shape_values) {
input_ = AddInput({GetTensorType<InputType>(), input_shape});
shape_ =
AddConstInput(GetTensorType<ShapeType>(), shape_values, shape_shape);
output_ = AddOutput(GetTensorType<InputType>());
SetBuiltinOp(BuiltinOperator_BROADCAST_TO,
BuiltinOptions_BroadcastToOptions,
CreateBroadcastToOptions(builder_).Union());
BuildInterpreter({input_shape, shape_shape});
}
void SetInput(std::initializer_list<InputType> data) {
PopulateTensor(input_, data);
}
void SetShape(std::initializer_list<ShapeType> data) {
PopulateTensor(shape_, data);
}
std::vector<InputType> GetOutput() {
return ExtractVector<InputType>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int input_;
int shape_;
int output_;
};
template <typename T>
class BroadcastToOpTest : public ::testing::Test {};
using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>;
TYPED_TEST_SUITE(BroadcastToOpTest, DataTypes);
#ifdef GTEST_HAS_DEATH_TEST
TYPED_TEST(BroadcastToOpTest, ShapeMustBe1D) {
EXPECT_DEATH(
BroadcastToOpModel<TypeParam>({2, 3, 4, 4}, {2, 2}, {2, 3, 4, 4}), "");
// Non-constant Shape tensor.
BroadcastToOpModel<TypeParam> m({2, 3, 4, 4}, {2, 2});
m.SetShape({2, 3, 4, 4});
EXPECT_THAT(m.InvokeUnchecked(), kTfLiteError);
}
TYPED_TEST(BroadcastToOpTest, TooManyDimensions) {
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9},
{2, 2, 3, 4, 5, 6, 7, 8, 9}),
"BroadcastTo only supports 1-8D tensor.");
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}),
"BroadcastTo only supports 1-8D tensor.");
}
TYPED_TEST(BroadcastToOpTest, MismatchDimension) {
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({2, 4, 1, 2}, {4}, {2, 4, 1, 3}),
"Output shape must be broadcastable from input shape.");
EXPECT_DEATH(
BroadcastToOpModel<TypeParam>({2, 4, 1, 2, 3}, {4}, {2, 4, 1, 2}),
"Output shape must be broadcastable from input shape.");
// Non-constant Shape tensor.
BroadcastToOpModel<TypeParam> m1({2, 4, 1, 2}, {4});
m1.SetShape({2, 3, 4, 4});
EXPECT_THAT(m1.InvokeUnchecked(), kTfLiteError);
BroadcastToOpModel<TypeParam> m2({2, 4, 1, 2}, {5});
m2.SetShape({1, 2, 3, 4, 4});
EXPECT_THAT(m2.InvokeUnchecked(), kTfLiteError);
}
#endif
TYPED_TEST(BroadcastToOpTest, BroadcastTo1DConstTest) {
BroadcastToOpModel<TypeParam> m({1}, {1}, {4});
m.SetInput({3});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo4DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4}, {1, 1, 2, 2});
m.SetInput({3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo8DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
{1, 1, 1, 1, 1, 1, 2, 2});
m.SetInput({3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo1DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1}, {1});
m.SetInput({3});
m.SetShape({4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo4DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4});
m.SetInput({3, 4});
m.SetShape({1, 1, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo8DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8});
m.SetInput({3, 4});
m.SetShape({1, 1, 1, 1, 1, 1, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4}, {3, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4});
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetShape({3, 3, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6}, {2, 2, 1, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12,
1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetShape({2, 2, 1, 3, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12,
1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
}
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) {
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {3}, {3, 1, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
{1, 1, 1, 1, 1, 1, 2, 2});
m.SetInput({3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
TYPED_TEST(BroadcastToOpTest, Int64ShapeDDynamicTest) {
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8});
m.SetInput({3, 4});
m.SetShape({1, 1, 1, 1, 1, 1, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
} // namespace
} // namespace tflite

View File

@ -39,6 +39,7 @@ TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
TfLiteRegistration* Register_BATCH_MATMUL();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
TfLiteRegistration* Register_BROADCAST_TO();
TfLiteRegistration* Register_CALL_ONCE();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_CEIL();

View File

@ -452,6 +452,7 @@ cc_library(
"reference/arg_min_max.h",
"reference/batch_matmul.h",
"reference/binary_function.h",
"reference/broadcast_to.h",
"reference/ceil.h",
"reference/comparisons.h",
"reference/concatenation.h",
@ -521,6 +522,7 @@ cc_library(
"@ruy//ruy/profiler:instrumentation",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
] + select({

View File

@ -696,6 +696,13 @@ inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
indexes[4] * desc.strides[4];
}
inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
}
// Given the dimensions of the operands for an element-wise binary broadcast,
// adjusts them so that they can be directly iterated over with simple loops.
// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace reference_ops {
template <int N>
void BroadcastImpl(const NdArrayDesc<N>& input_desc, const char* input_data,
const NdArrayDesc<N>& output_desc, char* output_data,
int indexes[N], int dim, const int last_broadcasting_dim,
const int type_size) {
// Copy data from input to output.
if (dim == last_broadcasting_dim) {
int copy_size = output_desc.strides[dim] * type_size;
const char* data_src =
input_data + SubscriptToIndex(input_desc, indexes) * type_size;
char* data_dst =
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
memcpy(data_dst, data_src, copy_size);
}
return;
}
// Recursive call to find the next broadcasting.
for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim];
++indexes[dim]) {
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes,
dim + 1, last_broadcasting_dim, type_size);
}
// Duplicate data in output tensor.
indexes[dim] = 0;
if (input_desc.extents[dim] != output_desc.extents[dim]) {
int copy_size = output_desc.strides[dim] * type_size;
char* data_src =
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
char* data_dst = data_src + copy_size;
for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
memcpy(data_dst, data_src, copy_size);
}
}
}
template <int N>
inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
const char* input_data,
const RuntimeShape& unextended_output_shape,
char* output_data, TfLiteType data_type) {
NdArrayDesc<N> input_desc;
NdArrayDesc<N> output_desc;
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
&input_desc);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
&output_desc);
// Get the last dimension has broadcasting. At this dimension, the data is
// copied from input tensor to output tensor.
int last_broadcast_dim = 0;
for (int i = N - 1; i > 0; --i) {
if (input_desc.extents[i] != output_desc.extents[i]) {
last_broadcast_dim = i;
break;
}
}
// Broadcasting using memcpy.
int indexes[N] = {0};
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
last_broadcast_dim, TfLiteTypeGetSize(data_type));
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <stdlib.h>
#include <algorithm>
#include <complex>
#include <limits>
#include <memory>
@ -434,4 +435,44 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
}
#endif // TF_LITE_STATIC_MEMORY
// Size of string is not constant, return 0 in such case.
int TfLiteTypeGetSize(TfLiteType type) {
switch (type) {
case kTfLiteUInt8:
TF_LITE_ASSERT_EQ(sizeof(uint8_t), 1);
return 1;
case kTfLiteInt8:
TF_LITE_ASSERT_EQ(sizeof(int8_t), 1);
return 1;
case kTfLiteBool:
return sizeof(bool);
case kTfLiteInt16:
TF_LITE_ASSERT_EQ(sizeof(int16_t), 2);
return 2;
case kTfLiteFloat16:
TF_LITE_ASSERT_EQ(sizeof(int16_t), 2);
return 2;
case kTfLiteFloat32:
TF_LITE_ASSERT_EQ(sizeof(float), 4);
return 4;
case kTfLiteInt32:
TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
return 4;
case kTfLiteInt64:
TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
return 8;
case kTfLiteFloat64:
TF_LITE_ASSERT_EQ(sizeof(double), 8);
return 8;
case kTfLiteComplex64:
TF_LITE_ASSERT_EQ(sizeof(std::complex<float>), 8);
return 8;
case kTfLiteComplex128:
TF_LITE_ASSERT_EQ(sizeof(std::complex<double>), 16);
return 16;
default:
return 0;
}
}
} // namespace tflite

View File

@ -284,6 +284,10 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const TfLiteTensor* input2,
const TfLiteTensor* input3,
TfLiteIntArray** output_shape);
// Return the size of given type in bytes. Return 0 in in case of string.
int TfLiteTypeGetSize(TfLiteType type);
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_

View File

@ -297,6 +297,12 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
// The version one of broadcast to op won't be not supported since the version
// one was rollbacked and the builtin op code number has been changed because
// of builtin op code shortage problem.
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
/* min_version = */ 2,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CALL_ONCE,
tflite::ops::builtin::Register_CALL_ONCE());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());

View File

@ -156,6 +156,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF();
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_BROADCAST_TO();
namespace {
@ -262,6 +263,12 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF(),
/* min_version = */ 1,
/* max_version = */ 2);
// The version one of broadcast to op won't be not supported since the version
// one was rollbacked and the builtin op code number has been changed because
// of builtin op code shortage problem.
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
/* min_version = */ 2,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORM_REF());
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,

View File

@ -353,7 +353,8 @@ enum BuiltinOperator : int32 {
BATCH_MATMUL = 126,
PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
CUMSUM = 128,
CALL_ONCE = 129
CALL_ONCE = 129,
BROADCAST_TO = 130
}
@ -461,7 +462,8 @@ union BuiltinOptions {
SegmentSumOptions,
BatchMatMulOptions,
CumsumOptions,
CallOnceOptions
CallOnceOptions,
BroadcastToOptions
}
enum Padding : byte { SAME, VALID }
@ -994,6 +996,9 @@ table CumsumOptions {
reverse:bool;
}
table BroadcastToOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -355,6 +355,9 @@ struct BatchMatMulOptionsT;
struct CumsumOptions;
struct CumsumOptionsT;
struct BroadcastToOptions;
struct BroadcastToOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -796,11 +799,12 @@ enum BuiltinOperator {
BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
BuiltinOperator_CUMSUM = 128,
BuiltinOperator_CALL_ONCE = 129,
BuiltinOperator_BROADCAST_TO = 130,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_CALL_ONCE
BuiltinOperator_MAX = BuiltinOperator_BROADCAST_TO
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[131] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -931,13 +935,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] {
BuiltinOperator_BATCH_MATMUL,
BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES,
BuiltinOperator_CUMSUM,
BuiltinOperator_CALL_ONCE
BuiltinOperator_CALL_ONCE,
BuiltinOperator_BROADCAST_TO
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
static const char * const names[131] = {
static const char * const names[132] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@ -1068,13 +1073,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"PLACEHOLDER_FOR_GREATER_OP_CODES",
"CUMSUM",
"CALL_ONCE",
"BROADCAST_TO",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CALL_ONCE)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_TO)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -1184,11 +1190,12 @@ enum BuiltinOptions {
BuiltinOptions_BatchMatMulOptions = 101,
BuiltinOptions_CumsumOptions = 102,
BuiltinOptions_CallOnceOptions = 103,
BuiltinOptions_BroadcastToOptions = 104,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_CallOnceOptions
BuiltinOptions_MAX = BuiltinOptions_BroadcastToOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[105] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1293,13 +1300,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] {
BuiltinOptions_SegmentSumOptions,
BuiltinOptions_BatchMatMulOptions,
BuiltinOptions_CumsumOptions,
BuiltinOptions_CallOnceOptions
BuiltinOptions_CallOnceOptions,
BuiltinOptions_BroadcastToOptions
};
return values;
}
inline const char * const *EnumNamesBuiltinOptions() {
static const char * const names[105] = {
static const char * const names[106] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@ -1404,13 +1412,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"BatchMatMulOptions",
"CumsumOptions",
"CallOnceOptions",
"BroadcastToOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_CallOnceOptions)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BroadcastToOptions)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1831,6 +1840,10 @@ template<> struct BuiltinOptionsTraits<tflite::CallOnceOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions;
};
template<> struct BuiltinOptionsTraits<tflite::BroadcastToOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2687,6 +2700,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_CallOnceOptions ?
reinterpret_cast<const tflite::CallOnceOptionsT *>(value) : nullptr;
}
tflite::BroadcastToOptionsT *AsBroadcastToOptions() {
return type == BuiltinOptions_BroadcastToOptions ?
reinterpret_cast<tflite::BroadcastToOptionsT *>(value) : nullptr;
}
const tflite::BroadcastToOptionsT *AsBroadcastToOptions() const {
return type == BuiltinOptions_BroadcastToOptions ?
reinterpret_cast<const tflite::BroadcastToOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -9505,6 +9526,46 @@ inline flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(
flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct BroadcastToOptionsT : public flatbuffers::NativeTable {
typedef BroadcastToOptions TableType;
BroadcastToOptionsT() {
}
};
struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef BroadcastToOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
BroadcastToOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<BroadcastToOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct BroadcastToOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit BroadcastToOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
BroadcastToOptionsBuilder &operator=(const BroadcastToOptionsBuilder &);
flatbuffers::Offset<BroadcastToOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<BroadcastToOptions>(end);
return o;
}
};
inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
BroadcastToOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
int8_t deprecated_builtin_code;
@ -9964,6 +10025,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_CallOnceOptions ? static_cast<const tflite::CallOnceOptions *>(builtin_options()) : nullptr;
}
const tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast<const tflite::BroadcastToOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -10412,6 +10476,10 @@ template<> inline const tflite::CallOnceOptions *Operator::builtin_options_as<tf
return builtin_options_as_CallOnceOptions();
}
template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as<tflite::BroadcastToOptions>() const {
return builtin_options_as_BroadcastToOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -14143,6 +14211,29 @@ inline flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(flatbuffers::FlatB
_reverse);
}
inline BroadcastToOptionsT *BroadcastToOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BroadcastToOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void BroadcastToOptions::UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<BroadcastToOptions> BroadcastToOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateBroadcastToOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BroadcastToOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateBroadcastToOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -15030,6 +15121,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
@ -15460,6 +15555,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -15878,6 +15977,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const tflite::CallOnceOptionsT *>(value);
return CreateCallOnceOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<const tflite::BroadcastToOptionsT *>(value);
return CreateBroadcastToOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -16296,6 +16399,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new tflite::CallOnceOptionsT(*reinterpret_cast<tflite::CallOnceOptionsT *>(u.value));
break;
}
case BuiltinOptions_BroadcastToOptions: {
value = new tflite::BroadcastToOptionsT(*reinterpret_cast<tflite::BroadcastToOptionsT *>(u.value));
break;
}
default:
break;
}
@ -16818,6 +16925,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<tflite::BroadcastToOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -610,6 +610,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 2;
}
return 1;
// The version one of broadcast to op won't be not supported since the
// version one was rollbacked and the builtin op code number has been
// changed because of builtin op code shortage problem.
case BuiltinOperator_BROADCAST_TO:
return 2;
default:
return 1;
}

View File

@ -61,6 +61,10 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
{{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
{{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"},
// The version one of broadcast to op won't be not supported since
// the version one was rollbacked and the builtin op code number
// has been changed because of builtin op code shortage problem.
{{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion},
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},

View File

@ -47,7 +47,7 @@ TEST(OpVersionTest, OpversionMissing) {
EXPECT_NE(runtime_version, "")
<< "Please add the version " << version << " of "
<< tflite::EnumNamesBuiltinOperator()[op_code]
<< " runtime_version.cc";
<< " to runtime_version.cc";
}
}
}