Support Reverse_v2

PiperOrigin-RevId: 230461952
This commit is contained in:
A. Unique TensorFlower 2019-01-22 20:19:06 -08:00 committed by TensorFlower Gardener
parent 8160760620
commit d5ee710a7f
19 changed files with 563 additions and 7 deletions

View File

@ -291,6 +291,7 @@ def generated_test_models():
"relu6", "relu6",
"reshape", "reshape",
"resize_bilinear", "resize_bilinear",
"reverse_v2",
"rsqrt", "rsqrt",
"shape", "shape",
"sigmoid", "sigmoid",

View File

@ -130,6 +130,7 @@ typedef enum {
kTfLiteBuiltinSplitV = 102, kTfLiteBuiltinSplitV = 102,
kTfLiteBuiltinUnique = 103, kTfLiteBuiltinUnique = 103,
kTfLiteBuiltinCeil = 104, kTfLiteBuiltinCeil = 104,
kTfLiteBuiltinReverseV2 = 105,
} TfLiteBuiltinOperator; } TfLiteBuiltinOperator;
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -726,6 +726,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_FLOOR_MOD: case BuiltinOperator_FLOOR_MOD:
case BuiltinOperator_RANGE: case BuiltinOperator_RANGE:
case BuiltinOperator_SQUARED_DIFFERENCE: case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_REVERSE_V2:
break; break;
} }
return kTfLiteOk; return kTfLiteOk;

View File

@ -200,6 +200,7 @@ cc_library(
"reshape.cc", "reshape.cc",
"resize_bilinear.cc", "resize_bilinear.cc",
"resize_nearest_neighbor.cc", "resize_nearest_neighbor.cc",
"reverse.cc",
"select.cc", "select.cc",
"shape.cc", "shape.cc",
"skip_gram.cc", "skip_gram.cc",
@ -1258,6 +1259,18 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "reverse_test",
size = "small",
srcs = ["reverse_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
filegroup( filegroup(
name = "all_files", name = "all_files",
srcs = glob( srcs = glob(

View File

@ -4720,6 +4720,33 @@ void Fill(const RuntimeShape& value_shape, const T* value_data,
} }
} }
template <typename Scalar>
void Reverse(int axis, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape& output_shape,
Scalar* output_data) {
gemmlowp::ScopedProfilingLabel label("Reverse");
int outer_size = 1;
for (int i = 0; i < axis; ++i) {
outer_size *= input_shape.Dims(i);
}
int copy_size = 1;
for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
copy_size *= input_shape.Dims(i);
}
const int dims_at_axis = input_shape.Dims(axis);
for (int i = 0; i < outer_size; ++i) {
for (int j = 0; j < dims_at_axis; ++j) {
const int start_pos = (i * dims_at_axis + j) * copy_size;
Scalar* output_ptr = output_data + start_pos;
int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
}
}
}
} // namespace reference_ops } // namespace reference_ops
} // namespace tflite } // namespace tflite

View File

@ -131,6 +131,7 @@ TfLiteRegistration* Register_SQUARED_DIFFERENCE();
TfLiteRegistration* Register_FILL(); TfLiteRegistration* Register_FILL();
TfLiteRegistration* Register_MIRROR_PAD(); TfLiteRegistration* Register_MIRROR_PAD();
TfLiteRegistration* Register_UNIQUE(); TfLiteRegistration* Register_UNIQUE();
TfLiteRegistration* Register_REVERSE_V2();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError( context->ReportError(
@ -288,6 +289,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_FILL, Register_FILL()); AddBuiltin(BuiltinOperator_FILL, Register_FILL());
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE()); AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default. // custom ops aren't always included by default.

View File

@ -0,0 +1,127 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace reverse {
namespace {
constexpr int kInputTensor = 0;
constexpr int kAxisTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
input->type != kTfLiteInt64) {
context->ReportError(context, "Type '%s' is not supported by reverse.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
if (axis->type != kTfLiteInt32) {
context->ReportError(context, "Axis Type '%s' is not supported by reverse.",
TfLiteTypeGetName(axis->type));
return kTfLiteError;
}
// TODO(renjieliu): support multi-axis case.
if (NumElements(axis) > 1) {
context->ReportError(context, "Current does not support more than 1 axis.");
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
return context->ResizeTensor(context, output, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
int axis = GetTensorData<int32_t>(axis_tensor)[0];
TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {
case kTfLiteFloat32: {
reference_ops::Reverse<float>(
axis, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
break;
}
case kTfLiteUInt8: {
reference_ops::Reverse<uint8_t>(
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
break;
}
case kTfLiteInt16: {
reference_ops::Reverse<int16_t>(
axis, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
break;
}
case kTfLiteInt32: {
reference_ops::Reverse<int32_t>(
axis, GetTensorShape(input), GetTensorData<int32_t>(input),
GetTensorShape(output), GetTensorData<int32_t>(output));
break;
}
case kTfLiteInt64: {
reference_ops::Reverse<int64_t>(
axis, GetTensorShape(input), GetTensorData<int64_t>(input),
GetTensorShape(output), GetTensorData<int64_t>(output));
break;
}
default: {
context->ReportError(context, "Type '%s' is not supported by reverse.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
} // namespace
} // namespace reverse
TfLiteRegistration* Register_REVERSE_V2() {
static TfLiteRegistration r = {nullptr, nullptr, reverse::Prepare,
reverse::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,199 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
template <typename T>
class ReverseOpModel : public SingleOpModel {
public:
ReverseOpModel(const TensorData& input, const TensorData& axis) {
input_ = AddInput(input);
axis_ = AddInput(axis);
output_ = AddOutput({input.type, {}});
SetBuiltinOp(BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options,
CreateReverseV2Options(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
int axis() { return axis_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int axis_;
int output_;
};
// float32 tests.
TEST(ReverseOpTest, FloatOneDimension) {
ReverseOpModel<float> model({TensorType_FLOAT32, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<float>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, FloatMultiDimensions) {
ReverseOpModel<float> model({TensorType_FLOAT32, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<float>(model.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int32 tests
TEST(ReverseOpTest, Int32OneDimension) {
ReverseOpModel<int32_t> model({TensorType_INT32, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Int32MultiDimensions) {
ReverseOpModel<int32_t> model({TensorType_INT32, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int32_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int64 tests
TEST(ReverseOpTest, Int64OneDimension) {
ReverseOpModel<int64_t> model({TensorType_INT64, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<int64_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Int64MultiDimensions) {
ReverseOpModel<int64_t> model({TensorType_INT64, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int64_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// uint8 tests
TEST(ReverseOpTest, Uint8OneDimension) {
ReverseOpModel<uint8_t> model({TensorType_UINT8, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<uint8_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Uint8MultiDimensions) {
ReverseOpModel<uint8_t> model({TensorType_UINT8, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<uint8_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int16 tests
TEST(ReverseOpTest, Int16OneDimension) {
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<int16_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Int16MultiDimensions) {
ReverseOpModel<int16_t> model({TensorType_INT16, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int16_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -663,6 +663,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_SPLIT_V: case tflite::BuiltinOperator_SPLIT_V:
case tflite::BuiltinOperator_UNIQUE: case tflite::BuiltinOperator_UNIQUE:
case tflite::BuiltinOperator_CEIL: case tflite::BuiltinOperator_CEIL:
case tflite::BuiltinOperator_REVERSE_V2:
logError("Op code %d is currently not delegated to NNAPI", builtin); logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError; return kTfLiteError;
break; break;

View File

@ -218,6 +218,7 @@ enum BuiltinOperator : byte {
SPLIT_V = 102, SPLIT_V = 102,
UNIQUE = 103, UNIQUE = 103,
CEIL = 104, CEIL = 104,
REVERSE_V2 = 105,
} }
// Options for the builtin operators. // Options for the builtin operators.
@ -302,6 +303,7 @@ union BuiltinOptions {
AbsOptions, AbsOptions,
SplitVOptions, SplitVOptions,
UniqueOptions, UniqueOptions,
ReverseV2Options,
} }
enum Padding : byte { SAME, VALID } enum Padding : byte { SAME, VALID }
@ -719,6 +721,9 @@ table UniqueOptions {
idx_out_type:TensorType = INT32; idx_out_type:TensorType = INT32;
} }
table ReverseV2Options {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom. // builtin, or a string if the operator is custom.
table OperatorCode { table OperatorCode {

View File

@ -271,6 +271,9 @@ struct MirrorPadOptionsT;
struct UniqueOptions; struct UniqueOptions;
struct UniqueOptionsT; struct UniqueOptionsT;
struct ReverseV2Options;
struct ReverseV2OptionsT;
struct OperatorCode; struct OperatorCode;
struct OperatorCodeT; struct OperatorCodeT;
@ -525,11 +528,12 @@ enum BuiltinOperator {
BuiltinOperator_SPLIT_V = 102, BuiltinOperator_SPLIT_V = 102,
BuiltinOperator_UNIQUE = 103, BuiltinOperator_UNIQUE = 103,
BuiltinOperator_CEIL = 104, BuiltinOperator_CEIL = 104,
BuiltinOperator_REVERSE_V2 = 105,
BuiltinOperator_MIN = BuiltinOperator_ADD, BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_CEIL BuiltinOperator_MAX = BuiltinOperator_REVERSE_V2
}; };
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[104] { inline const BuiltinOperator (&EnumValuesBuiltinOperator())[105] {
static const BuiltinOperator values[] = { static const BuiltinOperator values[] = {
BuiltinOperator_ADD, BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D, BuiltinOperator_AVERAGE_POOL_2D,
@ -634,7 +638,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[104] {
BuiltinOperator_ABS, BuiltinOperator_ABS,
BuiltinOperator_SPLIT_V, BuiltinOperator_SPLIT_V,
BuiltinOperator_UNIQUE, BuiltinOperator_UNIQUE,
BuiltinOperator_CEIL BuiltinOperator_CEIL,
BuiltinOperator_REVERSE_V2
}; };
return values; return values;
} }
@ -746,6 +751,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"SPLIT_V", "SPLIT_V",
"UNIQUE", "UNIQUE",
"CEIL", "CEIL",
"REVERSE_V2",
nullptr nullptr
}; };
return names; return names;
@ -838,11 +844,12 @@ enum BuiltinOptions {
BuiltinOptions_AbsOptions = 78, BuiltinOptions_AbsOptions = 78,
BuiltinOptions_SplitVOptions = 79, BuiltinOptions_SplitVOptions = 79,
BuiltinOptions_UniqueOptions = 80, BuiltinOptions_UniqueOptions = 80,
BuiltinOptions_ReverseV2Options = 81,
BuiltinOptions_MIN = BuiltinOptions_NONE, BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_UniqueOptions BuiltinOptions_MAX = BuiltinOptions_ReverseV2Options
}; };
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] { inline const BuiltinOptions (&EnumValuesBuiltinOptions())[82] {
static const BuiltinOptions values[] = { static const BuiltinOptions values[] = {
BuiltinOptions_NONE, BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions, BuiltinOptions_Conv2DOptions,
@ -924,7 +931,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] {
BuiltinOptions_MirrorPadOptions, BuiltinOptions_MirrorPadOptions,
BuiltinOptions_AbsOptions, BuiltinOptions_AbsOptions,
BuiltinOptions_SplitVOptions, BuiltinOptions_SplitVOptions,
BuiltinOptions_UniqueOptions BuiltinOptions_UniqueOptions,
BuiltinOptions_ReverseV2Options
}; };
return values; return values;
} }
@ -1012,6 +1020,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"AbsOptions", "AbsOptions",
"SplitVOptions", "SplitVOptions",
"UniqueOptions", "UniqueOptions",
"ReverseV2Options",
nullptr nullptr
}; };
return names; return names;
@ -1346,6 +1355,10 @@ template<> struct BuiltinOptionsTraits<UniqueOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions; static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions;
}; };
template<> struct BuiltinOptionsTraits<ReverseV2Options> {
static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options;
};
struct BuiltinOptionsUnion { struct BuiltinOptionsUnion {
BuiltinOptions type; BuiltinOptions type;
void *value; void *value;
@ -2017,6 +2030,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_UniqueOptions ? return type == BuiltinOptions_UniqueOptions ?
reinterpret_cast<const UniqueOptionsT *>(value) : nullptr; reinterpret_cast<const UniqueOptionsT *>(value) : nullptr;
} }
ReverseV2OptionsT *AsReverseV2Options() {
return type == BuiltinOptions_ReverseV2Options ?
reinterpret_cast<ReverseV2OptionsT *>(value) : nullptr;
}
const ReverseV2OptionsT *AsReverseV2Options() const {
return type == BuiltinOptions_ReverseV2Options ?
reinterpret_cast<const ReverseV2OptionsT *>(value) : nullptr;
}
}; };
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7113,6 +7134,46 @@ inline flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(
flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct ReverseV2OptionsT : public flatbuffers::NativeTable {
typedef ReverseV2Options TableType;
ReverseV2OptionsT() {
}
};
struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ReverseV2OptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
ReverseV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<ReverseV2Options> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct ReverseV2OptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit ReverseV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
ReverseV2OptionsBuilder &operator=(const ReverseV2OptionsBuilder &);
flatbuffers::Offset<ReverseV2Options> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<ReverseV2Options>(end);
return o;
}
};
inline flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(
flatbuffers::FlatBufferBuilder &_fbb) {
ReverseV2OptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable { struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType; typedef OperatorCode TableType;
BuiltinOperator builtin_code; BuiltinOperator builtin_code;
@ -7486,6 +7547,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const UniqueOptions *builtin_options_as_UniqueOptions() const { const UniqueOptions *builtin_options_as_UniqueOptions() const {
return builtin_options_type() == BuiltinOptions_UniqueOptions ? static_cast<const UniqueOptions *>(builtin_options()) : nullptr; return builtin_options_type() == BuiltinOptions_UniqueOptions ? static_cast<const UniqueOptions *>(builtin_options()) : nullptr;
} }
const ReverseV2Options *builtin_options_as_ReverseV2Options() const {
return builtin_options_type() == BuiltinOptions_ReverseV2Options ? static_cast<const ReverseV2Options *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const { const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
} }
@ -7837,6 +7901,10 @@ template<> inline const UniqueOptions *Operator::builtin_options_as<UniqueOption
return builtin_options_as_UniqueOptions(); return builtin_options_as_UniqueOptions();
} }
template<> inline const ReverseV2Options *Operator::builtin_options_as<ReverseV2Options>() const {
return builtin_options_as_ReverseV2Options();
}
struct OperatorBuilder { struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_; flatbuffers::uoffset_t start_;
@ -10484,6 +10552,29 @@ inline flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatB
_idx_out_type); _idx_out_type);
} }
inline ReverseV2OptionsT *ReverseV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ReverseV2OptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void ReverseV2Options::UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<ReverseV2Options> ReverseV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateReverseV2Options(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReverseV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateReverseV2Options(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT(); auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver); UnPackTo(_o, _resolver);
@ -11062,6 +11153,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const UniqueOptions *>(obj); auto ptr = reinterpret_cast<const UniqueOptions *>(obj);
return verifier.VerifyTable(ptr); return verifier.VerifyTable(ptr);
} }
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false; default: return false;
} }
} }
@ -11400,6 +11495,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const UniqueOptions *>(obj); auto ptr = reinterpret_cast<const UniqueOptions *>(obj);
return ptr->UnPack(resolver); return ptr->UnPack(resolver);
} }
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr; default: return nullptr;
} }
} }
@ -11726,6 +11825,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const UniqueOptionsT *>(value); auto ptr = reinterpret_cast<const UniqueOptionsT *>(value);
return CreateUniqueOptions(_fbb, ptr, _rehasher).Union(); return CreateUniqueOptions(_fbb, ptr, _rehasher).Union();
} }
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<const ReverseV2OptionsT *>(value);
return CreateReverseV2Options(_fbb, ptr, _rehasher).Union();
}
default: return 0; default: return 0;
} }
} }
@ -12052,6 +12155,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new UniqueOptionsT(*reinterpret_cast<UniqueOptionsT *>(u.value)); value = new UniqueOptionsT(*reinterpret_cast<UniqueOptionsT *>(u.value));
break; break;
} }
case BuiltinOptions_ReverseV2Options: {
value = new ReverseV2OptionsT(*reinterpret_cast<ReverseV2OptionsT *>(u.value));
break;
}
default: default:
break; break;
} }
@ -12459,6 +12566,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr; delete ptr;
break; break;
} }
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<ReverseV2OptionsT *>(value);
delete ptr;
break;
}
default: break; default: break;
} }
value = nullptr; value = nullptr;

View File

@ -3895,6 +3895,37 @@ def make_unique_tests(zip_path):
build_inputs, build_inputs,
expected_tf_success=9) expected_tf_success=9)
def make_reverse_v2_tests(zip_path):
"""Make a set of tests to do reverse_v2."""
test_parameters = [{
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
"axis": [0, 1, 2, 3],
}]
def get_valid_axis(parameters):
"""Return a tweaked version of 'axis'."""
axis = parameters["axis"]
shape = parameters["base_shape"][:]
while axis > len(shape) - 1:
axis -= 1
return axis
def build_graph(parameters):
input_tensor = tf.placeholder(
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
outs = tf.reverse(input_tensor, axis=[get_valid_axis(parameters)])
return [input_tensor], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
# Toco binary path provided by the generate rule. # Toco binary path provided by the generate rule.
bin_path = None bin_path = None

View File

@ -2062,6 +2062,20 @@ void ConvertZerosLikeOperator(const Model& model,
(*zeros_like_op->mutable_attr())["T"].set_type(data_type); (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
} }
void ConvertReverseV2Operator(const Model& model,
const ReverseV2Operator& src_op,
const char* op_name, GraphDef* tensorflow_graph) {
tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
reverse_v2_op->set_op(op_name);
reverse_v2_op->set_name(src_op.outputs[0]);
DCHECK_EQ(src_op.inputs.size(), 2);
*reverse_v2_op->add_input() = src_op.inputs[0];
*reverse_v2_op->add_input() = src_op.inputs[1];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
(*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertOperator(const Model& model, const Operator& src_op, void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) { GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@ -2341,6 +2355,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertZerosLikeOperator( ConvertZerosLikeOperator(
model, static_cast<const TensorFlowZerosLikeOperator&>(src_op), model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
"ZerosLike", tensorflow_graph); "ZerosLike", tensorflow_graph);
} else if (src_op.type == OperatorType::kReverseV2) {
ConvertReverseV2Operator(model,
static_cast<const ReverseV2Operator&>(src_op),
"Reverse_V2", tensorflow_graph);
} else { } else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
} }

View File

@ -2009,6 +2009,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
case OperatorType::kLogicalNot: case OperatorType::kLogicalNot:
case OperatorType::kLogicalOr: case OperatorType::kLogicalOr:
case OperatorType::kZerosLike: case OperatorType::kZerosLike:
case OperatorType::kReverseV2:
ProcessSimpleOperator(model, op, 0); ProcessSimpleOperator(model, op, 0);
break; break;
case OperatorType::kGather: case OperatorType::kGather:

View File

@ -2459,6 +2459,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>}, {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
{"ResizeBilinear", ConvertResizeBilinearOperator}, {"ResizeBilinear", ConvertResizeBilinearOperator},
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator}, {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
{"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>}, {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>}, {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
{"Shape", ConvertShapeOperator}, {"Shape", ConvertShapeOperator},

View File

@ -161,7 +161,8 @@ enum class OperatorType : uint8 {
kMirrorPad, kMirrorPad,
kUnique, kUnique,
kUnidirectionalSequenceRnn, kUnidirectionalSequenceRnn,
kBidirectionalSequenceLstm kBidirectionalSequenceLstm,
kReverseV2
}; };
// Helper to deal with TensorFlow arrays using a different ordering of // Helper to deal with TensorFlow arrays using a different ordering of
@ -1958,6 +1959,16 @@ struct TensorFlowZerosLikeOperator : Operator {
TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {} TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
}; };
// ReverseV2 operator:
//
// Inputs:
// Inputs[0]: required: the input array.
//
// TensorFlow equivalent: ReverseV2.
struct ReverseV2Operator : Operator {
ReverseV2Operator() : Operator(OperatorType::kReverseV2) {}
};
enum class MirrorPadMode { kNone, kSymmetric, kReflect }; enum class MirrorPadMode { kNone, kSymmetric, kReflect };
// MirrorPad Operator: // MirrorPad Operator:

View File

@ -2036,6 +2036,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs)); MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
ops.push_back( ops.push_back(
MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill)); MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
"REVERSE_V2", OperatorType::kReverseV2));
return ops; return ops;
} }
} // namespace } // namespace

View File

@ -151,6 +151,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod); CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange); CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill); CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
OperatorType::kReverseV2);
} }
TEST_F(OperatorTest, BuiltinAdd) { TEST_F(OperatorTest, BuiltinAdd) {

View File

@ -420,6 +420,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(MirrorPad) HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
HANDLE_OPERATORTYPENAME_CASE(Unique) HANDLE_OPERATORTYPENAME_CASE(Unique)
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn) HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
default: default:
LOG(FATAL) << "Unhandled op type"; LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE #undef HANDLE_OPERATORTYPENAME_CASE