Add builtin BroadcastTo Op to TFLite
Converter support will be added in a follow-up CL. PiperOrigin-RevId: 340563780 Change-Id: I4ea49fef309518b6447cb117365653b9b1d7d6a1
This commit is contained in:
parent
b37cacd6ea
commit
4565291a98
@ -157,6 +157,7 @@ typedef enum {
|
||||
kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127,
|
||||
kTfLiteBuiltinCumsum = 128,
|
||||
kTfLiteBuiltinCallOnce = 129,
|
||||
kTfLiteBuiltinBroadcastTo = 130,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -822,6 +822,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_SCATTER_ND:
|
||||
case BuiltinOperator_DENSIFY:
|
||||
case BuiltinOperator_SEGMENT_SUM:
|
||||
case BuiltinOperator_BROADCAST_TO:
|
||||
return kTfLiteOk;
|
||||
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
|
||||
return kTfLiteError;
|
||||
|
@ -544,6 +544,7 @@ BUILTIN_KERNEL_SRCS = [
|
||||
"batch_to_space_nd.cc",
|
||||
"bidirectional_sequence_lstm.cc",
|
||||
"bidirectional_sequence_rnn.cc",
|
||||
"broadcast_to.cc",
|
||||
"call_once.cc",
|
||||
"cast.cc",
|
||||
"ceil.cc",
|
||||
@ -1017,6 +1018,19 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "broadcast_to_test",
|
||||
size = "small",
|
||||
srcs = ["broadcast_to_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "cast_test",
|
||||
size = "small",
|
||||
|
136
tensorflow/lite/kernels/broadcast_to.cc
Normal file
136
tensorflow/lite/kernels/broadcast_to.cc
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace broadcastto {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kShapeTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
constexpr int kMaxDims = 8;
|
||||
|
||||
struct BroadcastToContext {
|
||||
BroadcastToContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
input = GetInput(context, node, kInputTensor);
|
||||
shape = GetInput(context, node, kShapeTensor);
|
||||
output = GetOutput(context, node, kOutputTensor);
|
||||
}
|
||||
const TfLiteTensor* input;
|
||||
const TfLiteTensor* shape;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
BroadcastToContext* op_context) {
|
||||
// Ensures the shape is 1D tensor.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1);
|
||||
|
||||
// Ensure output dims is not less than input dims.
|
||||
int input_num_dims = NumDimensions(op_context->input);
|
||||
int output_num_dims = SizeOfDimension(op_context->shape, 0);
|
||||
TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
|
||||
"Output shape must be broadcastable from input shape.");
|
||||
TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
|
||||
"BroadcastTo only supports 1-8D tensor.");
|
||||
|
||||
// Check if output shape is broadcastable from input shape.
|
||||
auto get_shape_data = [op_context](int i) -> int32_t {
|
||||
if (op_context->shape->type == kTfLiteInt32) {
|
||||
return GetTensorData<int32_t>(op_context->shape)[i];
|
||||
} else {
|
||||
return GetTensorData<int64_t>(op_context->shape)[i];
|
||||
}
|
||||
};
|
||||
|
||||
int extending_dims = output_num_dims - input_num_dims;
|
||||
for (int idx = 0; idx < input_num_dims; ++idx) {
|
||||
TF_LITE_ENSURE_MSG(context,
|
||||
(SizeOfDimension(op_context->input, idx) == 1 ||
|
||||
SizeOfDimension(op_context->input, idx) ==
|
||||
get_shape_data(extending_dims + idx)),
|
||||
"Output shape must be broadcastable from input shape.");
|
||||
}
|
||||
// Resizing the shape of the output tensor.
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims);
|
||||
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
|
||||
scoped_output_shape(output_shape, TfLiteIntArrayFree);
|
||||
for (int idx = 0; idx < output_num_dims; ++idx) {
|
||||
output_shape->data[idx] = get_shape_data(idx);
|
||||
}
|
||||
|
||||
return context->ResizeTensor(context, op_context->output,
|
||||
scoped_output_shape.release());
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumInputs(node) == 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TF_LITE_ENSURE_MSG(context,
|
||||
(NumDimensions(GetInput(context, node, 0)) <= kMaxDims),
|
||||
"BroadcastTo only supports 1-8D tensor.");
|
||||
|
||||
BroadcastToContext op_context(context, node);
|
||||
TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 ||
|
||||
op_context.shape->type == kTfLiteInt64);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
|
||||
|
||||
// Not yet support string type due to the use of memcopy with fixed size.
|
||||
TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString);
|
||||
|
||||
if (IsConstantTensor(op_context.shape)) {
|
||||
return ResizeOutputTensor(context, &op_context);
|
||||
}
|
||||
|
||||
SetTensorToDynamic(op_context.output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
BroadcastToContext op_context(context, node);
|
||||
if (IsDynamicTensor(op_context.output)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
|
||||
}
|
||||
|
||||
// BroadcastTo op support upto 8 dims, matching the support of Tensorflow.
|
||||
reference_ops::BroadcastTo<kMaxDims>(
|
||||
GetTensorShape(op_context.input), op_context.input->data.raw,
|
||||
GetTensorShape(op_context.output), op_context.output->data.raw,
|
||||
op_context.input->type);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace broadcastto
|
||||
|
||||
TfLiteRegistration* Register_BROADCAST_TO() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare,
|
||||
broadcastto::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
255
tensorflow/lite/kernels/broadcast_to_test.cc
Normal file
255
tensorflow/lite/kernels/broadcast_to_test.cc
Normal file
@ -0,0 +1,255 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <class InputType, class ShapeType = int32_t>
|
||||
class BroadcastToOpModel : public SingleOpModel {
|
||||
public:
|
||||
// BroadcastTo with dynamic shape.
|
||||
BroadcastToOpModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> shape_shape) {
|
||||
input_ = AddInput({GetTensorType<InputType>(), input_shape});
|
||||
shape_ = AddInput({GetTensorType<ShapeType>(), shape_shape});
|
||||
output_ = AddOutput(GetTensorType<InputType>());
|
||||
SetBuiltinOp(BuiltinOperator_BROADCAST_TO,
|
||||
BuiltinOptions_BroadcastToOptions,
|
||||
CreateBroadcastToOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape, shape_shape});
|
||||
}
|
||||
|
||||
// BroadcastTo with const shape.
|
||||
BroadcastToOpModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> shape_shape,
|
||||
std::initializer_list<ShapeType> shape_values) {
|
||||
input_ = AddInput({GetTensorType<InputType>(), input_shape});
|
||||
shape_ =
|
||||
AddConstInput(GetTensorType<ShapeType>(), shape_values, shape_shape);
|
||||
output_ = AddOutput(GetTensorType<InputType>());
|
||||
SetBuiltinOp(BuiltinOperator_BROADCAST_TO,
|
||||
BuiltinOptions_BroadcastToOptions,
|
||||
CreateBroadcastToOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape, shape_shape});
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<InputType> data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
|
||||
void SetShape(std::initializer_list<ShapeType> data) {
|
||||
PopulateTensor(shape_, data);
|
||||
}
|
||||
|
||||
std::vector<InputType> GetOutput() {
|
||||
return ExtractVector<InputType>(output_);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int shape_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BroadcastToOpTest : public ::testing::Test {};
|
||||
|
||||
using DataTypes = ::testing::Types<float, uint8_t, int8_t, int16_t, int32_t>;
|
||||
TYPED_TEST_SUITE(BroadcastToOpTest, DataTypes);
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
TYPED_TEST(BroadcastToOpTest, ShapeMustBe1D) {
|
||||
EXPECT_DEATH(
|
||||
BroadcastToOpModel<TypeParam>({2, 3, 4, 4}, {2, 2}, {2, 3, 4, 4}), "");
|
||||
// Non-constant Shape tensor.
|
||||
BroadcastToOpModel<TypeParam> m({2, 3, 4, 4}, {2, 2});
|
||||
m.SetShape({2, 3, 4, 4});
|
||||
EXPECT_THAT(m.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, TooManyDimensions) {
|
||||
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9},
|
||||
{2, 2, 3, 4, 5, 6, 7, 8, 9}),
|
||||
"BroadcastTo only supports 1-8D tensor.");
|
||||
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9}),
|
||||
"BroadcastTo only supports 1-8D tensor.");
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, MismatchDimension) {
|
||||
EXPECT_DEATH(BroadcastToOpModel<TypeParam>({2, 4, 1, 2}, {4}, {2, 4, 1, 3}),
|
||||
"Output shape must be broadcastable from input shape.");
|
||||
EXPECT_DEATH(
|
||||
BroadcastToOpModel<TypeParam>({2, 4, 1, 2, 3}, {4}, {2, 4, 1, 2}),
|
||||
"Output shape must be broadcastable from input shape.");
|
||||
|
||||
// Non-constant Shape tensor.
|
||||
BroadcastToOpModel<TypeParam> m1({2, 4, 1, 2}, {4});
|
||||
m1.SetShape({2, 3, 4, 4});
|
||||
EXPECT_THAT(m1.InvokeUnchecked(), kTfLiteError);
|
||||
BroadcastToOpModel<TypeParam> m2({2, 4, 1, 2}, {5});
|
||||
m2.SetShape({1, 2, 3, 4, 4});
|
||||
EXPECT_THAT(m2.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
#endif
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, BroadcastTo1DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1}, {1}, {4});
|
||||
m.SetInput({3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, BroadcastTo4DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4}, {1, 1, 2, 2});
|
||||
m.SetInput({3, 4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, BroadcastTo8DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
|
||||
{1, 1, 1, 1, 1, 1, 2, 2});
|
||||
m.SetInput({3, 4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, BroadcastTo1DDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1}, {1});
|
||||
m.SetInput({3});
|
||||
m.SetShape({4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 3}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, BroadcastTo4DDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 1, 1, 2}, {4});
|
||||
m.SetInput({3, 4});
|
||||
m.SetShape({1, 1, 2, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 4, 3, 4}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, BroadcastTo8DDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 1, 1, 1, 1, 1, 2, 1}, {8});
|
||||
m.SetInput({3, 4});
|
||||
m.SetShape({1, 1, 1, 1, 1, 1, 2, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4}, {3, 3, 2, 2});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
|
||||
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast4DDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 3, 1, 2}, {4});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.SetShape({3, 3, 2, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
|
||||
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6}, {2, 2, 1, 3, 2, 2});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
|
||||
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12,
|
||||
1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
|
||||
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ComplexBroadcast6DDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam> m({1, 2, 1, 3, 1, 2}, {6});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetShape({2, 2, 1, 3, 2, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1, 3, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
|
||||
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12,
|
||||
1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6,
|
||||
7, 8, 7, 8, 9, 10, 9, 10, 11, 12, 11, 12}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, ExtendingShape4DConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {4}, {3, 3, 2, 2});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 2, 2}));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray({1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4,
|
||||
3, 4, 5, 6, 5, 6, 1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, NoBroadcastingConstTest) {
|
||||
BroadcastToOpModel<TypeParam> m({3, 1, 2}, {3}, {3, 1, 2});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, Int64ShapeConstTest) {
|
||||
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8},
|
||||
{1, 1, 1, 1, 1, 1, 2, 2});
|
||||
m.SetInput({3, 4});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
|
||||
}
|
||||
|
||||
TYPED_TEST(BroadcastToOpTest, Int64ShapeDDynamicTest) {
|
||||
BroadcastToOpModel<TypeParam, int64_t> m({1, 1, 1, 1, 1, 1, 2, 1}, {8});
|
||||
m.SetInput({3, 4});
|
||||
m.SetShape({1, 1, 1, 1, 1, 1, 2, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1, 1, 1, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 4, 4}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -39,6 +39,7 @@ TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
|
||||
TfLiteRegistration* Register_BATCH_MATMUL();
|
||||
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
|
||||
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
|
||||
TfLiteRegistration* Register_BROADCAST_TO();
|
||||
TfLiteRegistration* Register_CALL_ONCE();
|
||||
TfLiteRegistration* Register_CAST();
|
||||
TfLiteRegistration* Register_CEIL();
|
||||
|
@ -452,6 +452,7 @@ cc_library(
|
||||
"reference/arg_min_max.h",
|
||||
"reference/batch_matmul.h",
|
||||
"reference/binary_function.h",
|
||||
"reference/broadcast_to.h",
|
||||
"reference/ceil.h",
|
||||
"reference/comparisons.h",
|
||||
"reference/concatenation.h",
|
||||
@ -521,6 +522,7 @@ cc_library(
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
] + select({
|
||||
|
@ -696,6 +696,13 @@ inline int SubscriptToIndex(const NdArrayDesc<5>& desc, int indexes[5]) {
|
||||
indexes[4] * desc.strides[4];
|
||||
}
|
||||
|
||||
inline int SubscriptToIndex(const NdArrayDesc<8>& desc, int indexes[8]) {
|
||||
return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
|
||||
indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] +
|
||||
indexes[4] * desc.strides[4] + indexes[5] * desc.strides[5] +
|
||||
indexes[6] * desc.strides[6] + indexes[7] * desc.strides[7];
|
||||
}
|
||||
|
||||
// Given the dimensions of the operands for an element-wise binary broadcast,
|
||||
// adjusts them so that they can be directly iterated over with simple loops.
|
||||
// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
|
||||
|
90
tensorflow/lite/kernels/internal/reference/broadcast_to.h
Normal file
90
tensorflow/lite/kernels/internal/reference/broadcast_to.h
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace reference_ops {
|
||||
template <int N>
|
||||
void BroadcastImpl(const NdArrayDesc<N>& input_desc, const char* input_data,
|
||||
const NdArrayDesc<N>& output_desc, char* output_data,
|
||||
int indexes[N], int dim, const int last_broadcasting_dim,
|
||||
const int type_size) {
|
||||
// Copy data from input to output.
|
||||
if (dim == last_broadcasting_dim) {
|
||||
int copy_size = output_desc.strides[dim] * type_size;
|
||||
const char* data_src =
|
||||
input_data + SubscriptToIndex(input_desc, indexes) * type_size;
|
||||
char* data_dst =
|
||||
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
|
||||
for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
|
||||
memcpy(data_dst, data_src, copy_size);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Recursive call to find the next broadcasting.
|
||||
for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim];
|
||||
++indexes[dim]) {
|
||||
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes,
|
||||
dim + 1, last_broadcasting_dim, type_size);
|
||||
}
|
||||
|
||||
// Duplicate data in output tensor.
|
||||
indexes[dim] = 0;
|
||||
if (input_desc.extents[dim] != output_desc.extents[dim]) {
|
||||
int copy_size = output_desc.strides[dim] * type_size;
|
||||
char* data_src =
|
||||
output_data + SubscriptToIndex(output_desc, indexes) * type_size;
|
||||
char* data_dst = data_src + copy_size;
|
||||
for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
|
||||
memcpy(data_dst, data_src, copy_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
|
||||
const char* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
char* output_data, TfLiteType data_type) {
|
||||
NdArrayDesc<N> input_desc;
|
||||
NdArrayDesc<N> output_desc;
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
|
||||
&input_desc);
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
|
||||
&output_desc);
|
||||
|
||||
// Get the last dimension has broadcasting. At this dimension, the data is
|
||||
// copied from input tensor to output tensor.
|
||||
int last_broadcast_dim = 0;
|
||||
for (int i = N - 1; i > 0; --i) {
|
||||
if (input_desc.extents[i] != output_desc.extents[i]) {
|
||||
last_broadcast_dim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcasting using memcpy.
|
||||
int indexes[N] = {0};
|
||||
BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
|
||||
last_broadcast_dim, TfLiteTypeGetSize(data_type));
|
||||
}
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
@ -434,4 +435,44 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
||||
}
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// Size of string is not constant, return 0 in such case.
|
||||
int TfLiteTypeGetSize(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ASSERT_EQ(sizeof(uint8_t), 1);
|
||||
return 1;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ASSERT_EQ(sizeof(int8_t), 1);
|
||||
return 1;
|
||||
case kTfLiteBool:
|
||||
return sizeof(bool);
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_ASSERT_EQ(sizeof(int16_t), 2);
|
||||
return 2;
|
||||
case kTfLiteFloat16:
|
||||
TF_LITE_ASSERT_EQ(sizeof(int16_t), 2);
|
||||
return 2;
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_ASSERT_EQ(sizeof(float), 4);
|
||||
return 4;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
|
||||
return 4;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
|
||||
return 8;
|
||||
case kTfLiteFloat64:
|
||||
TF_LITE_ASSERT_EQ(sizeof(double), 8);
|
||||
return 8;
|
||||
case kTfLiteComplex64:
|
||||
TF_LITE_ASSERT_EQ(sizeof(std::complex<float>), 8);
|
||||
return 8;
|
||||
case kTfLiteComplex128:
|
||||
TF_LITE_ASSERT_EQ(sizeof(std::complex<double>), 16);
|
||||
return 16;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -284,6 +284,10 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
||||
const TfLiteTensor* input2,
|
||||
const TfLiteTensor* input3,
|
||||
TfLiteIntArray** output_shape);
|
||||
|
||||
// Return the size of given type in bytes. Return 0 in in case of string.
|
||||
int TfLiteTypeGetSize(TfLiteType type);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
|
||||
|
@ -297,6 +297,12 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_CUMSUM, Register_CUMSUM());
|
||||
// The version one of broadcast to op won't be not supported since the version
|
||||
// one was rollbacked and the builtin op code number has been changed because
|
||||
// of builtin op code shortage problem.
|
||||
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
||||
/* min_version = */ 2,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_CALL_ONCE,
|
||||
tflite::ops::builtin::Register_CALL_ONCE());
|
||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||
|
@ -156,6 +156,7 @@ TfLiteRegistration* Register_HARD_SWISH_REF();
|
||||
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
|
||||
TfLiteRegistration* Register_SELECT_V2();
|
||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||
TfLiteRegistration* Register_BROADCAST_TO();
|
||||
|
||||
namespace {
|
||||
|
||||
@ -262,6 +263,12 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
// The version one of broadcast to op won't be not supported since the version
|
||||
// one was rollbacked and the builtin op code number has been changed because
|
||||
// of builtin op code shortage problem.
|
||||
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
||||
/* min_version = */ 2,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
||||
Register_LOCAL_RESPONSE_NORM_REF());
|
||||
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
|
||||
|
@ -353,7 +353,8 @@ enum BuiltinOperator : int32 {
|
||||
BATCH_MATMUL = 126,
|
||||
PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
|
||||
CUMSUM = 128,
|
||||
CALL_ONCE = 129
|
||||
CALL_ONCE = 129,
|
||||
BROADCAST_TO = 130
|
||||
}
|
||||
|
||||
|
||||
@ -461,7 +462,8 @@ union BuiltinOptions {
|
||||
SegmentSumOptions,
|
||||
BatchMatMulOptions,
|
||||
CumsumOptions,
|
||||
CallOnceOptions
|
||||
CallOnceOptions,
|
||||
BroadcastToOptions
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -994,6 +996,9 @@ table CumsumOptions {
|
||||
reverse:bool;
|
||||
}
|
||||
|
||||
table BroadcastToOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -355,6 +355,9 @@ struct BatchMatMulOptionsT;
|
||||
struct CumsumOptions;
|
||||
struct CumsumOptionsT;
|
||||
|
||||
struct BroadcastToOptions;
|
||||
struct BroadcastToOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -796,11 +799,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
|
||||
BuiltinOperator_CUMSUM = 128,
|
||||
BuiltinOperator_CALL_ONCE = 129,
|
||||
BuiltinOperator_BROADCAST_TO = 130,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_CALL_ONCE
|
||||
BuiltinOperator_MAX = BuiltinOperator_BROADCAST_TO
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[131] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -931,13 +935,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[130] {
|
||||
BuiltinOperator_BATCH_MATMUL,
|
||||
BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES,
|
||||
BuiltinOperator_CUMSUM,
|
||||
BuiltinOperator_CALL_ONCE
|
||||
BuiltinOperator_CALL_ONCE,
|
||||
BuiltinOperator_BROADCAST_TO
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOperator() {
|
||||
static const char * const names[131] = {
|
||||
static const char * const names[132] = {
|
||||
"ADD",
|
||||
"AVERAGE_POOL_2D",
|
||||
"CONCATENATION",
|
||||
@ -1068,13 +1073,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"PLACEHOLDER_FOR_GREATER_OP_CODES",
|
||||
"CUMSUM",
|
||||
"CALL_ONCE",
|
||||
"BROADCAST_TO",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CALL_ONCE)) return "";
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BROADCAST_TO)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOperator()[index];
|
||||
}
|
||||
@ -1184,11 +1190,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_BatchMatMulOptions = 101,
|
||||
BuiltinOptions_CumsumOptions = 102,
|
||||
BuiltinOptions_CallOnceOptions = 103,
|
||||
BuiltinOptions_BroadcastToOptions = 104,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_CallOnceOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_BroadcastToOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[105] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -1293,13 +1300,14 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[104] {
|
||||
BuiltinOptions_SegmentSumOptions,
|
||||
BuiltinOptions_BatchMatMulOptions,
|
||||
BuiltinOptions_CumsumOptions,
|
||||
BuiltinOptions_CallOnceOptions
|
||||
BuiltinOptions_CallOnceOptions,
|
||||
BuiltinOptions_BroadcastToOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOptions() {
|
||||
static const char * const names[105] = {
|
||||
static const char * const names[106] = {
|
||||
"NONE",
|
||||
"Conv2DOptions",
|
||||
"DepthwiseConv2DOptions",
|
||||
@ -1404,13 +1412,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"BatchMatMulOptions",
|
||||
"CumsumOptions",
|
||||
"CallOnceOptions",
|
||||
"BroadcastToOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_CallOnceOptions)) return "";
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_BroadcastToOptions)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOptions()[index];
|
||||
}
|
||||
@ -1831,6 +1840,10 @@ template<> struct BuiltinOptionsTraits<tflite::CallOnceOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<tflite::BroadcastToOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2687,6 +2700,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_CallOnceOptions ?
|
||||
reinterpret_cast<const tflite::CallOnceOptionsT *>(value) : nullptr;
|
||||
}
|
||||
tflite::BroadcastToOptionsT *AsBroadcastToOptions() {
|
||||
return type == BuiltinOptions_BroadcastToOptions ?
|
||||
reinterpret_cast<tflite::BroadcastToOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const tflite::BroadcastToOptionsT *AsBroadcastToOptions() const {
|
||||
return type == BuiltinOptions_BroadcastToOptions ?
|
||||
reinterpret_cast<const tflite::BroadcastToOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -9505,6 +9526,46 @@ inline flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(
|
||||
|
||||
flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct BroadcastToOptionsT : public flatbuffers::NativeTable {
|
||||
typedef BroadcastToOptions TableType;
|
||||
BroadcastToOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef BroadcastToOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
BroadcastToOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<BroadcastToOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct BroadcastToOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit BroadcastToOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
BroadcastToOptionsBuilder &operator=(const BroadcastToOptionsBuilder &);
|
||||
flatbuffers::Offset<BroadcastToOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<BroadcastToOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
BroadcastToOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
int8_t deprecated_builtin_code;
|
||||
@ -9964,6 +10025,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const {
|
||||
return builtin_options_type() == tflite::BuiltinOptions_CallOnceOptions ? static_cast<const tflite::CallOnceOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const {
|
||||
return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast<const tflite::BroadcastToOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -10412,6 +10476,10 @@ template<> inline const tflite::CallOnceOptions *Operator::builtin_options_as<tf
|
||||
return builtin_options_as_CallOnceOptions();
|
||||
}
|
||||
|
||||
template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as<tflite::BroadcastToOptions>() const {
|
||||
return builtin_options_as_BroadcastToOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -14143,6 +14211,29 @@ inline flatbuffers::Offset<CumsumOptions> CreateCumsumOptions(flatbuffers::FlatB
|
||||
_reverse);
|
||||
}
|
||||
|
||||
inline BroadcastToOptionsT *BroadcastToOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new BroadcastToOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void BroadcastToOptions::UnPackTo(BroadcastToOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<BroadcastToOptions> BroadcastToOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateBroadcastToOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BroadcastToOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateBroadcastToOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -15030,6 +15121,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_BroadcastToOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
@ -15460,6 +15555,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const tflite::CallOnceOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_BroadcastToOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BroadcastToOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -15878,6 +15977,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const tflite::CallOnceOptionsT *>(value);
|
||||
return CreateCallOnceOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_BroadcastToOptions: {
|
||||
auto ptr = reinterpret_cast<const tflite::BroadcastToOptionsT *>(value);
|
||||
return CreateBroadcastToOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -16296,6 +16399,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new tflite::CallOnceOptionsT(*reinterpret_cast<tflite::CallOnceOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_BroadcastToOptions: {
|
||||
value = new tflite::BroadcastToOptionsT(*reinterpret_cast<tflite::BroadcastToOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -16818,6 +16925,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_BroadcastToOptions: {
|
||||
auto ptr = reinterpret_cast<tflite::BroadcastToOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -610,6 +610,11 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
// The version one of broadcast to op won't be not supported since the
|
||||
// version one was rollbacked and the builtin op code number has been
|
||||
// changed because of builtin op code shortage problem.
|
||||
case BuiltinOperator_BROADCAST_TO:
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
|
@ -61,6 +61,10 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
|
||||
{{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
|
||||
{{BuiltinOperator_BATCH_MATMUL, 3}, "2.4.0"},
|
||||
// The version one of broadcast to op won't be not supported since
|
||||
// the version one was rollbacked and the builtin op code number
|
||||
// has been changed because of builtin op code shortage problem.
|
||||
{{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
||||
|
@ -47,7 +47,7 @@ TEST(OpVersionTest, OpversionMissing) {
|
||||
EXPECT_NE(runtime_version, "")
|
||||
<< "Please add the version " << version << " of "
|
||||
<< tflite::EnumNamesBuiltinOperator()[op_code]
|
||||
<< " runtime_version.cc";
|
||||
<< " to runtime_version.cc";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user