Support Reverse_v2

PiperOrigin-RevId: 230461952
This commit is contained in:
A. Unique TensorFlower 2019-01-22 20:19:06 -08:00 committed by TensorFlower Gardener
parent 8160760620
commit d5ee710a7f
19 changed files with 563 additions and 7 deletions

View File

@ -291,6 +291,7 @@ def generated_test_models():
"relu6",
"reshape",
"resize_bilinear",
"reverse_v2",
"rsqrt",
"shape",
"sigmoid",

View File

@ -130,6 +130,7 @@ typedef enum {
kTfLiteBuiltinSplitV = 102,
kTfLiteBuiltinUnique = 103,
kTfLiteBuiltinCeil = 104,
kTfLiteBuiltinReverseV2 = 105,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -726,6 +726,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_FLOOR_MOD:
case BuiltinOperator_RANGE:
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_REVERSE_V2:
break;
}
return kTfLiteOk;

View File

@ -200,6 +200,7 @@ cc_library(
"reshape.cc",
"resize_bilinear.cc",
"resize_nearest_neighbor.cc",
"reverse.cc",
"select.cc",
"shape.cc",
"skip_gram.cc",
@ -1258,6 +1259,18 @@ tf_cc_test(
],
)
tf_cc_test(
name = "reverse_test",
size = "small",
srcs = ["reverse_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -4720,6 +4720,33 @@ void Fill(const RuntimeShape& value_shape, const T* value_data,
}
}
template <typename Scalar>
void Reverse(int axis, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape& output_shape,
Scalar* output_data) {
gemmlowp::ScopedProfilingLabel label("Reverse");
int outer_size = 1;
for (int i = 0; i < axis; ++i) {
outer_size *= input_shape.Dims(i);
}
int copy_size = 1;
for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
copy_size *= input_shape.Dims(i);
}
const int dims_at_axis = input_shape.Dims(axis);
for (int i = 0; i < outer_size; ++i) {
for (int j = 0; j < dims_at_axis; ++j) {
const int start_pos = (i * dims_at_axis + j) * copy_size;
Scalar* output_ptr = output_data + start_pos;
int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
}
}
}
} // namespace reference_ops
} // namespace tflite

View File

@ -131,6 +131,7 @@ TfLiteRegistration* Register_SQUARED_DIFFERENCE();
TfLiteRegistration* Register_FILL();
TfLiteRegistration* Register_MIRROR_PAD();
TfLiteRegistration* Register_UNIQUE();
TfLiteRegistration* Register_REVERSE_V2();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@ -288,6 +289,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -0,0 +1,127 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace reverse {
namespace {
constexpr int kInputTensor = 0;
constexpr int kAxisTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
input->type != kTfLiteInt64) {
context->ReportError(context, "Type '%s' is not supported by reverse.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
if (axis->type != kTfLiteInt32) {
context->ReportError(context, "Axis Type '%s' is not supported by reverse.",
TfLiteTypeGetName(axis->type));
return kTfLiteError;
}
// TODO(renjieliu): support multi-axis case.
if (NumElements(axis) > 1) {
context->ReportError(context, "Current does not support more than 1 axis.");
}
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
return context->ResizeTensor(context, output, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
int axis = GetTensorData<int32_t>(axis_tensor)[0];
TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {
case kTfLiteFloat32: {
reference_ops::Reverse<float>(
axis, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
break;
}
case kTfLiteUInt8: {
reference_ops::Reverse<uint8_t>(
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
break;
}
case kTfLiteInt16: {
reference_ops::Reverse<int16_t>(
axis, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
break;
}
case kTfLiteInt32: {
reference_ops::Reverse<int32_t>(
axis, GetTensorShape(input), GetTensorData<int32_t>(input),
GetTensorShape(output), GetTensorData<int32_t>(output));
break;
}
case kTfLiteInt64: {
reference_ops::Reverse<int64_t>(
axis, GetTensorShape(input), GetTensorData<int64_t>(input),
GetTensorShape(output), GetTensorData<int64_t>(output));
break;
}
default: {
context->ReportError(context, "Type '%s' is not supported by reverse.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
} // namespace
} // namespace reverse
TfLiteRegistration* Register_REVERSE_V2() {
static TfLiteRegistration r = {nullptr, nullptr, reverse::Prepare,
reverse::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,199 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
template <typename T>
class ReverseOpModel : public SingleOpModel {
public:
ReverseOpModel(const TensorData& input, const TensorData& axis) {
input_ = AddInput(input);
axis_ = AddInput(axis);
output_ = AddOutput({input.type, {}});
SetBuiltinOp(BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options,
CreateReverseV2Options(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
int axis() { return axis_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int axis_;
int output_;
};
// float32 tests.
TEST(ReverseOpTest, FloatOneDimension) {
ReverseOpModel<float> model({TensorType_FLOAT32, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<float>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, FloatMultiDimensions) {
ReverseOpModel<float> model({TensorType_FLOAT32, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<float>(model.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int32 tests
TEST(ReverseOpTest, Int32OneDimension) {
ReverseOpModel<int32_t> model({TensorType_INT32, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Int32MultiDimensions) {
ReverseOpModel<int32_t> model({TensorType_INT32, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int32_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int64 tests
TEST(ReverseOpTest, Int64OneDimension) {
ReverseOpModel<int64_t> model({TensorType_INT64, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<int64_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Int64MultiDimensions) {
ReverseOpModel<int64_t> model({TensorType_INT64, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int64_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// uint8 tests
TEST(ReverseOpTest, Uint8OneDimension) {
ReverseOpModel<uint8_t> model({TensorType_UINT8, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<uint8_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Uint8MultiDimensions) {
ReverseOpModel<uint8_t> model({TensorType_UINT8, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<uint8_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int16 tests
TEST(ReverseOpTest, Int16OneDimension) {
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
{TensorType_INT32, {1}});
model.PopulateTensor<int16_t>(model.input(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
}
TEST(ReverseOpTest, Int16MultiDimensions) {
ReverseOpModel<int16_t> model({TensorType_INT16, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int16_t>(
model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -663,6 +663,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_SPLIT_V:
case tflite::BuiltinOperator_UNIQUE:
case tflite::BuiltinOperator_CEIL:
case tflite::BuiltinOperator_REVERSE_V2:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -218,6 +218,7 @@ enum BuiltinOperator : byte {
SPLIT_V = 102,
UNIQUE = 103,
CEIL = 104,
REVERSE_V2 = 105,
}
// Options for the builtin operators.
@ -302,6 +303,7 @@ union BuiltinOptions {
AbsOptions,
SplitVOptions,
UniqueOptions,
ReverseV2Options,
}
enum Padding : byte { SAME, VALID }
@ -719,6 +721,9 @@ table UniqueOptions {
idx_out_type:TensorType = INT32;
}
table ReverseV2Options {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -271,6 +271,9 @@ struct MirrorPadOptionsT;
struct UniqueOptions;
struct UniqueOptionsT;
struct ReverseV2Options;
struct ReverseV2OptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -525,11 +528,12 @@ enum BuiltinOperator {
BuiltinOperator_SPLIT_V = 102,
BuiltinOperator_UNIQUE = 103,
BuiltinOperator_CEIL = 104,
BuiltinOperator_REVERSE_V2 = 105,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_CEIL
BuiltinOperator_MAX = BuiltinOperator_REVERSE_V2
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[104] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[105] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -634,7 +638,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[104] {
BuiltinOperator_ABS,
BuiltinOperator_SPLIT_V,
BuiltinOperator_UNIQUE,
BuiltinOperator_CEIL
BuiltinOperator_CEIL,
BuiltinOperator_REVERSE_V2
};
return values;
}
@ -746,6 +751,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"SPLIT_V",
"UNIQUE",
"CEIL",
"REVERSE_V2",
nullptr
};
return names;
@ -838,11 +844,12 @@ enum BuiltinOptions {
BuiltinOptions_AbsOptions = 78,
BuiltinOptions_SplitVOptions = 79,
BuiltinOptions_UniqueOptions = 80,
BuiltinOptions_ReverseV2Options = 81,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_UniqueOptions
BuiltinOptions_MAX = BuiltinOptions_ReverseV2Options
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[82] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -924,7 +931,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[81] {
BuiltinOptions_MirrorPadOptions,
BuiltinOptions_AbsOptions,
BuiltinOptions_SplitVOptions,
BuiltinOptions_UniqueOptions
BuiltinOptions_UniqueOptions,
BuiltinOptions_ReverseV2Options
};
return values;
}
@ -1012,6 +1020,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"AbsOptions",
"SplitVOptions",
"UniqueOptions",
"ReverseV2Options",
nullptr
};
return names;
@ -1346,6 +1355,10 @@ template<> struct BuiltinOptionsTraits<UniqueOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions;
};
template<> struct BuiltinOptionsTraits<ReverseV2Options> {
static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2017,6 +2030,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_UniqueOptions ?
reinterpret_cast<const UniqueOptionsT *>(value) : nullptr;
}
ReverseV2OptionsT *AsReverseV2Options() {
return type == BuiltinOptions_ReverseV2Options ?
reinterpret_cast<ReverseV2OptionsT *>(value) : nullptr;
}
const ReverseV2OptionsT *AsReverseV2Options() const {
return type == BuiltinOptions_ReverseV2Options ?
reinterpret_cast<const ReverseV2OptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7113,6 +7134,46 @@ inline flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(
flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct ReverseV2OptionsT : public flatbuffers::NativeTable {
typedef ReverseV2Options TableType;
ReverseV2OptionsT() {
}
};
struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ReverseV2OptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
ReverseV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<ReverseV2Options> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct ReverseV2OptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit ReverseV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
ReverseV2OptionsBuilder &operator=(const ReverseV2OptionsBuilder &);
flatbuffers::Offset<ReverseV2Options> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<ReverseV2Options>(end);
return o;
}
};
inline flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(
flatbuffers::FlatBufferBuilder &_fbb) {
ReverseV2OptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -7486,6 +7547,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const UniqueOptions *builtin_options_as_UniqueOptions() const {
return builtin_options_type() == BuiltinOptions_UniqueOptions ? static_cast<const UniqueOptions *>(builtin_options()) : nullptr;
}
const ReverseV2Options *builtin_options_as_ReverseV2Options() const {
return builtin_options_type() == BuiltinOptions_ReverseV2Options ? static_cast<const ReverseV2Options *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -7837,6 +7901,10 @@ template<> inline const UniqueOptions *Operator::builtin_options_as<UniqueOption
return builtin_options_as_UniqueOptions();
}
template<> inline const ReverseV2Options *Operator::builtin_options_as<ReverseV2Options>() const {
return builtin_options_as_ReverseV2Options();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -10484,6 +10552,29 @@ inline flatbuffers::Offset<UniqueOptions> CreateUniqueOptions(flatbuffers::FlatB
_idx_out_type);
}
inline ReverseV2OptionsT *ReverseV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ReverseV2OptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void ReverseV2Options::UnPackTo(ReverseV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<ReverseV2Options> ReverseV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateReverseV2Options(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReverseV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateReverseV2Options(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -11062,6 +11153,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const UniqueOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -11400,6 +11495,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const UniqueOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -11726,6 +11825,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const UniqueOptionsT *>(value);
return CreateUniqueOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<const ReverseV2OptionsT *>(value);
return CreateReverseV2Options(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -12052,6 +12155,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new UniqueOptionsT(*reinterpret_cast<UniqueOptionsT *>(u.value));
break;
}
case BuiltinOptions_ReverseV2Options: {
value = new ReverseV2OptionsT(*reinterpret_cast<ReverseV2OptionsT *>(u.value));
break;
}
default:
break;
}
@ -12459,6 +12566,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_ReverseV2Options: {
auto ptr = reinterpret_cast<ReverseV2OptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -3895,6 +3895,37 @@ def make_unique_tests(zip_path):
build_inputs,
expected_tf_success=9)
def make_reverse_v2_tests(zip_path):
"""Make a set of tests to do reverse_v2."""
test_parameters = [{
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
"axis": [0, 1, 2, 3],
}]
def get_valid_axis(parameters):
"""Return a tweaked version of 'axis'."""
axis = parameters["axis"]
shape = parameters["base_shape"][:]
while axis > len(shape) - 1:
axis -= 1
return axis
def build_graph(parameters):
input_tensor = tf.placeholder(
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
outs = tf.reverse(input_tensor, axis=[get_valid_axis(parameters)])
return [input_tensor], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
# Toco binary path provided by the generate rule.
bin_path = None

View File

@ -2062,6 +2062,20 @@ void ConvertZerosLikeOperator(const Model& model,
(*zeros_like_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertReverseV2Operator(const Model& model,
const ReverseV2Operator& src_op,
const char* op_name, GraphDef* tensorflow_graph) {
tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
reverse_v2_op->set_op(op_name);
reverse_v2_op->set_name(src_op.outputs[0]);
DCHECK_EQ(src_op.inputs.size(), 2);
*reverse_v2_op->add_input() = src_op.inputs[0];
*reverse_v2_op->add_input() = src_op.inputs[1];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
(*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@ -2341,6 +2355,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertZerosLikeOperator(
model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
"ZerosLike", tensorflow_graph);
} else if (src_op.type == OperatorType::kReverseV2) {
ConvertReverseV2Operator(model,
static_cast<const ReverseV2Operator&>(src_op),
"Reverse_V2", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}

View File

@ -2009,6 +2009,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
case OperatorType::kLogicalNot:
case OperatorType::kLogicalOr:
case OperatorType::kZerosLike:
case OperatorType::kReverseV2:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:

View File

@ -2459,6 +2459,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
{"ResizeBilinear", ConvertResizeBilinearOperator},
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
{"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
{"Shape", ConvertShapeOperator},

View File

@ -161,7 +161,8 @@ enum class OperatorType : uint8 {
kMirrorPad,
kUnique,
kUnidirectionalSequenceRnn,
kBidirectionalSequenceLstm
kBidirectionalSequenceLstm,
kReverseV2
};
// Helper to deal with TensorFlow arrays using a different ordering of
@ -1958,6 +1959,16 @@ struct TensorFlowZerosLikeOperator : Operator {
TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
};
// ReverseV2 operator:
//
// Inputs:
// Inputs[0]: required: the input array.
//
// TensorFlow equivalent: ReverseV2.
struct ReverseV2Operator : Operator {
ReverseV2Operator() : Operator(OperatorType::kReverseV2) {}
};
enum class MirrorPadMode { kNone, kSymmetric, kReflect };
// MirrorPad Operator:

View File

@ -2036,6 +2036,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
ops.push_back(
MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
"REVERSE_V2", OperatorType::kReverseV2));
return ops;
}
} // namespace

View File

@ -151,6 +151,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
OperatorType::kReverseV2);
}
TEST_F(OperatorTest, BuiltinAdd) {

View File

@ -420,6 +420,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
HANDLE_OPERATORTYPENAME_CASE(Unique)
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE