Add builtin BroadcastTo Op to TFLite

Converter support will be added in a follow-up CL.

PiperOrigin-RevId: 322281991
Change-Id: I9a96d0dff3a089a9b43b85c955cc416717e26aa9
This commit is contained in:
Thai Nguyen 2020-07-20 20:34:46 -07:00 committed by TensorFlower Gardener
parent 918f876bf8
commit bc8bb3ba84
21 changed files with 674 additions and 15 deletions

View File

@ -54,7 +54,8 @@
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* Better support for ops with high-dimensional broadcasting inputs by adding
`BroadcastTo` ops when necessary.
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:

View File

@ -153,6 +153,7 @@ typedef enum {
kTfLiteBuiltinDensify = 124,
kTfLiteBuiltinSegmentSum = 125,
kTfLiteBuiltinBatchMatmul = 126,
kTfLiteBuiltinBroadcastTo = 127,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -219,6 +219,29 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "Unknown type";
}
// Size of string is not constant, return 0 in such case.
int TfLiteTypeGetSize(TfLiteType type) {
switch (type) {
case kTfLiteUInt8:
case kTfLiteInt8:
return 1;
case kTfLiteBool:
return sizeof(bool);
case kTfLiteInt16:
case kTfLiteFloat16:
return 2;
case kTfLiteFloat32:
case kTfLiteInt32:
return 4;
case kTfLiteInt64:
case kTfLiteComplex64:
case kTfLiteFloat64:
return 8;
default:
return 0;
}
}
TfLiteDelegate TfLiteDelegateCreate() {
TfLiteDelegate d = {
.data_ = NULL,

View File

@ -268,6 +268,9 @@ typedef enum {
// Return the name of a given type, for error reporting purposes.
const char* TfLiteTypeGetName(TfLiteType type);
// Return the size of given type in bytes. Return 0 in in case of string.
int TfLiteTypeGetSize(TfLiteType type);
// SupportedQuantizationTypes.
typedef enum TfLiteQuantizationType {
// No quantization.

View File

@ -820,6 +820,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_BROADCAST_TO:
return kTfLiteOk;
}
return kTfLiteError;

View File

@ -491,6 +491,7 @@ BUILTIN_KERNEL_SRCS = [
"batch_to_space_nd.cc",
"bidirectional_sequence_lstm.cc",
"bidirectional_sequence_rnn.cc",
"broadcast_to.cc",
"cast.cc",
"ceil.cc",
"comparisons.cc",
@ -984,6 +985,19 @@ cc_test(
],
)
cc_test(
name = "broadcast_to_test",
size = "small",
srcs = ["broadcast_to_test.cc"],
deps = [
":builtin_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "cast_test",
size = "small",

View File

@ -0,0 +1,136 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
#include <string.h>
#include <cstdint>
#include <memory>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace broadcastto {
constexpr int kInputTensor = 0;
constexpr int kShapeTensor = 1;
constexpr int kOutputTensor = 0;
constexpr int kMaxDims = 8;
struct BroadcastToContext {
BroadcastToContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, kInputTensor);
shape = GetInput(context, node, kShapeTensor);
output = GetOutput(context, node, kOutputTensor);
}
const TfLiteTensor* input;
const TfLiteTensor* shape;
TfLiteTensor* output;
};
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
BroadcastToContext* op_context) {
// Ensures the shape is 1D tensor.
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1);
// Ensure output dims is not less than input dims.
int input_num_dims = NumDimensions(op_context->input);
int output_num_dims = SizeOfDimension(op_context->shape, 0);
TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
"Output shape must be broadcastable from input shape.");
TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
"BroadcastTo only supports 1-8D tensor.");
// Check if output shape is broadcastable from input shape.
auto get_shape_data = [op_context](int i) -> int32_t {
if (op_context->shape->type == kTfLiteInt32) {
return GetTensorData<int32_t>(op_context->shape)[i];
} else {
return GetTensorData<int64_t>(op_context->shape)[i];
}
};
int extending_dims = output_num_dims - input_num_dims;
for (int idx = 0; idx < input_num_dims; ++idx) {
TF_LITE_ENSURE_MSG(context,
(SizeOfDimension(op_context->input, idx) == 1 ||
SizeOfDimension(op_context->input, idx) ==
get_shape_data(extending_dims + idx)),
"Output shape must be broadcastable from input shape.");
}
// Resizing the shape of the output tensor.
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims);
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
scoped_output_shape(output_shape, TfLiteIntArrayFree);
for (int idx = 0; idx < output_num_dims; ++idx) {
output_shape->data[idx] = get_shape_data(idx);
}
return context->ResizeTensor(context, op_context->output,
scoped_output_shape.release());
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_MSG(context,
(NumDimensions(GetInput(context, node, 0)) <= kMaxDims),
"BroadcastTo only supports 1-8D tensor.");
BroadcastToContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 ||
op_context.shape->type == kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
// Not yet support string type due to the use of memcopy with fixed size.
TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString);
if (IsConstantTensor(op_context.shape)) {
return ResizeOutputTensor(context, &op_context);
}
SetTensorToDynamic(op_context.output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
BroadcastToContext op_context(context, node);
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
// BroadcastTo op support upto 8 dims, matching the support of Tensorflow.
reference_ops::BroadcastTo<kMaxDims>(
GetTensorShape(op_context.input), op_context.input->data.raw,
GetTensorShape(op_context.output), op_context.output->data.raw,
op_context.input->type);
return kTfLiteOk;
}
} // namespace broadcastto
TfLiteRegistration* Register_BROADCAST_TO() {
static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare,
broadcastto::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,255 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
template <class InputType, class ShapeType = int32_t>
class BroadcastToOpModel : public SingleOpModel {
public:
// BroadcastTo with dynamic shape.
BroadcastToOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> shape_shape) {
input_ = AddInput({GetTensorType<InputType>(), input_shape});
shape_ = AddInput({GetTensorType<ShapeType>(), shape_shape});
output_ = AddOutput(GetTensorType<InputType>());
SetBuiltinOp(BuiltinOperator_BROADCAST_TO,
BuiltinOptions_BroadcastToOptions,
CreateBroadcastToOptions(builder_).Union());
BuildInterpreter({input_shape, shape_shape});
}
// BroadcastTo with const shape.
BroadcastToOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> shape_shape,
std::initializer_list<ShapeType> shape_values) {
input_ = AddInput({GetTensorType<InputType>(), input_shape});
shape_ =
AddConstInput(GetTensorType<ShapeType>(), shape_values, shape_shape);
output_ = AddOutput(GetTensorType<InputType>());
SetBuiltinOp(BuiltinOperator_BROADCAST_TO,
BuiltinOptions_BroadcastToOptions,
CreateBroadcastToOptions(builder_).Union());
BuildInterpreter({input_shape, shape_shape});
}
void SetInput(std::initializer_list<InputType> data) {
PopulateTensor(input_, data);
}
void SetShape(std::initializer_list<ShapeType> data) {
PopulateTensor(shape_, data);
}
std::vector<InputType> GetOutput() {
return ExtractVector<InputType>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int input_;
int shape_;
int output_;
};
template <typename T>
class BroadcastToOpTest : public ::testing::Test {};
using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>;
TYPED_TEST_SUITE(BroadcastToOpTest, DataTypes);
#ifdef GTEST_HAS_DEATH_TEST
TYPED_TEST(BroadcastToOpTest, ShapeMustBe1D) {
EXPECT_DEATH(
BroadcastToOpModel<TypeParam>({2, 3, 4, 4}, {2, 2}, {2, 3, 4, 4}), "");
// Non-constant Shape tensor.
BroadcastToOpModel<TypeParam> m({2, 3, 4, 4}, {2, 2});
m.SetShape({2, 3, 4, 4});
EXPECT_THAT(m.InvokeUnchecked(), kTfLiteError);
}
TYPED_TEST(BroadcastToOpTest, TooManyDimensions) {
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9},
{2, 2, 3, 4, 5, 6, 7, 8, 9}),
"BroadcastTo only supports 1-8D tensor.");
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}),
"BroadcastTo only supports 1-8D tensor.");
}
TYPED_TEST(BroadcastToOpTest, MismatchDimension) {
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({2, 4, 1, 2}, {4}, {2, 4, 1, 3}),
"Output shape must be broadcastable from input shape.");
EXPECT_DEATH(
BroadcastToOpModel<TypeParam>({2, 4, 1, 2, 3}, {4}, {2, 4, 1, 2}),
"Output shape must be broadcastable from input shape.");
// Non-constant Shape tensor.
BroadcastToOpModel<TypeParam> m1({2, 4, 1, 2}, {4});
m1.SetShape({2, 3, 4, 4});
EXPECT_THAT(m1.InvokeUnchecked(), kTfLiteError);
BroadcastToOpModel<TypeParam> m2({2, 4, 1, 2}, {5});
m2.SetShape({1, 2, 3, 4, 4});
EXPECT_THAT(m2.InvokeUnchecked(), kTfLiteError);
}
#endif
TYPED_TEST(BroadcastToOpTest, BroadcastTo1DConstTest) {
BroadcastToOpModel<TypeParam> m({1}, {1}, {4});
m.SetInput({3});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo4DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4}, {1, 1, 2, 2});
m.SetInput({3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo8DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
{1, 1, 1, 1, 1, 1, 2, 2});
m.SetInput({3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo1DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1}, {1});
m.SetInput({3});
m.SetShape({4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo4DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4});
m.SetInput({3, 4});
m.SetShape({1, 1, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4}));
}
TYPED_TEST(BroadcastToOpTest, BroadcastTo8DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8});
m.SetInput({3, 4});
m.SetShape({1, 1, 1, 1, 1, 1, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4}, {3, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4});
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetShape({3, 3, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DConstTest) {
BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6}, {2, 2, 1, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12,
1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
}
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) {
BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetShape({2, 2, 1, 3, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12,
1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
}
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) {
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {3}, {3, 1, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
{1, 1, 1, 1, 1, 1, 2, 2});
m.SetInput({3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
TYPED_TEST(BroadcastToOpTest, Int64ShapeDDynamicTest) {
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8});
m.SetInput({3, 4});
m.SetShape({1, 1, 1, 1, 1, 1, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
}
} // namespace
} // namespace tflite

View File

@ -39,6 +39,7 @@ TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
TfLiteRegistration* Register_BATCH_MATMUL();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
TfLiteRegistration* Register_BROADCAST_TO();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_CEIL();
TfLiteRegistration* Register_CONCATENATION();

View File

@ -441,6 +441,7 @@ cc_library(
"reference/arg_min_max.h",
"reference/batch_matmul.h",
"reference/binary_function.h",
"reference/broadcast_to.h",
"reference/ceil.h",
"reference/comparisons.h",
"reference/concatenation.h",

View File

@ -665,6 +665,13 @@ inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
indexes[4] * desc.strides[4];
}
inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
}
// Given the dimensions of the operands for an element-wise binary broadcast,
// adjusts them so that they can be directly iterated over with simple loops.
// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_ops {
template <int N>
void BroadcastImpl(const NdArrayDesc<N>& input_desc, const char* input_data,
const NdArrayDesc<N>& output_desc, char* output_data,
int indexes[N], int dim, const int last_broadcasting_dim,
const int type_size) {
// Copy data from input to output.
if (dim == last_broadcasting_dim) {
int copy_size = output_desc.strides[dim] * type_size;
const char* data_src =
input_data + SubscriptToIndex(input_desc, indexes) * type_size;
char* data_dst =
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
memcpy(data_dst, data_src, copy_size);
}
return;
}
// Recursive call to find the next broadcasting.
for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim];
++indexes[dim]) {
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes,
dim + 1, last_broadcasting_dim, type_size);
}
// Duplicate data in output tensor.
indexes[dim] = 0;
if (input_desc.extents[dim] != output_desc.extents[dim]) {
int copy_size = output_desc.strides[dim] * type_size;
char* data_src =
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
char* data_dst = data_src + copy_size;
for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
memcpy(data_dst, data_src, copy_size);
}
}
}
template <int N>
inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
const char* input_data,
const RuntimeShape& unextended_output_shape,
char* output_data, TfLiteType data_type) {
NdArrayDesc<N> input_desc;
NdArrayDesc<N> output_desc;
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
&input_desc);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
&output_desc);
// Get the last dimension has broadcasting. At this dimension, the data is
// copied from input tensor to output tensor.
int last_broadcast_dim = 0;
for (int i = N - 1; i > 0; --i) {
if (input_desc.extents[i] != output_desc.extents[i]) {
last_broadcast_dim = i;
break;
}
}
// Broadcasting using memcpy.
int indexes[N] = {0};
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
last_broadcast_dim, TfLiteTypeGetSize(data_type));
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_

View File

@ -292,6 +292,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -139,6 +139,7 @@ TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_BATCH_MATMUL_REF();
TfLiteRegistration* Register_BROADCAST_TO();
namespace {
@ -207,6 +208,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
Register_SPACE_TO_BATCH_ND_REF());
AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND,
Register_BATCH_TO_SPACE_ND_REF());
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO());
AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF());
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF());
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,

View File

@ -349,7 +349,8 @@ enum BuiltinOperator : byte {
SELECT_V2 = 123,
DENSIFY = 124,
SEGMENT_SUM = 125,
BATCH_MATMUL = 126
BATCH_MATMUL = 126,
BROADCAST_TO = 127
}
@ -455,7 +456,8 @@ union BuiltinOptions {
SelectV2Options,
DensifyOptions,
SegmentSumOptions,
BatchMatMulOptions
BatchMatMulOptions,
BroadcastToOptions
}
enum Padding : byte { SAME, VALID }
@ -975,6 +977,9 @@ table BatchMatMulOptions {
adj_y:bool;
}
table BroadcastToOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -349,6 +349,9 @@ struct SegmentSumOptionsT;
struct BatchMatMulOptions;
struct BatchMatMulOptionsT;
struct BroadcastToOptions;
struct BroadcastToOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -781,11 +784,12 @@ enum BuiltinOperator {
BuiltinOperator_DENSIFY = 124,
BuiltinOperator_SEGMENT_SUM = 125,
BuiltinOperator_BATCH_MATMUL = 126,
BuiltinOperator_BROADCAST_TO = 127,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL
BuiltinOperator_MAX = BuiltinOperator_BROADCAST_TO
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[128] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -913,13 +917,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] {
BuiltinOperator_SELECT_V2,
BuiltinOperator_DENSIFY,
BuiltinOperator_SEGMENT_SUM,
BuiltinOperator_BATCH_MATMUL
BuiltinOperator_BATCH_MATMUL,
BuiltinOperator_BROADCAST_TO
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
static const char * const names[128] = {
static const char * const names[129] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@ -1047,13 +1052,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"DENSIFY",
"SEGMENT_SUM",
"BATCH_MATMUL",
"BROADCAST_TO",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_TO)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -1161,11 +1167,12 @@ enum BuiltinOptions {
BuiltinOptions_DensifyOptions = 99,
BuiltinOptions_SegmentSumOptions = 100,
BuiltinOptions_BatchMatMulOptions = 101,
BuiltinOptions_BroadcastToOptions = 102,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_BatchMatMulOptions
BuiltinOptions_MAX = BuiltinOptions_BroadcastToOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[103] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1268,13 +1275,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[102] {
BuiltinOptions_SelectV2Options,
BuiltinOptions_DensifyOptions,
BuiltinOptions_SegmentSumOptions,
BuiltinOptions_BatchMatMulOptions
BuiltinOptions_BatchMatMulOptions,
BuiltinOptions_BroadcastToOptions
};
return values;
}
inline const char * const *EnumNamesBuiltinOptions() {
static const char * const names[103] = {
static const char * const names[104] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@ -1377,13 +1385,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"DensifyOptions",
"SegmentSumOptions",
"BatchMatMulOptions",
"BroadcastToOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BatchMatMulOptions)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BroadcastToOptions)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1796,6 +1805,10 @@ template<> struct BuiltinOptionsTraits<tflite::BatchMatMulOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions;
};
template<> struct BuiltinOptionsTraits<tflite::BroadcastToOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2636,6 +2649,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_BatchMatMulOptions ?
reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value) : nullptr;
}
tflite::BroadcastToOptionsT *AsBroadcastToOptions() {
return type == BuiltinOptions_BroadcastToOptions ?
reinterpret_cast<tflite::BroadcastToOptionsT *>(value) : nullptr;
}
const tflite::BroadcastToOptionsT *AsBroadcastToOptions() const {
return type == BuiltinOptions_BroadcastToOptions ?
reinterpret_cast<const tflite::BroadcastToOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -9310,6 +9331,46 @@ inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(
flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct BroadcastToOptionsT : public flatbuffers::NativeTable {
typedef BroadcastToOptions TableType;
BroadcastToOptionsT() {
}
};
struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef BroadcastToOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
BroadcastToOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<BroadcastToOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct BroadcastToOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit BroadcastToOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
BroadcastToOptionsBuilder &operator=(const BroadcastToOptionsBuilder &);
flatbuffers::Offset<BroadcastToOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<BroadcastToOptions>(end);
return o;
}
};
inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
BroadcastToOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
tflite::BuiltinOperator builtin_code;
@ -9749,6 +9810,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast<const tflite::BatchMatMulOptions *>(builtin_options()) : nullptr;
}
const tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast<const tflite::BroadcastToOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -10189,6 +10253,10 @@ template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as
return builtin_options_as_BatchMatMulOptions();
}
template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as<tflite::BroadcastToOptions>() const {
return builtin_options_as_BroadcastToOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -13656,6 +13724,29 @@ inline flatbuffers::Offset<BatchMatMulOptions> CreateBatchMatMulOptions(flatbuff
_adj_y);
}
inline BroadcastToOptionsT *BroadcastToOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BroadcastToOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void BroadcastToOptions::UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<BroadcastToOptions> BroadcastToOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateBroadcastToOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BroadcastToOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateBroadcastToOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -14465,6 +14556,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
@ -14887,6 +14982,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -15297,6 +15396,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const tflite::BatchMatMulOptionsT *>(value);
return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<const tflite::BroadcastToOptionsT *>(value);
return CreateBroadcastToOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -15707,6 +15810,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new tflite::BatchMatMulOptionsT(*reinterpret_cast<tflite::BatchMatMulOptionsT *>(u.value));
break;
}
case BuiltinOptions_BroadcastToOptions: {
value = new tflite::BroadcastToOptionsT(*reinterpret_cast<tflite::BroadcastToOptionsT *>(u.value));
break;
}
default:
break;
}
@ -16219,6 +16326,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_BroadcastToOptions: {
auto ptr = reinterpret_cast<tflite::BroadcastToOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;
@ -16282,4 +16394,4 @@ inline std::unique_ptr<tflite::ModelT> UnPackSizePrefixedModel(
} // namespace tflite
#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_
#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_

View File

@ -43,6 +43,7 @@ enum class OperatorType : uint8 {
kAveragePool,
kBatchMatMul,
kBatchNormalization,
kBroadcastTo,
kCeil,
kConv,
kConcatenation,

View File

@ -63,6 +63,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},
{{OperatorType::kBatchToSpaceND, 2}, "1.14.0"},
{{OperatorType::kBatchMatMul, 1}, kPendingReleaseOpVersion},
{{OperatorType::kBroadcastTo, 1}, kPendingReleaseOpVersion},
{{OperatorType::kCast, 1}, "1.5.0"},
{{OperatorType::kConcatenation, 1}, "1.5.0"},
{{OperatorType::kConcatenation, 2}, "1.14.0"},

View File

@ -268,6 +268,9 @@ typedef enum {
// Return the name of a given type, for error reporting purposes.
const char* TfLiteTypeGetName(TfLiteType type);
// Return the size of given type in bytes. Return 0 in in case of string.
int TfLiteTypeGetSize(TfLiteType type);
// SupportedQuantizationTypes.
typedef enum TfLiteQuantizationType {
// No quantization.

View File

@ -59,6 +59,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"},
{{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
{{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
{{BuiltinOperator_BROADCAST_TO, 1}, kPendingReleaseVersion},
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},

View File

@ -47,7 +47,7 @@ TEST(OpVersionTest, OpversionMissing) {
EXPECT_NE(runtime_version, "")
<< "Please add the version " << version << " of "
<< tflite::EnumNamesBuiltinOperator()[op_code]
<< " runtime_version.cc";
<< " to runtime_version.cc";
}
}
}