Added ScatterND op

PiperOrigin-RevId: 278286268
Change-Id: Iacceb4ac18b7eb2f8da2faa385d0ca69159ffa1c
This commit is contained in:
A. Unique TensorFlower 2019-11-03 19:40:56 -08:00 committed by TensorFlower Gardener
parent e6f23d3e45
commit 998eadd8b5
10 changed files with 726 additions and 9 deletions

View File

@ -148,6 +148,7 @@ typedef enum {
kTfLiteBuiltinWhile = 119,
kTfLiteBuiltinNonMaxSuppressionV4 = 120,
kTfLiteBuiltinNonMaxSuppressionV5 = 121,
kTfLiteBuiltinScatterNd = 122,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -818,6 +818,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_QUANTIZE:
case BuiltinOperator_NON_MAX_SUPPRESSION_V4:
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
case BuiltinOperator_SCATTER_ND:
break;
}
return kTfLiteOk;

View File

@ -482,6 +482,7 @@ cc_library(
"reverse.cc",
"reverse_sequence.cc",
"round.cc",
"scatter_nd.cc",
"select.cc",
"shape.cc",
"skip_gram.cc",
@ -1151,6 +1152,20 @@ cc_test(
],
)
cc_test(
name = "scatter_nd_test",
size = "small",
srcs = ["scatter_nd_test.cc"],
deps = [
":builtin_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "topk_v2_test",
size = "small",

View File

@ -147,6 +147,7 @@ TfLiteRegistration* Register_IF();
TfLiteRegistration* Register_WHILE();
TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4();
TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5();
TfLiteRegistration* Register_SCATTER_ND();
} // namespace builtin
} // namespace ops

View File

@ -2187,6 +2187,48 @@ inline void GatherNd(const RuntimeShape& params_shape,
}
}
template <typename IndicesT, typename UpdatesT>
inline void ScatterNd(const RuntimeShape& indices_shape,
const IndicesT* indices_data,
const RuntimeShape& updates_shape,
const UpdatesT* updates_data,
const RuntimeShape& output_shape, UpdatesT* output_data) {
gemmlowp::ScopedProfilingLabel label("ScatterNd");
int n_slices = 1;
int slice_size = 1;
const int outer_dims = indices_shape.DimensionsCount() - 1;
const int indices_nd = indices_shape.Dims(outer_dims);
const int updates_dims = updates_shape.DimensionsCount();
for (int i = 0; i < outer_dims; ++i) {
n_slices *= indices_shape.Dims(i);
}
for (int i = outer_dims; i < updates_dims; ++i) {
slice_size *= updates_shape.Dims(i);
}
int output_flat_size = output_shape.FlatSize();
int remain_flat_size = output_flat_size;
std::vector<int> dims_to_count(indices_nd, 0);
for (int i = 0; i < indices_nd; ++i) {
dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
remain_flat_size = dims_to_count[i];
}
memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
for (int i = 0; i < n_slices; ++i) {
int to_pos = 0;
for (int j = 0; j < indices_nd; ++j) {
IndicesT idx = indices_data[i * indices_nd + j];
TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
to_pos += idx * dims_to_count[j];
}
for (int j = 0; j < slice_size; j++) {
output_data[to_pos + j] += updates_data[i * slice_size + j];
}
}
}
template <typename T>
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,

View File

@ -273,6 +273,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
Register_NON_MAX_SUPPRESSION_V4());
AddBuiltin(BuiltinOperator_NON_MAX_SUPPRESSION_V5,
Register_NON_MAX_SUPPRESSION_V5());
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -0,0 +1,190 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace scatter_nd {
constexpr int kIndices = 0;
constexpr int kUpdates = 1;
constexpr int kShape = 2;
constexpr int kOutputTensor = 0;
template <typename IndicesT>
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const TfLiteTensor* shape,
TfLiteTensor* output) {
const int shape_rank = SizeOfDimension(shape, 0);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape_rank);
const auto* shape_data = GetTensorData<IndicesT>(shape);
for (int i = 0; i < shape_rank; i++) {
output_shape->data[i] = shape_data[i];
}
return context->ResizeTensor(context, output, output_shape);
}
template <typename IndicesT>
TfLiteStatus CheckShapes(TfLiteContext* context, const RuntimeShape& indices,
const RuntimeShape& updates,
const RuntimeShape& shape_shape,
const IndicesT* shape_data) {
TF_LITE_ENSURE(context, (indices.DimensionsCount() >= 1) &&
(updates.DimensionsCount() >= 1) &&
(shape_shape.DimensionsCount() == 1));
const int outer_dims = indices.DimensionsCount() - 1;
for (int i = 0; i < outer_dims; ++i) {
TF_LITE_ENSURE_EQ(context, indices.Dims(i), updates.Dims(i));
}
const int ix = indices.Dims(outer_dims);
TF_LITE_ENSURE_EQ(context, updates.DimensionsCount() - outer_dims,
shape_shape.Dims(0) - ix);
for (int i = 0; i + outer_dims < updates.DimensionsCount(); ++i) {
TF_LITE_ENSURE_EQ(context, updates.Dims(i + outer_dims),
shape_data[ix + i]);
}
return kTfLiteOk;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* indices = GetInput(context, node, kIndices);
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
const TfLiteTensor* shape = GetInput(context, node, kShape);
switch (updates->type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
case kTfLiteInt64:
case kTfLiteInt32:
break;
default:
context->ReportError(
context, "Updates of type '%s' are not supported by scatter_nd.",
TfLiteTypeGetName(updates->type));
return kTfLiteError;
}
if (indices->type != shape->type) {
context->ReportError(context, "Indices and shape must have the same type.");
return kTfLiteError;
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = updates->type;
if (IsConstantTensor(shape)) {
switch (indices->type) {
case kTfLiteInt32:
TF_LITE_ENSURE_OK(
context,
CheckShapes<int32_t>(context, GetTensorShape(indices),
GetTensorShape(updates), GetTensorShape(shape),
GetTensorData<int32_t>(shape)));
return ResizeOutputTensor<int32_t>(context, shape, output);
default:
context->ReportError(
context, "Indices of type '%s' are not supported by scatter_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
} else {
SetTensorToDynamic(output);
return kTfLiteOk;
}
}
template <typename IndicesT, typename UpdatesT>
TfLiteStatus ScatterNd(const TfLiteTensor* indices, const TfLiteTensor* updates,
TfLiteTensor* output) {
reference_ops::ScatterNd(
GetTensorShape(indices), GetTensorData<IndicesT>(indices),
GetTensorShape(updates), GetTensorData<UpdatesT>(updates),
GetTensorShape(output), GetTensorData<UpdatesT>(output));
return kTfLiteOk;
}
template <typename IndicesT>
TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
const TfLiteTensor* updates,
const TfLiteTensor* shape, TfLiteTensor* output) {
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(
context, CheckShapes<IndicesT>(
context, GetTensorShape(indices), GetTensorShape(updates),
GetTensorShape(shape), GetTensorData<IndicesT>(shape)));
TF_LITE_ENSURE_OK(context,
ResizeOutputTensor<IndicesT>(context, shape, output));
}
switch (updates->type) {
case kTfLiteFloat32:
return ScatterNd<IndicesT, float>(indices, updates, output);
case kTfLiteUInt8:
return ScatterNd<IndicesT, uint8_t>(indices, updates, output);
case kTfLiteInt8:
return ScatterNd<IndicesT, int8_t>(indices, updates, output);
case kTfLiteInt32:
return ScatterNd<IndicesT, int32_t>(indices, updates, output);
case kTfLiteInt64:
return ScatterNd<IndicesT, int64_t>(indices, updates, output);
default:
context->ReportError(
context, "Updates of type '%s' are not supported by scatter_nd.",
TfLiteTypeGetName(updates->type));
return kTfLiteError;
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndices);
const TfLiteTensor* updates = GetInput(context, node, kUpdates);
const TfLiteTensor* shape = GetInput(context, node, kShape);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (indices->type) {
case kTfLiteInt32:
return EvalScatterNd<int32_t>(context, indices, updates, shape, output);
default:
context->ReportError(
context, "Indices of type '%s' are not supported by scatter_nd.",
TfLiteTypeGetName(indices->type));
return kTfLiteError;
}
}
} // namespace scatter_nd
TfLiteRegistration* Register_SCATTER_ND() {
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
scatter_nd::Prepare, scatter_nd::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,349 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class ScatterNdOpModel : public SingleOpModel {
public:
ScatterNdOpModel(const TensorData& indices, const TensorData& updates,
const TensorData& shape) {
indices_ = AddInput(indices);
updates_ = AddInput(updates);
shape_ = AddInput(shape);
output_ = AddOutput(updates.type);
SetBuiltinOp(BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions,
CreateScatterNdOptions(builder_).Union());
BuildInterpreter(
{GetShape(indices_), GetShape(updates_), GetShape(shape_)});
}
template <typename T>
void SetIndices(std::initializer_list<T> data) {
PopulateTensor<T>(indices_, data);
}
template <typename T>
void SetUpdates(std::initializer_list<T> data) {
PopulateTensor<T>(updates_, data);
}
template <typename T>
void SetShape(std::initializer_list<T> data) {
PopulateTensor<T>(shape_, data);
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int indices_;
int updates_;
int shape_;
int output_;
};
TEST(ScatterNdOpTest, ScatterElementIntoVector) {
ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4}},
{TensorType_INT32, {1}});
m.SetIndices<int32_t>({4, 3, 1, 7});
m.SetUpdates<float>({9, 10, 11, 12});
m.SetShape<int32_t>({8});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({0, 11, 0, 10, 9, 0, 0, 12}));
}
TEST(ScatterNdOpTest, ScatterMatrixIntoRank3Tensor) {
ScatterNdOpModel m({TensorType_INT32, {2, 1}},
{TensorType_FLOAT32, {2, 4, 4}}, {TensorType_INT32, {3}});
m.SetIndices<int32_t>({0, 2});
m.SetUpdates<float>({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8});
m.SetShape<int32_t>({4, 4, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 4, 4}));
EXPECT_THAT(
m.GetOutput<float>(),
ElementsAreArray({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
}
TEST(ScatterNdOpTest, ScatterVectorIntoMatrix) {
ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4, 4}},
{TensorType_INT32, {2}});
m.SetIndices<int32_t>({/*0*/ 9, /*1*/ 8, /*2*/ 0, /*3*/ 1});
m.SetUpdates<float>({/*0*/ 1, 2, 3, 4,
/*1*/ 5, 6, 7, 8,
/*2*/ 9, 10, 11, 12,
/*3*/ 13, 14, 15, 16});
m.SetShape<int32_t>({10, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({10, 4}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({/*0*/ 9, 10, 11, 12,
/*1*/ 13, 14, 15, 16,
/*2*/ 0, 0, 0, 0,
/*3*/ 0, 0, 0, 0,
/*4*/ 0, 0, 0, 0,
/*5*/ 0, 0, 0, 0,
/*6*/ 0, 0, 0, 0,
/*7*/ 0, 0, 0, 0,
/*8*/ 5, 6, 7, 8,
/*9*/ 1, 2, 3, 4}));
}
TEST(ScatterNdOpTest, ScatterMatricesIntoRank4Tensor) {
ScatterNdOpModel m({TensorType_INT32, {2, 2, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2}},
{TensorType_INT32, {4}});
m.SetIndices<int32_t>(
{/*0,0*/ 1, 1, /*0,1*/ 0, 1, /*1,0*/ 0, 0, /*1,1*/ 1, 0});
m.SetUpdates<float>({/*0,0*/ 1, 2, 3, 4, /*0,1*/ 5, 6, 7, 8,
/*1,0*/ 9, 10, 11, 12, /*1,1*/ 13, 14, 15, 16});
m.SetShape<int32_t>({2, 2, 2, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({/*0, 0*/ 9, 10, 11, 12,
/*0, 1*/ 5, 6, 7, 8,
/*1, 0*/ 13, 14, 15, 16,
/*1, 1*/ 1, 2, 3, 4}));
}
TEST(ScatterNdOpTest, ScatterVectorIntoRank4Tensor) {
ScatterNdOpModel m({TensorType_INT32, {2, 2, 3}},
{TensorType_FLOAT32, {2, 2, 5}}, {TensorType_INT32, {4}});
m.SetIndices<int32_t>(
{/*0,0*/ 2, 2, 2, /*0,1*/ 1, 0, 1, /*1,0*/ 0, 2, 0, /*1,0*/ 2, 2, 0});
m.SetUpdates<float>(
{/*0,0*/ 1, 2, 3, 4, 5, /*0,1*/ 6, 7, 8, 9, 10,
/*1,0*/ 11, 12, 13, 14, 15, /*1,1*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({3, 3, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3, 5}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({
/*0, 0, 0*/ 0, 0, 0, 0, 0,
/*0, 0, 1*/ 0, 0, 0, 0, 0,
/*0, 0, 2*/ 0, 0, 0, 0, 0,
/*0, 1, 0*/ 0, 0, 0, 0, 0,
/*0, 1, 1*/ 0, 0, 0, 0, 0,
/*0, 1, 2*/ 0, 0, 0, 0, 0,
/*0, 2, 0*/ 11, 12, 13, 14, 15,
/*0, 2, 1*/ 0, 0, 0, 0, 0,
/*0, 2, 2*/ 0, 0, 0, 0, 0,
/*1, 0, 0*/ 0, 0, 0, 0, 0,
/*1, 0, 1*/ 6, 7, 8, 9, 10,
/*1, 0, 2*/ 0, 0, 0, 0, 0,
/*1, 1, 0*/ 0, 0, 0, 0, 0,
/*1, 1, 1*/ 0, 0, 0, 0, 0,
/*1, 1, 2*/ 0, 0, 0, 0, 0,
/*1, 2, 0*/ 0, 0, 0, 0, 0,
/*1, 2, 1*/ 0, 0, 0, 0, 0,
/*1, 2, 2*/ 0, 0, 0, 0, 0,
/*2, 0, 0*/ 0, 0, 0, 0, 0,
/*2, 0, 1*/ 0, 0, 0, 0, 0,
/*2, 0, 2*/ 0, 0, 0, 0, 0,
/*2, 1, 0*/ 0, 0, 0, 0, 0,
/*2, 1, 1*/ 0, 0, 0, 0, 0,
/*2, 1, 2*/ 0, 0, 0, 0, 0,
/*2, 2, 0*/ 16, 17, 18, 19, 20,
/*2, 2, 1*/ 0, 0, 0, 0, 0,
/*2, 2, 2*/ 1, 2, 3, 4, 5,
}));
}
TEST(ScatterNdOpTest, ScatterVectorIntoRank3Tensor) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
m.SetUpdates<float>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 11, 12, 13, 14, 15,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20}));
}
TEST(ScatterNdOpTest, OverlappedIndicesSummed) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 1, 0, /*1*/ 0, 2, /*2*/ 0, 2, /*3*/ 1, 0});
m.SetUpdates<float>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 17, 19, 21, 23, 25,
/*1, 0*/ 17, 19, 21, 23, 25,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 0, 0, 0, 0, 0}));
}
TEST(ScatterNdOpTest, Int32IndicesUint8Updates) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_UINT8, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
m.SetUpdates<uint8_t>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 11, 12, 13, 14, 15,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20}));
}
TEST(ScatterNdOpTest, Int32IndicesInt8Updates) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT8, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
m.SetUpdates<int8_t>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 11, 12, 13, 14, 15,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20}));
}
TEST(ScatterNdOpTest, Int32IndicesInt32Updates) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT32, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
m.SetUpdates<int32_t>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<int32_t>(),
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 11, 12, 13, 14, 15,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20}));
}
TEST(ScatterNdOpTest, Int32IndicesInt64Updates) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
m.SetUpdates<int64_t>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<int64_t>(),
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 11, 12, 13, 14, 15,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20}));
}
TEST(ScatterNdOpTest, DynamicShape) {
ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
{TensorType_INT32, {3}});
m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
m.SetUpdates<int64_t>(
{/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
/*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
m.SetShape<int32_t>({2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
EXPECT_THAT(m.GetOutput<int64_t>(),
ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 11, 12, 13, 14, 15,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20}));
m.SetIndices<int32_t>({/*0*/ 2, 3, /*1*/ 1, 0, /*2*/ 2, 0, /*3*/ 1, 2});
m.SetShape<int32_t>({3, 4, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 4, 5}));
EXPECT_THAT(m.GetOutput<int64_t>(),
ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
/*0, 1*/ 0, 0, 0, 0, 0,
/*0, 2*/ 0, 0, 0, 0, 0,
/*0, 3*/ 0, 0, 0, 0, 0,
/*1, 0*/ 6, 7, 8, 9, 10,
/*1, 1*/ 0, 0, 0, 0, 0,
/*1, 2*/ 16, 17, 18, 19, 20,
/*1, 3*/ 0, 0, 0, 0, 0,
/*2, 0*/ 11, 12, 13, 14, 15,
/*2, 1*/ 0, 0, 0, 0, 0,
/*2, 2*/ 0, 0, 0, 0, 0,
/*2, 3*/ 1, 2, 3, 4, 5}));
}
} // namespace
} // namespace tflite

View File

@ -235,6 +235,7 @@ enum BuiltinOperator : byte {
WHILE = 119,
NON_MAX_SUPPRESSION_V4 = 120,
NON_MAX_SUPPRESSION_V5 = 121,
SCATTER_ND = 122
}
// Options for the builtin operators.
@ -334,7 +335,8 @@ union BuiltinOptions {
WhileOptions,
DepthToSpaceOptions,
NonMaxSuppressionV4Options,
NonMaxSuppressionV5Options
NonMaxSuppressionV5Options,
ScatterNdOptions
}
enum Padding : byte { SAME, VALID }
@ -812,6 +814,9 @@ table NonMaxSuppressionV4Options {
table NonMaxSuppressionV5Options {
}
table ScatterNdOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -319,6 +319,9 @@ struct NonMaxSuppressionV4OptionsT;
struct NonMaxSuppressionV5Options;
struct NonMaxSuppressionV5OptionsT;
struct ScatterNdOptions;
struct ScatterNdOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -597,11 +600,12 @@ enum BuiltinOperator {
BuiltinOperator_WHILE = 119,
BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120,
BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121,
BuiltinOperator_SCATTER_ND = 122,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_NON_MAX_SUPPRESSION_V5
BuiltinOperator_MAX = BuiltinOperator_SCATTER_ND
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[122] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[123] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -724,7 +728,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[122] {
BuiltinOperator_IF,
BuiltinOperator_WHILE,
BuiltinOperator_NON_MAX_SUPPRESSION_V4,
BuiltinOperator_NON_MAX_SUPPRESSION_V5
BuiltinOperator_NON_MAX_SUPPRESSION_V5,
BuiltinOperator_SCATTER_ND
};
return values;
}
@ -853,13 +858,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"WHILE",
"NON_MAX_SUPPRESSION_V4",
"NON_MAX_SUPPRESSION_V5",
"SCATTER_ND",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (e < BuiltinOperator_ADD || e > BuiltinOperator_NON_MAX_SUPPRESSION_V5) return "";
if (e < BuiltinOperator_ADD || e > BuiltinOperator_SCATTER_ND) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -962,11 +968,12 @@ enum BuiltinOptions {
BuiltinOptions_DepthToSpaceOptions = 94,
BuiltinOptions_NonMaxSuppressionV4Options = 95,
BuiltinOptions_NonMaxSuppressionV5Options = 96,
BuiltinOptions_ScatterNdOptions = 97,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_NonMaxSuppressionV5Options
BuiltinOptions_MAX = BuiltinOptions_ScatterNdOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[97] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[98] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1064,7 +1071,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[97] {
BuiltinOptions_WhileOptions,
BuiltinOptions_DepthToSpaceOptions,
BuiltinOptions_NonMaxSuppressionV4Options,
BuiltinOptions_NonMaxSuppressionV5Options
BuiltinOptions_NonMaxSuppressionV5Options,
BuiltinOptions_ScatterNdOptions
};
return values;
}
@ -1168,13 +1176,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"DepthToSpaceOptions",
"NonMaxSuppressionV4Options",
"NonMaxSuppressionV5Options",
"ScatterNdOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (e < BuiltinOptions_NONE || e > BuiltinOptions_NonMaxSuppressionV5Options) return "";
if (e < BuiltinOptions_NONE || e > BuiltinOptions_ScatterNdOptions) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1567,6 +1576,10 @@ template<> struct BuiltinOptionsTraits<NonMaxSuppressionV5Options> {
static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options;
};
template<> struct BuiltinOptionsTraits<ScatterNdOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2367,6 +2380,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_NonMaxSuppressionV5Options ?
reinterpret_cast<const NonMaxSuppressionV5OptionsT *>(value) : nullptr;
}
ScatterNdOptionsT *AsScatterNdOptions() {
return type == BuiltinOptions_ScatterNdOptions ?
reinterpret_cast<ScatterNdOptionsT *>(value) : nullptr;
}
const ScatterNdOptionsT *AsScatterNdOptions() const {
return type == BuiltinOptions_ScatterNdOptions ?
reinterpret_cast<const ScatterNdOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -8226,6 +8247,46 @@ inline flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5
flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct ScatterNdOptionsT : public flatbuffers::NativeTable {
typedef ScatterNdOptions TableType;
ScatterNdOptionsT() {
}
};
struct ScatterNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ScatterNdOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
ScatterNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(ScatterNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<ScatterNdOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct ScatterNdOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit ScatterNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
ScatterNdOptionsBuilder &operator=(const ScatterNdOptionsBuilder &);
flatbuffers::Offset<ScatterNdOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<ScatterNdOptions>(end);
return o;
}
};
inline flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
ScatterNdOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -8650,6 +8711,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const {
return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV5Options ? static_cast<const NonMaxSuppressionV5Options *>(builtin_options()) : nullptr;
}
const ScatterNdOptions *builtin_options_as_ScatterNdOptions() const {
return builtin_options_type() == BuiltinOptions_ScatterNdOptions ? static_cast<const ScatterNdOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -9070,6 +9134,10 @@ template<> inline const NonMaxSuppressionV5Options *Operator::builtin_options_as
return builtin_options_as_NonMaxSuppressionV5Options();
}
template<> inline const ScatterNdOptions *Operator::builtin_options_as<ScatterNdOptions>() const {
return builtin_options_as_ScatterNdOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -12225,6 +12293,29 @@ inline flatbuffers::Offset<NonMaxSuppressionV5Options> CreateNonMaxSuppressionV5
_fbb);
}
inline ScatterNdOptionsT *ScatterNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ScatterNdOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void ScatterNdOptions::UnPackTo(ScatterNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<ScatterNdOptions> ScatterNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateScatterNdOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<ScatterNdOptions> CreateScatterNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ScatterNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateScatterNdOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -12902,6 +12993,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const NonMaxSuppressionV5Options *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_ScatterNdOptions: {
auto ptr = reinterpret_cast<const ScatterNdOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -13304,6 +13399,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const NonMaxSuppressionV5Options *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_ScatterNdOptions: {
auto ptr = reinterpret_cast<const ScatterNdOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -13694,6 +13793,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const NonMaxSuppressionV5OptionsT *>(value);
return CreateNonMaxSuppressionV5Options(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_ScatterNdOptions: {
auto ptr = reinterpret_cast<const ScatterNdOptionsT *>(value);
return CreateScatterNdOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -14084,6 +14187,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new NonMaxSuppressionV5OptionsT(*reinterpret_cast<NonMaxSuppressionV5OptionsT *>(u.value));
break;
}
case BuiltinOptions_ScatterNdOptions: {
value = new ScatterNdOptionsT(*reinterpret_cast<ScatterNdOptionsT *>(u.value));
break;
}
default:
break;
}
@ -14571,6 +14678,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_ScatterNdOptions: {
auto ptr = reinterpret_cast<ScatterNdOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;