Added ScatterND op
PiperOrigin-RevId: 278286268 Change-Id: Iacceb4ac18b7eb2f8da2faa385d0ca69159ffa1c
This commit is contained in:
parent
e6f23d3e45
commit
998eadd8b5
@ -148,6 +148,7 @@ typedef enum {
|
||||
kTfLiteBuiltinWhile = 119,
|
||||
kTfLiteBuiltinNonMaxSuppressionV4 = 120,
|
||||
kTfLiteBuiltinNonMaxSuppressionV5 = 121,
|
||||
kTfLiteBuiltinScatterNd = 122,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -818,6 +818,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_QUANTIZE:
|
||||
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
|
||||
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
|
||||
case BuiltinOperator_SCATTER_ND:
|
||||
break;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -482,6 +482,7 @@ cc_library(
|
||||
"reverse.cc",
|
||||
"reverse_sequence.cc",
|
||||
"round.cc",
|
||||
"scatter_nd.cc",
|
||||
"select.cc",
|
||||
"shape.cc",
|
||||
"skip_gram.cc",
|
||||
@ -1151,6 +1152,20 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "scatter_nd_test",
|
||||
size = "small",
|
||||
srcs = ["scatter_nd_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "topk_v2_test",
|
||||
size = "small",
|
||||
|
@ -147,6 +147,7 @@ TfLiteRegistration* Register_IF();
|
||||
TfLiteRegistration* Register_WHILE();
|
||||
TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4();
|
||||
TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5();
|
||||
TfLiteRegistration* Register_SCATTER_ND();
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
|
@ -2187,6 +2187,48 @@ inline void GatherNd(const RuntimeShape& params_shape,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IndicesT, typename UpdatesT>
|
||||
inline void ScatterNd(const RuntimeShape& indices_shape,
|
||||
const IndicesT* indices_data,
|
||||
const RuntimeShape& updates_shape,
|
||||
const UpdatesT* updates_data,
|
||||
const RuntimeShape& output_shape, UpdatesT* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("ScatterNd");
|
||||
|
||||
int n_slices = 1;
|
||||
int slice_size = 1;
|
||||
const int outer_dims = indices_shape.DimensionsCount() - 1;
|
||||
const int indices_nd = indices_shape.Dims(outer_dims);
|
||||
const int updates_dims = updates_shape.DimensionsCount();
|
||||
for (int i = 0; i < outer_dims; ++i) {
|
||||
n_slices *= indices_shape.Dims(i);
|
||||
}
|
||||
for (int i = outer_dims; i < updates_dims; ++i) {
|
||||
slice_size *= updates_shape.Dims(i);
|
||||
}
|
||||
|
||||
int output_flat_size = output_shape.FlatSize();
|
||||
int remain_flat_size = output_flat_size;
|
||||
std::vector<int> dims_to_count(indices_nd, 0);
|
||||
for (int i = 0; i < indices_nd; ++i) {
|
||||
dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
|
||||
remain_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
|
||||
for (int i = 0; i < n_slices; ++i) {
|
||||
int to_pos = 0;
|
||||
for (int j = 0; j < indices_nd; ++j) {
|
||||
IndicesT idx = indices_data[i * indices_nd + j];
|
||||
TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
|
||||
to_pos += idx * dims_to_count[j];
|
||||
}
|
||||
for (int j = 0; j < slice_size; j++) {
|
||||
output_data[to_pos + j] += updates_data[i * slice_size + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
|
@ -273,6 +273,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
Register_NON_MAX_SUPPRESSION_V4());
|
||||
AddBuiltin(BuiltinOperator_NON_MAX_SUPPRESSION_V5,
|
||||
Register_NON_MAX_SUPPRESSION_V5());
|
||||
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
190
tensorflow/lite/kernels/scatter_nd.cc
Normal file
190
tensorflow/lite/kernels/scatter_nd.cc
Normal file
@ -0,0 +1,190 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace scatter_nd {
|
||||
constexpr int kIndices = 0;
|
||||
constexpr int kUpdates = 1;
|
||||
constexpr int kShape = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
template <typename IndicesT>
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* shape,
|
||||
TfLiteTensor* output) {
|
||||
const int shape_rank = SizeOfDimension(shape, 0);
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape_rank);
|
||||
const auto* shape_data = GetTensorData<IndicesT>(shape);
|
||||
|
||||
for (int i = 0; i < shape_rank; i++) {
|
||||
output_shape->data[i] = shape_data[i];
|
||||
}
|
||||
return context->ResizeTensor(context, output, output_shape);
|
||||
}
|
||||
|
||||
template <typename IndicesT>
|
||||
TfLiteStatus CheckShapes(TfLiteContext* context, const RuntimeShape& indices,
|
||||
const RuntimeShape& updates,
|
||||
const RuntimeShape& shape_shape,
|
||||
const IndicesT* shape_data) {
|
||||
TF_LITE_ENSURE(context, (indices.DimensionsCount() >= 1) &&
|
||||
(updates.DimensionsCount() >= 1) &&
|
||||
(shape_shape.DimensionsCount() == 1));
|
||||
|
||||
const int outer_dims = indices.DimensionsCount() - 1;
|
||||
for (int i = 0; i < outer_dims; ++i) {
|
||||
TF_LITE_ENSURE_EQ(context, indices.Dims(i), updates.Dims(i));
|
||||
}
|
||||
|
||||
const int ix = indices.Dims(outer_dims);
|
||||
TF_LITE_ENSURE_EQ(context, updates.DimensionsCount() - outer_dims,
|
||||
shape_shape.Dims(0) - ix);
|
||||
for (int i = 0; i + outer_dims < updates.DimensionsCount(); ++i) {
|
||||
TF_LITE_ENSURE_EQ(context, updates.Dims(i + outer_dims),
|
||||
shape_data[ix + i]);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndices);
|
||||
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShape);
|
||||
|
||||
switch (updates->type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt64:
|
||||
case kTfLiteInt32:
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Updates of type '%s' are not supported by scatter_nd.",
|
||||
TfLiteTypeGetName(updates->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (indices->type != shape->type) {
|
||||
context->ReportError(context, "Indices and shape must have the same type.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
output->type = updates->type;
|
||||
|
||||
if (IsConstantTensor(shape)) {
|
||||
switch (indices->type) {
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
CheckShapes<int32_t>(context, GetTensorShape(indices),
|
||||
GetTensorShape(updates), GetTensorShape(shape),
|
||||
GetTensorData<int32_t>(shape)));
|
||||
return ResizeOutputTensor<int32_t>(context, shape, output);
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Indices of type '%s' are not supported by scatter_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IndicesT, typename UpdatesT>
|
||||
TfLiteStatus ScatterNd(const TfLiteTensor* indices, const TfLiteTensor* updates,
|
||||
TfLiteTensor* output) {
|
||||
reference_ops::ScatterNd(
|
||||
GetTensorShape(indices), GetTensorData<IndicesT>(indices),
|
||||
GetTensorShape(updates), GetTensorData<UpdatesT>(updates),
|
||||
GetTensorShape(output), GetTensorData<UpdatesT>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename IndicesT>
|
||||
TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
|
||||
const TfLiteTensor* updates,
|
||||
const TfLiteTensor* shape, TfLiteTensor* output) {
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CheckShapes<IndicesT>(
|
||||
context, GetTensorShape(indices), GetTensorShape(updates),
|
||||
GetTensorShape(shape), GetTensorData<IndicesT>(shape)));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeOutputTensor<IndicesT>(context, shape, output));
|
||||
}
|
||||
|
||||
switch (updates->type) {
|
||||
case kTfLiteFloat32:
|
||||
return ScatterNd<IndicesT, float>(indices, updates, output);
|
||||
case kTfLiteUInt8:
|
||||
return ScatterNd<IndicesT, uint8_t>(indices, updates, output);
|
||||
case kTfLiteInt8:
|
||||
return ScatterNd<IndicesT, int8_t>(indices, updates, output);
|
||||
case kTfLiteInt32:
|
||||
return ScatterNd<IndicesT, int32_t>(indices, updates, output);
|
||||
case kTfLiteInt64:
|
||||
return ScatterNd<IndicesT, int64_t>(indices, updates, output);
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Updates of type '%s' are not supported by scatter_nd.",
|
||||
TfLiteTypeGetName(updates->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndices);
|
||||
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShape);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (indices->type) {
|
||||
case kTfLiteInt32:
|
||||
return EvalScatterNd<int32_t>(context, indices, updates, shape, output);
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Indices of type '%s' are not supported by scatter_nd.",
|
||||
TfLiteTypeGetName(indices->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace scatter_nd
|
||||
|
||||
TfLiteRegistration* Register_SCATTER_ND() {
|
||||
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
|
||||
scatter_nd::Prepare, scatter_nd::Eval};
|
||||
return &r;
|
||||
}
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
349
tensorflow/lite/kernels/scatter_nd_test.cc
Normal file
349
tensorflow/lite/kernels/scatter_nd_test.cc
Normal file
@ -0,0 +1,349 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class ScatterNdOpModel : public SingleOpModel {
|
||||
public:
|
||||
ScatterNdOpModel(const TensorData& indices, const TensorData& updates,
|
||||
const TensorData& shape) {
|
||||
indices_ = AddInput(indices);
|
||||
updates_ = AddInput(updates);
|
||||
shape_ = AddInput(shape);
|
||||
output_ = AddOutput(updates.type);
|
||||
SetBuiltinOp(BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions,
|
||||
CreateScatterNdOptions(builder_).Union());
|
||||
BuildInterpreter(
|
||||
{GetShape(indices_), GetShape(updates_), GetShape(shape_)});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetIndices(std::initializer_list<T> data) {
|
||||
PopulateTensor<T>(indices_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetUpdates(std::initializer_list<T> data) {
|
||||
PopulateTensor<T>(updates_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetShape(std::initializer_list<T> data) {
|
||||
PopulateTensor<T>(shape_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
int indices_;
|
||||
int updates_;
|
||||
int shape_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(ScatterNdOpTest, ScatterElementIntoVector) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4}},
|
||||
{TensorType_INT32, {1}});
|
||||
m.SetIndices<int32_t>({4, 3, 1, 7});
|
||||
m.SetUpdates<float>({9, 10, 11, 12});
|
||||
m.SetShape<int32_t>({8});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({0, 11, 0, 10, 9, 0, 0, 12}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, ScatterMatrixIntoRank3Tensor) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {2, 1}},
|
||||
{TensorType_FLOAT32, {2, 4, 4}}, {TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({0, 2});
|
||||
m.SetUpdates<float>({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8});
|
||||
m.SetShape<int32_t>({4, 4, 4});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 4, 4}));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput<float>(),
|
||||
ElementsAreArray({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, ScatterVectorIntoMatrix) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4, 4}},
|
||||
{TensorType_INT32, {2}});
|
||||
m.SetIndices<int32_t>({/*0*/ 9, /*1*/ 8, /*2*/ 0, /*3*/ 1});
|
||||
m.SetUpdates<float>({/*0*/ 1, 2, 3, 4,
|
||||
/*1*/ 5, 6, 7, 8,
|
||||
/*2*/ 9, 10, 11, 12,
|
||||
/*3*/ 13, 14, 15, 16});
|
||||
m.SetShape<int32_t>({10, 4});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({10, 4}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({/*0*/ 9, 10, 11, 12,
|
||||
/*1*/ 13, 14, 15, 16,
|
||||
/*2*/ 0, 0, 0, 0,
|
||||
/*3*/ 0, 0, 0, 0,
|
||||
/*4*/ 0, 0, 0, 0,
|
||||
/*5*/ 0, 0, 0, 0,
|
||||
/*6*/ 0, 0, 0, 0,
|
||||
/*7*/ 0, 0, 0, 0,
|
||||
/*8*/ 5, 6, 7, 8,
|
||||
/*9*/ 1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, ScatterMatricesIntoRank4Tensor) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {2, 2, 2}},
|
||||
{TensorType_FLOAT32, {2, 2, 2, 2}},
|
||||
{TensorType_INT32, {4}});
|
||||
m.SetIndices<int32_t>(
|
||||
{/*0,0*/ 1, 1, /*0,1*/ 0, 1, /*1,0*/ 0, 0, /*1,1*/ 1, 0});
|
||||
m.SetUpdates<float>({/*0,0*/ 1, 2, 3, 4, /*0,1*/ 5, 6, 7, 8,
|
||||
/*1,0*/ 9, 10, 11, 12, /*1,1*/ 13, 14, 15, 16});
|
||||
m.SetShape<int32_t>({2, 2, 2, 2});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({/*0, 0*/ 9, 10, 11, 12,
|
||||
/*0, 1*/ 5, 6, 7, 8,
|
||||
/*1, 0*/ 13, 14, 15, 16,
|
||||
/*1, 1*/ 1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, ScatterVectorIntoRank4Tensor) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {2, 2, 3}},
|
||||
{TensorType_FLOAT32, {2, 2, 5}}, {TensorType_INT32, {4}});
|
||||
m.SetIndices<int32_t>(
|
||||
{/*0,0*/ 2, 2, 2, /*0,1*/ 1, 0, 1, /*1,0*/ 0, 2, 0, /*1,0*/ 2, 2, 0});
|
||||
m.SetUpdates<float>(
|
||||
{/*0,0*/ 1, 2, 3, 4, 5, /*0,1*/ 6, 7, 8, 9, 10,
|
||||
/*1,0*/ 11, 12, 13, 14, 15, /*1,1*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({3, 3, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({
|
||||
/*0, 0, 0*/ 0, 0, 0, 0, 0,
|
||||
/*0, 0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 0, 2*/ 0, 0, 0, 0, 0,
|
||||
/*0, 1, 0*/ 0, 0, 0, 0, 0,
|
||||
/*0, 1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 1, 2*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2, 0*/ 11, 12, 13, 14, 15,
|
||||
/*0, 2, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2, 2*/ 0, 0, 0, 0, 0,
|
||||
/*1, 0, 0*/ 0, 0, 0, 0, 0,
|
||||
/*1, 0, 1*/ 6, 7, 8, 9, 10,
|
||||
/*1, 0, 2*/ 0, 0, 0, 0, 0,
|
||||
/*1, 1, 0*/ 0, 0, 0, 0, 0,
|
||||
/*1, 1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 1, 2*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2, 0*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2, 2*/ 0, 0, 0, 0, 0,
|
||||
/*2, 0, 0*/ 0, 0, 0, 0, 0,
|
||||
/*2, 0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*2, 0, 2*/ 0, 0, 0, 0, 0,
|
||||
/*2, 1, 0*/ 0, 0, 0, 0, 0,
|
||||
/*2, 1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*2, 1, 2*/ 0, 0, 0, 0, 0,
|
||||
/*2, 2, 0*/ 16, 17, 18, 19, 20,
|
||||
/*2, 2, 1*/ 0, 0, 0, 0, 0,
|
||||
/*2, 2, 2*/ 1, 2, 3, 4, 5,
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, ScatterVectorIntoRank3Tensor) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
|
||||
m.SetUpdates<float>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 11, 12, 13, 14, 15,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, OverlappedIndicesSummed) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 1, 0, /*1*/ 0, 2, /*2*/ 0, 2, /*3*/ 1, 0});
|
||||
m.SetUpdates<float>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 17, 19, 21, 23, 25,
|
||||
/*1, 0*/ 17, 19, 21, 23, 25,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, Int32IndicesUint8Updates) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_UINT8, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
|
||||
m.SetUpdates<uint8_t>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 11, 12, 13, 14, 15,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, Int32IndicesInt8Updates) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT8, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
|
||||
m.SetUpdates<int8_t>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 11, 12, 13, 14, 15,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, Int32IndicesInt32Updates) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT32, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
|
||||
m.SetUpdates<int32_t>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<int32_t>(),
|
||||
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 11, 12, 13, 14, 15,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, Int32IndicesInt64Updates) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
|
||||
m.SetUpdates<int64_t>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<int64_t>(),
|
||||
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 11, 12, 13, 14, 15,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20}));
|
||||
}
|
||||
|
||||
TEST(ScatterNdOpTest, DynamicShape) {
|
||||
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
|
||||
{TensorType_INT32, {3}});
|
||||
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
|
||||
m.SetUpdates<int64_t>(
|
||||
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
|
||||
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
|
||||
m.SetShape<int32_t>({2, 3, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
|
||||
EXPECT_THAT(m.GetOutput<int64_t>(),
|
||||
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 11, 12, 13, 14, 15,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20}));
|
||||
|
||||
m.SetIndices<int32_t>({/*0*/ 2, 3, /*1*/ 1, 0, /*2*/ 2, 0, /*3*/ 1, 2});
|
||||
m.SetShape<int32_t>({3, 4, 5});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 4, 5}));
|
||||
EXPECT_THAT(m.GetOutput<int64_t>(),
|
||||
ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
|
||||
/*0, 1*/ 0, 0, 0, 0, 0,
|
||||
/*0, 2*/ 0, 0, 0, 0, 0,
|
||||
/*0, 3*/ 0, 0, 0, 0, 0,
|
||||
/*1, 0*/ 6, 7, 8, 9, 10,
|
||||
/*1, 1*/ 0, 0, 0, 0, 0,
|
||||
/*1, 2*/ 16, 17, 18, 19, 20,
|
||||
/*1, 3*/ 0, 0, 0, 0, 0,
|
||||
/*2, 0*/ 11, 12, 13, 14, 15,
|
||||
/*2, 1*/ 0, 0, 0, 0, 0,
|
||||
/*2, 2*/ 0, 0, 0, 0, 0,
|
||||
/*2, 3*/ 1, 2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -235,6 +235,7 @@ enum BuiltinOperator : byte {
|
||||
WHILE = 119,
|
||||
NON_MAX_SUPPRESSION_V4 = 120,
|
||||
NON_MAX_SUPPRESSION_V5 = 121,
|
||||
SCATTER_ND = 122
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -334,7 +335,8 @@ union BuiltinOptions {
|
||||
WhileOptions,
|
||||
DepthToSpaceOptions,
|
||||
NonMaxSuppressionV4Options,
|
||||
NonMaxSuppressionV5Options
|
||||
NonMaxSuppressionV5Options,
|
||||
ScatterNdOptions
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -812,6 +814,9 @@ table NonMaxSuppressionV4Options {
|
||||
table NonMaxSuppressionV5Options {
|
||||
}
|
||||
|
||||
table ScatterNdOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -319,6 +319,9 @@ struct NonMaxSuppressionV4OptionsT;
|
||||
struct NonMaxSuppressionV5Options;
|
||||
struct NonMaxSuppressionV5OptionsT;
|
||||
|
||||
struct ScatterNdOptions;
|
||||
struct ScatterNdOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -597,11 +600,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_WHILE = 119,
|
||||
BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120,
|
||||
BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121,
|
||||
BuiltinOperator_SCATTER_ND = 122,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_NON_MAX_SUPPRESSION_V5
|
||||
BuiltinOperator_MAX = BuiltinOperator_SCATTER_ND
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[122] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[123] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -724,7 +728,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[122] {
|
||||
BuiltinOperator_IF,
|
||||
BuiltinOperator_WHILE,
|
||||
BuiltinOperator_NON_MAX_SUPPRESSION_V4,
|
||||
BuiltinOperator_NON_MAX_SUPPRESSION_V5
|
||||
BuiltinOperator_NON_MAX_SUPPRESSION_V5,
|
||||
BuiltinOperator_SCATTER_ND
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -853,13 +858,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"WHILE",
|
||||
"NON_MAX_SUPPRESSION_V4",
|
||||
"NON_MAX_SUPPRESSION_V5",
|
||||
"SCATTER_ND",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||
if (e < BuiltinOperator_ADD || e > BuiltinOperator_NON_MAX_SUPPRESSION_V5) return "";
|
||||
if (e < BuiltinOperator_ADD || e > BuiltinOperator_SCATTER_ND) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOperator()[index];
|
||||
}
|
||||
@ -962,11 +968,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_DepthToSpaceOptions = 94,
|
||||
BuiltinOptions_NonMaxSuppressionV4Options = 95,
|
||||
BuiltinOptions_NonMaxSuppressionV5Options = 96,
|
||||
BuiltinOptions_ScatterNdOptions = 97,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_NonMaxSuppressionV5Options
|
||||
BuiltinOptions_MAX = BuiltinOptions_ScatterNdOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[97] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[98] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -1064,7 +1071,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[97] {
|
||||
BuiltinOptions_WhileOptions,
|
||||
BuiltinOptions_DepthToSpaceOptions,
|
||||
BuiltinOptions_NonMaxSuppressionV4Options,
|
||||
BuiltinOptions_NonMaxSuppressionV5Options
|
||||
BuiltinOptions_NonMaxSuppressionV5Options,
|
||||
BuiltinOptions_ScatterNdOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1168,13 +1176,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"DepthToSpaceOptions",
|
||||
"NonMaxSuppressionV4Options",
|
||||
"NonMaxSuppressionV5Options",
|
||||
"ScatterNdOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||
if (e < BuiltinOptions_NONE || e > BuiltinOptions_NonMaxSuppressionV5Options) return "";
|
||||
if (e < BuiltinOptions_NONE || e > BuiltinOptions_ScatterNdOptions) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOptions()[index];
|
||||
}
|
||||
@ -1567,6 +1576,10 @@ template<> struct BuiltinOptionsTraits<NonMaxSuppressionV5Options> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<ScatterNdOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2367,6 +2380,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_NonMaxSuppressionV5Options ?
|
||||
reinterpret_cast<const NonMaxSuppressionV5OptionsT *>(value) : nullptr;
|
||||
}
|
||||
ScatterNdOptionsT *AsScatterNdOptions() {
|
||||
return type == BuiltinOptions_ScatterNdOptions ?
|
||||
reinterpret_cast<ScatterNdOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const ScatterNdOptionsT *AsScatterNdOptions() const {
|
||||
return type == BuiltinOptions_ScatterNdOptions ?
|
||||
reinterpret_cast<const ScatterNdOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -8226,6 +8247,46 @@ inline flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5
|
||||
|
||||
flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct ScatterNdOptionsT : public flatbuffers::NativeTable {
|
||||
typedef ScatterNdOptions TableType;
|
||||
ScatterNdOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef ScatterNdOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
ScatterNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(ScatterNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<ScatterNdOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct ScatterNdOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit ScatterNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
ScatterNdOptionsBuilder &operator=(const ScatterNdOptionsBuilder &);
|
||||
flatbuffers::Offset<ScatterNdOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<ScatterNdOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
ScatterNdOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -8650,6 +8711,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const {
|
||||
return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV5Options ? static_cast<const NonMaxSuppressionV5Options *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const ScatterNdOptions *builtin_options_as_ScatterNdOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_ScatterNdOptions ? static_cast<const ScatterNdOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -9070,6 +9134,10 @@ template<> inline const NonMaxSuppressionV5Options *Operator::builtin_options_as
|
||||
return builtin_options_as_NonMaxSuppressionV5Options();
|
||||
}
|
||||
|
||||
template<> inline const ScatterNdOptions *Operator::builtin_options_as<ScatterNdOptions>() const {
|
||||
return builtin_options_as_ScatterNdOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -12225,6 +12293,29 @@ inline flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline ScatterNdOptionsT *ScatterNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new ScatterNdOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void ScatterNdOptions::UnPackTo(ScatterNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<ScatterNdOptions> ScatterNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateScatterNdOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ScatterNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateScatterNdOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -12902,6 +12993,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const NonMaxSuppressionV5Options *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_ScatterNdOptions: {
|
||||
auto ptr = reinterpret_cast<const ScatterNdOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -13304,6 +13399,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const NonMaxSuppressionV5Options *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_ScatterNdOptions: {
|
||||
auto ptr = reinterpret_cast<const ScatterNdOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -13694,6 +13793,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const NonMaxSuppressionV5OptionsT *>(value);
|
||||
return CreateNonMaxSuppressionV5Options(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_ScatterNdOptions: {
|
||||
auto ptr = reinterpret_cast<const ScatterNdOptionsT *>(value);
|
||||
return CreateScatterNdOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -14084,6 +14187,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new NonMaxSuppressionV5OptionsT(*reinterpret_cast<NonMaxSuppressionV5OptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_ScatterNdOptions: {
|
||||
value = new ScatterNdOptionsT(*reinterpret_cast<ScatterNdOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -14571,6 +14678,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_ScatterNdOptions: {
|
||||
auto ptr = reinterpret_cast<ScatterNdOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user