Implement Rank.

PiperOrigin-RevId: 234068627
This commit is contained in:
A. Unique TensorFlower 2019-02-14 18:48:49 -08:00 committed by TensorFlower Gardener
parent 4b7ceaee0f
commit b23578d605
21 changed files with 354 additions and 18 deletions

View File

@ -286,6 +286,7 @@ def generated_test_models():
"prelu", "prelu",
"pow", "pow",
"range", "range",
"rank",
"reduce_any", "reduce_any",
"reduce_max", "reduce_max",
"reduce_min", "reduce_min",

View File

@ -135,6 +135,7 @@ typedef enum {
kTfLiteBuiltinGatherNd = 107, kTfLiteBuiltinGatherNd = 107,
kTfLiteBuiltinCos = 108, kTfLiteBuiltinCos = 108,
kTfLiteBuiltinWhere = 109, kTfLiteBuiltinWhere = 109,
kTfLiteBuiltinRank = 110,
} TfLiteBuiltinOperator; } TfLiteBuiltinOperator;
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -333,6 +333,9 @@ typedef struct {
TfLiteType out_type; TfLiteType out_type;
} TfLiteShapeParams; } TfLiteShapeParams;
typedef struct {
} TfLiteRankParams;
typedef struct { typedef struct {
// Parameters supported by version 1: // Parameters supported by version 1:
float min; float min;

View File

@ -71,6 +71,7 @@ TEST(IntArray, CanCompileStructs) {
TfLiteTransposeConvParams transpose_conv_params; TfLiteTransposeConvParams transpose_conv_params;
TfLiteSparseToDenseParams sparse_to_dense_params; TfLiteSparseToDenseParams sparse_to_dense_params;
TfLiteShapeParams shape_params; TfLiteShapeParams shape_params;
TfLiteRankParams rank_params;
TfLiteFakeQuantParams fake_quant_params; TfLiteFakeQuantParams fake_quant_params;
TfLitePackParams pack_params; TfLitePackParams pack_params;
TfLiteOneHotParams one_hot_params; TfLiteOneHotParams one_hot_params;

View File

@ -731,6 +731,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_ADD_N: case BuiltinOperator_ADD_N:
case BuiltinOperator_GATHER_ND: case BuiltinOperator_GATHER_ND:
case BuiltinOperator_WHERE: case BuiltinOperator_WHERE:
case BuiltinOperator_RANK:
break; break;
} }
return kTfLiteOk; return kTfLiteOk;

View File

@ -179,6 +179,7 @@ class OpOptionData {
op_to_option_["LOG"] = ""; op_to_option_["LOG"] = "";
op_to_option_["SQRT"] = ""; op_to_option_["SQRT"] = "";
op_to_option_["RSQRT"] = ""; op_to_option_["RSQRT"] = "";
op_to_option_["Rank"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs // TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";

View File

@ -725,6 +725,17 @@ Options {
} }
``` ```
**RANK**
```
Inputs {
0: a tensor
}
Outputs {
0: a 0-D int32 Tensor representing the rank of input
}
```
**RELU** **RELU**
``` ```

View File

@ -199,6 +199,7 @@ cc_library(
"pooling.cc", "pooling.cc",
"pow.cc", "pow.cc",
"range.cc", "range.cc",
"rank.cc",
"reduce.cc", "reduce.cc",
"reshape.cc", "reshape.cc",
"resize_bilinear.cc", "resize_bilinear.cc",
@ -1096,6 +1097,19 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "rank_test",
size = "small",
srcs = ["rank_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
tf_cc_test( tf_cc_test(
name = "pow_test", name = "pow_test",
size = "small", size = "small",

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace rank {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = kTfLiteInt32;
// Rank produces a 0-D int32 Tensor representing the rank of input.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(0);
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0);
if (output->type == kTfLiteInt32) {
int32_t* output_data = GetTensorData<int32_t>(output);
*output_data = NumDimensions(input);
} else {
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace rank
TfLiteRegistration* Register_RANK() {
static TfLiteRegistration r = {nullptr, nullptr, rank::Prepare, rank::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,91 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <initializer_list>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class RankOpModel : public SingleOpModel {
public:
RankOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
TensorType output_type = TensorType_INT32;
input_ = AddInput(input_type);
output_ = AddOutput(output_type);
SetBuiltinOp(BuiltinOperator_RANK, BuiltinOptions_RankOptions,
CreateRankOptions(builder_).Union());
BuildInterpreter({input_shape});
}
TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
int input() { return input_; }
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int output_;
};
TEST(RankOpTest, InputTypeFloat) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
TEST(RankOpTest, InputTypeInt) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
TEST(RankOpTest, ScalarTensor) {
RankOpModel model({}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
TEST(RankOpTest, EmptyTensor) {
RankOpModel model({1, 0}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -113,6 +113,7 @@ TfLiteRegistration* Register_NOT_EQUAL();
TfLiteRegistration* Register_SQRT(); TfLiteRegistration* Register_SQRT();
TfLiteRegistration* Register_RSQRT(); TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SHAPE(); TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_RANK();
TfLiteRegistration* Register_POW(); TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT(); TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK(); TfLiteRegistration* Register_PACK();
@ -336,6 +337,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_RANK, Register_RANK());
AddBuiltin(BuiltinOperator_POW, Register_POW()); AddBuiltin(BuiltinOperator_POW, Register_POW());
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
AddBuiltin(BuiltinOperator_PACK, Register_PACK(), AddBuiltin(BuiltinOperator_PACK, Register_PACK(),

View File

@ -668,6 +668,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_ADD_N: case tflite::BuiltinOperator_ADD_N:
case tflite::BuiltinOperator_GATHER_ND: case tflite::BuiltinOperator_GATHER_ND:
case tflite::BuiltinOperator_WHERE: case tflite::BuiltinOperator_WHERE:
case tflite::BuiltinOperator_RANK:
logError("Op code %d is currently not delegated to NNAPI", builtin); logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError; return kTfLiteError;
break; break;

View File

@ -223,6 +223,7 @@ enum BuiltinOperator : byte {
GATHER_ND = 107, GATHER_ND = 107,
COS = 108, COS = 108,
WHERE = 109, WHERE = 109,
RANK = 110,
} }
// Options for the builtin operators. // Options for the builtin operators.
@ -312,6 +313,7 @@ union BuiltinOptions {
GatherNdOptions, GatherNdOptions,
CosOptions, CosOptions,
WhereOptions, WhereOptions,
RankOptions,
} }
enum Padding : byte { SAME, VALID } enum Padding : byte { SAME, VALID }
@ -652,6 +654,9 @@ table ShapeOptions {
out_type : TensorType; out_type : TensorType;
} }
table RankOptions {
}
table PowOptions { table PowOptions {
} }

View File

@ -217,6 +217,9 @@ struct NotEqualOptionsT;
struct ShapeOptions; struct ShapeOptions;
struct ShapeOptionsT; struct ShapeOptionsT;
struct RankOptions;
struct RankOptionsT;
struct PowOptions; struct PowOptions;
struct PowOptionsT; struct PowOptionsT;
@ -545,11 +548,12 @@ enum BuiltinOperator {
BuiltinOperator_GATHER_ND = 107, BuiltinOperator_GATHER_ND = 107,
BuiltinOperator_COS = 108, BuiltinOperator_COS = 108,
BuiltinOperator_WHERE = 109, BuiltinOperator_WHERE = 109,
BuiltinOperator_RANK = 110,
BuiltinOperator_MIN = BuiltinOperator_ADD, BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_WHERE BuiltinOperator_MAX = BuiltinOperator_RANK
}; };
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[109] { inline const BuiltinOperator (&EnumValuesBuiltinOperator())[110] {
static const BuiltinOperator values[] = { static const BuiltinOperator values[] = {
BuiltinOperator_ADD, BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D, BuiltinOperator_AVERAGE_POOL_2D,
@ -659,7 +663,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[109] {
BuiltinOperator_ADD_N, BuiltinOperator_ADD_N,
BuiltinOperator_GATHER_ND, BuiltinOperator_GATHER_ND,
BuiltinOperator_COS, BuiltinOperator_COS,
BuiltinOperator_WHERE BuiltinOperator_WHERE,
BuiltinOperator_RANK
}; };
return values; return values;
} }
@ -776,6 +781,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"GATHER_ND", "GATHER_ND",
"COS", "COS",
"WHERE", "WHERE",
"RANK",
nullptr nullptr
}; };
return names; return names;
@ -873,11 +879,12 @@ enum BuiltinOptions {
BuiltinOptions_GatherNdOptions = 83, BuiltinOptions_GatherNdOptions = 83,
BuiltinOptions_CosOptions = 84, BuiltinOptions_CosOptions = 84,
BuiltinOptions_WhereOptions = 85, BuiltinOptions_WhereOptions = 85,
BuiltinOptions_RankOptions = 86,
BuiltinOptions_MIN = BuiltinOptions_NONE, BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_WhereOptions BuiltinOptions_MAX = BuiltinOptions_RankOptions
}; };
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[86] { inline const BuiltinOptions (&EnumValuesBuiltinOptions())[87] {
static const BuiltinOptions values[] = { static const BuiltinOptions values[] = {
BuiltinOptions_NONE, BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions, BuiltinOptions_Conv2DOptions,
@ -964,7 +971,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[86] {
BuiltinOptions_AddNOptions, BuiltinOptions_AddNOptions,
BuiltinOptions_GatherNdOptions, BuiltinOptions_GatherNdOptions,
BuiltinOptions_CosOptions, BuiltinOptions_CosOptions,
BuiltinOptions_WhereOptions BuiltinOptions_WhereOptions,
BuiltinOptions_RankOptions
}; };
return values; return values;
} }
@ -1057,6 +1065,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"GatherNdOptions", "GatherNdOptions",
"CosOptions", "CosOptions",
"WhereOptions", "WhereOptions",
"RankOptions",
nullptr nullptr
}; };
return names; return names;
@ -1411,6 +1420,10 @@ template<> struct BuiltinOptionsTraits<WhereOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions; static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions;
}; };
template<> struct BuiltinOptionsTraits<RankOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_RankOptions;
};
struct BuiltinOptionsUnion { struct BuiltinOptionsUnion {
BuiltinOptions type; BuiltinOptions type;
void *value; void *value;
@ -2122,6 +2135,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_WhereOptions ? return type == BuiltinOptions_WhereOptions ?
reinterpret_cast<const WhereOptionsT *>(value) : nullptr; reinterpret_cast<const WhereOptionsT *>(value) : nullptr;
} }
RankOptionsT *AsRankOptions() {
return type == BuiltinOptions_RankOptions ?
reinterpret_cast<RankOptionsT *>(value) : nullptr;
}
const RankOptionsT *AsRankOptions() const {
return type == BuiltinOptions_RankOptions ?
reinterpret_cast<const RankOptionsT *>(value) : nullptr;
}
}; };
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -6340,6 +6361,46 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(
flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct RankOptionsT : public flatbuffers::NativeTable {
typedef RankOptions TableType;
RankOptionsT() {
}
};
struct RankOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef RankOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
RankOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<RankOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct RankOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit RankOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
RankOptionsBuilder &operator=(const RankOptionsBuilder &);
flatbuffers::Offset<RankOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<RankOptions>(end);
return o;
}
};
inline flatbuffers::Offset<RankOptions> CreateRankOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
RankOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<RankOptions> CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct PowOptionsT : public flatbuffers::NativeTable { struct PowOptionsT : public flatbuffers::NativeTable {
typedef PowOptions TableType; typedef PowOptions TableType;
PowOptionsT() { PowOptionsT() {
@ -7806,6 +7867,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const WhereOptions *builtin_options_as_WhereOptions() const { const WhereOptions *builtin_options_as_WhereOptions() const {
return builtin_options_type() == BuiltinOptions_WhereOptions ? static_cast<const WhereOptions *>(builtin_options()) : nullptr; return builtin_options_type() == BuiltinOptions_WhereOptions ? static_cast<const WhereOptions *>(builtin_options()) : nullptr;
} }
const RankOptions *builtin_options_as_RankOptions() const {
return builtin_options_type() == BuiltinOptions_RankOptions ? static_cast<const RankOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const { const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
} }
@ -8177,6 +8241,10 @@ template<> inline const WhereOptions *Operator::builtin_options_as<WhereOptions>
return builtin_options_as_WhereOptions(); return builtin_options_as_WhereOptions();
} }
template<> inline const RankOptions *Operator::builtin_options_as<RankOptions>() const {
return builtin_options_as_RankOptions();
}
struct OperatorBuilder { struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_; flatbuffers::uoffset_t start_;
@ -10374,6 +10442,29 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBuf
_out_type); _out_type);
} }
inline RankOptionsT *RankOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new RankOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void RankOptions::UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<RankOptions> RankOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateRankOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<RankOptions> CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RankOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateRankOptions(
_fbb);
}
inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new PowOptionsT(); auto _o = new PowOptionsT();
UnPackTo(_o, _resolver); UnPackTo(_o, _resolver);
@ -11537,6 +11628,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const WhereOptions *>(obj); auto ptr = reinterpret_cast<const WhereOptions *>(obj);
return verifier.VerifyTable(ptr); return verifier.VerifyTable(ptr);
} }
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<const RankOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false; default: return false;
} }
} }
@ -11895,6 +11990,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const WhereOptions *>(obj); auto ptr = reinterpret_cast<const WhereOptions *>(obj);
return ptr->UnPack(resolver); return ptr->UnPack(resolver);
} }
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<const RankOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr; default: return nullptr;
} }
} }
@ -12241,6 +12340,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const WhereOptionsT *>(value); auto ptr = reinterpret_cast<const WhereOptionsT *>(value);
return CreateWhereOptions(_fbb, ptr, _rehasher).Union(); return CreateWhereOptions(_fbb, ptr, _rehasher).Union();
} }
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<const RankOptionsT *>(value);
return CreateRankOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0; default: return 0;
} }
} }
@ -12587,6 +12690,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new WhereOptionsT(*reinterpret_cast<WhereOptionsT *>(u.value)); value = new WhereOptionsT(*reinterpret_cast<WhereOptionsT *>(u.value));
break; break;
} }
case BuiltinOptions_RankOptions: {
value = new RankOptionsT(*reinterpret_cast<RankOptionsT *>(u.value));
break;
}
default: default:
break; break;
} }
@ -13019,6 +13126,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr; delete ptr;
break; break;
} }
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<RankOptionsT *>(value);
delete ptr;
break;
}
default: break; default: break;
} }
value = nullptr; value = nullptr;

View File

@ -2264,6 +2264,29 @@ def make_shape_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_rank_tests(zip_path):
"""Make a set of tests to do rank."""
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]],
}]
def build_graph(parameters):
"""Build the rank op testing graph."""
input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input")
out = tf.rank(input_value)
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_one_hot_tests(zip_path): def make_one_hot_tests(zip_path):
"""Make a set of tests to do one_hot.""" """Make a set of tests to do one_hot."""

View File

@ -1305,7 +1305,8 @@ void ConvertTensorFlowShapeOperator(const Model& model,
GetTensorFlowDataType(model, src_op.outputs[0])); GetTensorFlowDataType(model, src_op.outputs[0]));
} }
void ConvertRankOperator(const Model& model, const RankOperator& src_op, void ConvertRankOperator(const Model& model,
const TensorFlowRankOperator& src_op,
GraphDef* tensorflow_graph) { GraphDef* tensorflow_graph) {
tensorflow::NodeDef* rank_op = tensorflow_graph->add_node(); tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
rank_op->set_op("Rank"); rank_op->set_op("Rank");
@ -2274,7 +2275,8 @@ void ConvertOperator(const Model& model, const Operator& src_op,
model, static_cast<const TensorFlowShapeOperator&>(src_op), model, static_cast<const TensorFlowShapeOperator&>(src_op),
tensorflow_graph); tensorflow_graph);
} else if (src_op.type == OperatorType::kRank) { } else if (src_op.type == OperatorType::kRank) {
ConvertRankOperator(model, static_cast<const RankOperator&>(src_op), ConvertRankOperator(model,
static_cast<const TensorFlowRankOperator&>(src_op),
tensorflow_graph); tensorflow_graph);
} else if (src_op.type == OperatorType::kRange) { } else if (src_op.type == OperatorType::kRange) {
ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op), ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),

View File

@ -1517,7 +1517,7 @@ void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
output_array.copy_shape(output_shape); output_array.copy_shape(output_shape);
} }
void ProcessRankOperator(Model* model, RankOperator* op) { void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
CHECK_GE(op->inputs.size(), 1); CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1); CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]); auto& output_array = model->GetArray(op->outputs[0]);
@ -2219,7 +2219,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
ProcessRangeOperator(model, static_cast<RangeOperator*>(op)); ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
break; break;
case OperatorType::kRank: case OperatorType::kRank:
ProcessRankOperator(model, static_cast<RankOperator*>(op)); ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
break; break;
case OperatorType::kShape: case OperatorType::kShape:
ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op)); ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));

View File

@ -2472,7 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Prod", ConvertReduceOperator<TensorFlowProdOperator>}, {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
{"RandomUniform", ConvertRandomUniform}, {"RandomUniform", ConvertRandomUniform},
{"Range", ConvertRangeOperator}, {"Range", ConvertRangeOperator},
{"Rank", ConvertSimpleOperator<RankOperator, 1, 1>}, {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
{"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>}, {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
{"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>}, {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>}, {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},

View File

@ -24,11 +24,11 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_port.h"
#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/core/platform/logging.h"
namespace toco { namespace toco {
@ -1259,13 +1259,12 @@ struct RangeOperator : Operator {
// Inputs: // Inputs:
// inputs[0]: required: the input array // inputs[0]: required: the input array
// //
// This operation outputs a 0-D integer tensor representing the rank of // This operation outputs a 0-D int32 Tensor representing the rank of input.
// the input.
// //
// TensorFlow equivalent: Rank. We currently assume that the output is int32 // TensorFlow equivalent: Rank.
// and not int64. The output type could be stored herein. struct TensorFlowRankOperator : Operator {
struct RankOperator : Operator { TensorFlowRankOperator() : Operator(OperatorType::kRank) {}
RankOperator() : Operator(OperatorType::kRank) {} ArrayDataType output_data_type = ArrayDataType::kInt32;
}; };
// Element-wise negation (-x) operator. // Element-wise negation (-x) operator.

View File

@ -2452,6 +2452,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill)); MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>( ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
"REVERSE_V2", OperatorType::kReverseV2)); "REVERSE_V2", OperatorType::kReverseV2));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
"RANK", OperatorType::kRank));
return ops; return ops;
} }
} // namespace } // namespace

View File

@ -154,6 +154,7 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill); CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2", CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
OperatorType::kReverseV2); OperatorType::kReverseV2);
CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
} }
TEST_F(OperatorTest, BuiltinAdd) { TEST_F(OperatorTest, BuiltinAdd) {