Support Reverse_v2
PiperOrigin-RevId: 230461952
This commit is contained in:
parent
8160760620
commit
d5ee710a7f
@ -291,6 +291,7 @@ def generated_test_models():
|
||||
"relu6",
|
||||
"reshape",
|
||||
"resize_bilinear",
|
||||
"reverse_v2",
|
||||
"rsqrt",
|
||||
"shape",
|
||||
"sigmoid",
|
||||
|
@ -130,6 +130,7 @@ typedef enum {
|
||||
kTfLiteBuiltinSplitV = 102,
|
||||
kTfLiteBuiltinUnique = 103,
|
||||
kTfLiteBuiltinCeil = 104,
|
||||
kTfLiteBuiltinReverseV2 = 105,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -726,6 +726,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_FLOOR_MOD:
|
||||
case BuiltinOperator_RANGE:
|
||||
case BuiltinOperator_SQUARED_DIFFERENCE:
|
||||
case BuiltinOperator_REVERSE_V2:
|
||||
break;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -200,6 +200,7 @@ cc_library(
|
||||
"reshape.cc",
|
||||
"resize_bilinear.cc",
|
||||
"resize_nearest_neighbor.cc",
|
||||
"reverse.cc",
|
||||
"select.cc",
|
||||
"shape.cc",
|
||||
"skip_gram.cc",
|
||||
@ -1258,6 +1259,18 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "reverse_test",
|
||||
size = "small",
|
||||
srcs = ["reverse_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -4720,6 +4720,33 @@ void Fill(const RuntimeShape& value_shape, const T* value_data,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void Reverse(int axis, const RuntimeShape& input_shape,
|
||||
const Scalar* input_data, const RuntimeShape& output_shape,
|
||||
Scalar* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Reverse");
|
||||
|
||||
int outer_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
outer_size *= input_shape.Dims(i);
|
||||
}
|
||||
|
||||
int copy_size = 1;
|
||||
for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
|
||||
copy_size *= input_shape.Dims(i);
|
||||
}
|
||||
|
||||
const int dims_at_axis = input_shape.Dims(axis);
|
||||
for (int i = 0; i < outer_size; ++i) {
|
||||
for (int j = 0; j < dims_at_axis; ++j) {
|
||||
const int start_pos = (i * dims_at_axis + j) * copy_size;
|
||||
Scalar* output_ptr = output_data + start_pos;
|
||||
int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
|
||||
memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -131,6 +131,7 @@ TfLiteRegistration* Register_SQUARED_DIFFERENCE();
|
||||
TfLiteRegistration* Register_FILL();
|
||||
TfLiteRegistration* Register_MIRROR_PAD();
|
||||
TfLiteRegistration* Register_UNIQUE();
|
||||
TfLiteRegistration* Register_REVERSE_V2();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -288,6 +289,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
|
||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
|
||||
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
127
tensorflow/lite/kernels/reverse.cc
Normal file
127
tensorflow/lite/kernels/reverse.cc
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace reverse {
|
||||
namespace {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kAxisTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
|
||||
|
||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||
input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
|
||||
input->type != kTfLiteInt64) {
|
||||
context->ReportError(context, "Type '%s' is not supported by reverse.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (axis->type != kTfLiteInt32) {
|
||||
context->ReportError(context, "Axis Type '%s' is not supported by reverse.",
|
||||
TfLiteTypeGetName(axis->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): support multi-axis case.
|
||||
if (NumElements(axis) > 1) {
|
||||
context->ReportError(context, "Current does not support more than 1 axis.");
|
||||
}
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
||||
|
||||
return context->ResizeTensor(context, output, output_shape);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
|
||||
int axis = GetTensorData<int32_t>(axis_tensor)[0];
|
||||
|
||||
TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input));
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32: {
|
||||
reference_ops::Reverse<float>(
|
||||
axis, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
reference_ops::Reverse<uint8_t>(
|
||||
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt16: {
|
||||
reference_ops::Reverse<int16_t>(
|
||||
axis, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt32: {
|
||||
reference_ops::Reverse<int32_t>(
|
||||
axis, GetTensorShape(input), GetTensorData<int32_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int32_t>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt64: {
|
||||
reference_ops::Reverse<int64_t>(
|
||||
axis, GetTensorShape(input), GetTensorData<int64_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int64_t>(output));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
context->ReportError(context, "Type '%s' is not supported by reverse.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace reverse
|
||||
|
||||
TfLiteRegistration* Register_REVERSE_V2() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, reverse::Prepare,
|
||||
reverse::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
199
tensorflow/lite/kernels/reverse_test.cc
Normal file
199
tensorflow/lite/kernels/reverse_test.cc
Normal file
@ -0,0 +1,199 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class ReverseOpModel : public SingleOpModel {
|
||||
public:
|
||||
ReverseOpModel(const TensorData& input, const TensorData& axis) {
|
||||
input_ = AddInput(input);
|
||||
axis_ = AddInput(axis);
|
||||
|
||||
output_ = AddOutput({input.type, {}});
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options,
|
||||
CreateReverseV2Options(builder_).Union());
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
int axis() { return axis_; }
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int axis_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
// float32 tests.
|
||||
TEST(ReverseOpTest, FloatOneDimension) {
|
||||
ReverseOpModel<float> model({TensorType_FLOAT32, {4}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<float>(model.input(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(ReverseOpTest, FloatMultiDimensions) {
|
||||
ReverseOpModel<float> model({TensorType_FLOAT32, {4, 3, 2}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<float>(model.input(),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
|
||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||
}
|
||||
|
||||
// int32 tests
|
||||
TEST(ReverseOpTest, Int32OneDimension) {
|
||||
ReverseOpModel<int32_t> model({TensorType_INT32, {4}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(ReverseOpTest, Int32MultiDimensions) {
|
||||
ReverseOpModel<int32_t> model({TensorType_INT32, {4, 3, 2}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int32_t>(
|
||||
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
|
||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||
}
|
||||
|
||||
// int64 tests
|
||||
TEST(ReverseOpTest, Int64OneDimension) {
|
||||
ReverseOpModel<int64_t> model({TensorType_INT64, {4}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int64_t>(model.input(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(ReverseOpTest, Int64MultiDimensions) {
|
||||
ReverseOpModel<int64_t> model({TensorType_INT64, {4, 3, 2}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int64_t>(
|
||||
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
|
||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||
}
|
||||
|
||||
// uint8 tests
|
||||
TEST(ReverseOpTest, Uint8OneDimension) {
|
||||
ReverseOpModel<uint8_t> model({TensorType_UINT8, {4}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<uint8_t>(model.input(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(ReverseOpTest, Uint8MultiDimensions) {
|
||||
ReverseOpModel<uint8_t> model({TensorType_UINT8, {4, 3, 2}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<uint8_t>(
|
||||
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
|
||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||
}
|
||||
|
||||
// int16 tests
|
||||
TEST(ReverseOpTest, Int16OneDimension) {
|
||||
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int16_t>(model.input(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(ReverseOpTest, Int16MultiDimensions) {
|
||||
ReverseOpModel<int16_t> model({TensorType_INT16, {4, 3, 2}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int16_t>(
|
||||
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
|
||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -663,6 +663,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_SPLIT_V:
|
||||
case tflite::BuiltinOperator_UNIQUE:
|
||||
case tflite::BuiltinOperator_CEIL:
|
||||
case tflite::BuiltinOperator_REVERSE_V2:
|
||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
return kTfLiteError;
|
||||
break;
|
||||
|
@ -218,6 +218,7 @@ enum BuiltinOperator : byte {
|
||||
SPLIT_V = 102,
|
||||
UNIQUE = 103,
|
||||
CEIL = 104,
|
||||
REVERSE_V2 = 105,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -302,6 +303,7 @@ union BuiltinOptions {
|
||||
AbsOptions,
|
||||
SplitVOptions,
|
||||
UniqueOptions,
|
||||
ReverseV2Options,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -719,6 +721,9 @@ table UniqueOptions {
|
||||
idx_out_type:TensorType = INT32;
|
||||
}
|
||||
|
||||
table ReverseV2Options {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -271,6 +271,9 @@ struct MirrorPadOptionsT;
|
||||
struct UniqueOptions;
|
||||
struct UniqueOptionsT;
|
||||
|
||||
struct ReverseV2Options;
|
||||
struct ReverseV2OptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -525,11 +528,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_SPLIT_V = 102,
|
||||
BuiltinOperator_UNIQUE = 103,
|
||||
BuiltinOperator_CEIL = 104,
|
||||
BuiltinOperator_REVERSE_V2 = 105,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_CEIL
|
||||
BuiltinOperator_MAX = BuiltinOperator_REVERSE_V2
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[104] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[105] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -634,7 +638,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[104] {
|
||||
BuiltinOperator_ABS,
|
||||
BuiltinOperator_SPLIT_V,
|
||||
BuiltinOperator_UNIQUE,
|
||||
BuiltinOperator_CEIL
|
||||
BuiltinOperator_CEIL,
|
||||
BuiltinOperator_REVERSE_V2
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -746,6 +751,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"SPLIT_V",
|
||||
"UNIQUE",
|
||||
"CEIL",
|
||||
"REVERSE_V2",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -838,11 +844,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_AbsOptions = 78,
|
||||
BuiltinOptions_SplitVOptions = 79,
|
||||
BuiltinOptions_UniqueOptions = 80,
|
||||
BuiltinOptions_ReverseV2Options = 81,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_UniqueOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_ReverseV2Options
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[82] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -924,7 +931,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] {
|
||||
BuiltinOptions_MirrorPadOptions,
|
||||
BuiltinOptions_AbsOptions,
|
||||
BuiltinOptions_SplitVOptions,
|
||||
BuiltinOptions_UniqueOptions
|
||||
BuiltinOptions_UniqueOptions,
|
||||
BuiltinOptions_ReverseV2Options
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1012,6 +1020,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"AbsOptions",
|
||||
"SplitVOptions",
|
||||
"UniqueOptions",
|
||||
"ReverseV2Options",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1346,6 +1355,10 @@ template<> struct BuiltinOptionsTraits<UniqueOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<ReverseV2Options> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2017,6 +2030,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_UniqueOptions ?
|
||||
reinterpret_cast<const UniqueOptionsT *>(value) : nullptr;
|
||||
}
|
||||
ReverseV2OptionsT *AsReverseV2Options() {
|
||||
return type == BuiltinOptions_ReverseV2Options ?
|
||||
reinterpret_cast<ReverseV2OptionsT *>(value) : nullptr;
|
||||
}
|
||||
const ReverseV2OptionsT *AsReverseV2Options() const {
|
||||
return type == BuiltinOptions_ReverseV2Options ?
|
||||
reinterpret_cast<const ReverseV2OptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -7113,6 +7134,46 @@ inline flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(
|
||||
|
||||
flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct ReverseV2OptionsT : public flatbuffers::NativeTable {
|
||||
typedef ReverseV2Options TableType;
|
||||
ReverseV2OptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef ReverseV2OptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
ReverseV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<ReverseV2Options> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct ReverseV2OptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit ReverseV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
ReverseV2OptionsBuilder &operator=(const ReverseV2OptionsBuilder &);
|
||||
flatbuffers::Offset<ReverseV2Options> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<ReverseV2Options>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
ReverseV2OptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -7486,6 +7547,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const UniqueOptions *builtin_options_as_UniqueOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_UniqueOptions ? static_cast<const UniqueOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const ReverseV2Options *builtin_options_as_ReverseV2Options() const {
|
||||
return builtin_options_type() == BuiltinOptions_ReverseV2Options ? static_cast<const ReverseV2Options *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -7837,6 +7901,10 @@ template<> inline const UniqueOptions *Operator::builtin_options_as<UniqueOption
|
||||
return builtin_options_as_UniqueOptions();
|
||||
}
|
||||
|
||||
template<> inline const ReverseV2Options *Operator::builtin_options_as<ReverseV2Options>() const {
|
||||
return builtin_options_as_ReverseV2Options();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -10484,6 +10552,29 @@ inline flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatB
|
||||
_idx_out_type);
|
||||
}
|
||||
|
||||
inline ReverseV2OptionsT *ReverseV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new ReverseV2OptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void ReverseV2Options::UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<ReverseV2Options> ReverseV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateReverseV2Options(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReverseV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateReverseV2Options(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -11062,6 +11153,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const UniqueOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_ReverseV2Options: {
|
||||
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -11400,6 +11495,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const UniqueOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_ReverseV2Options: {
|
||||
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -11726,6 +11825,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const UniqueOptionsT *>(value);
|
||||
return CreateUniqueOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_ReverseV2Options: {
|
||||
auto ptr = reinterpret_cast<const ReverseV2OptionsT *>(value);
|
||||
return CreateReverseV2Options(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -12052,6 +12155,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new UniqueOptionsT(*reinterpret_cast<UniqueOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_ReverseV2Options: {
|
||||
value = new ReverseV2OptionsT(*reinterpret_cast<ReverseV2OptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -12459,6 +12566,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_ReverseV2Options: {
|
||||
auto ptr = reinterpret_cast<ReverseV2OptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -3895,6 +3895,37 @@ def make_unique_tests(zip_path):
|
||||
build_inputs,
|
||||
expected_tf_success=9)
|
||||
|
||||
|
||||
def make_reverse_v2_tests(zip_path):
|
||||
"""Make a set of tests to do reverse_v2."""
|
||||
|
||||
test_parameters = [{
|
||||
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
|
||||
"axis": [0, 1, 2, 3],
|
||||
}]
|
||||
|
||||
def get_valid_axis(parameters):
|
||||
"""Return a tweaked version of 'axis'."""
|
||||
axis = parameters["axis"]
|
||||
shape = parameters["base_shape"][:]
|
||||
while axis > len(shape) - 1:
|
||||
axis -= 1
|
||||
return axis
|
||||
|
||||
def build_graph(parameters):
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
|
||||
outs = tf.reverse(input_tensor, axis=[get_valid_axis(parameters)])
|
||||
return [input_tensor], [outs]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
|
||||
return [input_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
# Toco binary path provided by the generate rule.
|
||||
bin_path = None
|
||||
|
||||
|
@ -2062,6 +2062,20 @@ void ConvertZerosLikeOperator(const Model& model,
|
||||
(*zeros_like_op->mutable_attr())["T"].set_type(data_type);
|
||||
}
|
||||
|
||||
void ConvertReverseV2Operator(const Model& model,
|
||||
const ReverseV2Operator& src_op,
|
||||
const char* op_name, GraphDef* tensorflow_graph) {
|
||||
tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
|
||||
reverse_v2_op->set_op(op_name);
|
||||
reverse_v2_op->set_name(src_op.outputs[0]);
|
||||
DCHECK_EQ(src_op.inputs.size(), 2);
|
||||
*reverse_v2_op->add_input() = src_op.inputs[0];
|
||||
*reverse_v2_op->add_input() = src_op.inputs[1];
|
||||
const tensorflow::DataType data_type =
|
||||
GetTensorFlowDataType(model, src_op.inputs[0]);
|
||||
(*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
|
||||
}
|
||||
|
||||
void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
GraphDef* tensorflow_graph) {
|
||||
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
|
||||
@ -2341,6 +2355,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
ConvertZerosLikeOperator(
|
||||
model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
|
||||
"ZerosLike", tensorflow_graph);
|
||||
} else if (src_op.type == OperatorType::kReverseV2) {
|
||||
ConvertReverseV2Operator(model,
|
||||
static_cast<const ReverseV2Operator&>(src_op),
|
||||
"Reverse_V2", tensorflow_graph);
|
||||
} else {
|
||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
|
||||
}
|
||||
|
@ -2009,6 +2009,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
case OperatorType::kLogicalNot:
|
||||
case OperatorType::kLogicalOr:
|
||||
case OperatorType::kZerosLike:
|
||||
case OperatorType::kReverseV2:
|
||||
ProcessSimpleOperator(model, op, 0);
|
||||
break;
|
||||
case OperatorType::kGather:
|
||||
|
@ -2459,6 +2459,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
|
||||
{"ResizeBilinear", ConvertResizeBilinearOperator},
|
||||
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
|
||||
{"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
|
||||
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
|
||||
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
|
||||
{"Shape", ConvertShapeOperator},
|
||||
|
@ -161,7 +161,8 @@ enum class OperatorType : uint8 {
|
||||
kMirrorPad,
|
||||
kUnique,
|
||||
kUnidirectionalSequenceRnn,
|
||||
kBidirectionalSequenceLstm
|
||||
kBidirectionalSequenceLstm,
|
||||
kReverseV2
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -1958,6 +1959,16 @@ struct TensorFlowZerosLikeOperator : Operator {
|
||||
TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
|
||||
};
|
||||
|
||||
// ReverseV2 operator:
|
||||
//
|
||||
// Inputs:
|
||||
// Inputs[0]: required: the input array.
|
||||
//
|
||||
// TensorFlow equivalent: ReverseV2.
|
||||
struct ReverseV2Operator : Operator {
|
||||
ReverseV2Operator() : Operator(OperatorType::kReverseV2) {}
|
||||
};
|
||||
|
||||
enum class MirrorPadMode { kNone, kSymmetric, kReflect };
|
||||
|
||||
// MirrorPad Operator:
|
||||
|
@ -2036,6 +2036,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
|
||||
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
|
||||
"REVERSE_V2", OperatorType::kReverseV2));
|
||||
return ops;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -151,6 +151,8 @@ TEST_F(OperatorTest, SimpleOperators) {
|
||||
CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
|
||||
CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
|
||||
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
|
||||
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
|
||||
OperatorType::kReverseV2);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinAdd) {
|
||||
|
@ -420,6 +420,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Unique)
|
||||
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
Reference in New Issue
Block a user