Implement Rank.
PiperOrigin-RevId: 234068627
This commit is contained in:
parent
4b7ceaee0f
commit
b23578d605
@ -286,6 +286,7 @@ def generated_test_models():
|
||||
"prelu",
|
||||
"pow",
|
||||
"range",
|
||||
"rank",
|
||||
"reduce_any",
|
||||
"reduce_max",
|
||||
"reduce_min",
|
||||
|
@ -135,6 +135,7 @@ typedef enum {
|
||||
kTfLiteBuiltinGatherNd = 107,
|
||||
kTfLiteBuiltinCos = 108,
|
||||
kTfLiteBuiltinWhere = 109,
|
||||
kTfLiteBuiltinRank = 110,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -333,6 +333,9 @@ typedef struct {
|
||||
TfLiteType out_type;
|
||||
} TfLiteShapeParams;
|
||||
|
||||
typedef struct {
|
||||
} TfLiteRankParams;
|
||||
|
||||
typedef struct {
|
||||
// Parameters supported by version 1:
|
||||
float min;
|
||||
|
@ -71,6 +71,7 @@ TEST(IntArray, CanCompileStructs) {
|
||||
TfLiteTransposeConvParams transpose_conv_params;
|
||||
TfLiteSparseToDenseParams sparse_to_dense_params;
|
||||
TfLiteShapeParams shape_params;
|
||||
TfLiteRankParams rank_params;
|
||||
TfLiteFakeQuantParams fake_quant_params;
|
||||
TfLitePackParams pack_params;
|
||||
TfLiteOneHotParams one_hot_params;
|
||||
|
@ -731,6 +731,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_ADD_N:
|
||||
case BuiltinOperator_GATHER_ND:
|
||||
case BuiltinOperator_WHERE:
|
||||
case BuiltinOperator_RANK:
|
||||
break;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -179,6 +179,7 @@ class OpOptionData {
|
||||
op_to_option_["LOG"] = "";
|
||||
op_to_option_["SQRT"] = "";
|
||||
op_to_option_["RSQRT"] = "";
|
||||
op_to_option_["Rank"] = "";
|
||||
|
||||
// TODO(aselle): These are undesirable hacks. Consider changing C structs
|
||||
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
|
||||
|
@ -725,6 +725,17 @@ Options {
|
||||
}
|
||||
```
|
||||
|
||||
**RANK**
|
||||
|
||||
```
|
||||
Inputs {
|
||||
0: a tensor
|
||||
}
|
||||
Outputs {
|
||||
0: a 0-D int32 Tensor representing the rank of input
|
||||
}
|
||||
```
|
||||
|
||||
**RELU**
|
||||
|
||||
```
|
||||
|
@ -199,6 +199,7 @@ cc_library(
|
||||
"pooling.cc",
|
||||
"pow.cc",
|
||||
"range.cc",
|
||||
"rank.cc",
|
||||
"reduce.cc",
|
||||
"reshape.cc",
|
||||
"resize_bilinear.cc",
|
||||
@ -1096,6 +1097,19 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "rank_test",
|
||||
size = "small",
|
||||
srcs = ["rank_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "pow_test",
|
||||
size = "small",
|
||||
|
65
tensorflow/lite/kernels/rank.cc
Normal file
65
tensorflow/lite/kernels/rank.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace rank {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
output->type = kTfLiteInt32;
|
||||
|
||||
// Rank produces a 0-D int32 Tensor representing the rank of input.
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(0);
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0);
|
||||
|
||||
if (output->type == kTfLiteInt32) {
|
||||
int32_t* output_data = GetTensorData<int32_t>(output);
|
||||
*output_data = NumDimensions(input);
|
||||
} else {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace rank
|
||||
|
||||
TfLiteRegistration* Register_RANK() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, rank::Prepare, rank::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
91
tensorflow/lite/kernels/rank_test.cc
Normal file
91
tensorflow/lite/kernels/rank_test.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class RankOpModel : public SingleOpModel {
|
||||
public:
|
||||
RankOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
|
||||
TensorType output_type = TensorType_INT32;
|
||||
input_ = AddInput(input_type);
|
||||
output_ = AddOutput(output_type);
|
||||
SetBuiltinOp(BuiltinOperator_RANK, BuiltinOptions_RankOptions,
|
||||
CreateRankOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape});
|
||||
}
|
||||
|
||||
TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
|
||||
|
||||
int input() { return input_; }
|
||||
|
||||
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(RankOpTest, InputTypeFloat) {
|
||||
RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32);
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
|
||||
EXPECT_TRUE(model.GetOutputShape().empty());
|
||||
}
|
||||
|
||||
TEST(RankOpTest, InputTypeInt) {
|
||||
RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32);
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
|
||||
EXPECT_TRUE(model.GetOutputShape().empty());
|
||||
}
|
||||
|
||||
TEST(RankOpTest, ScalarTensor) {
|
||||
RankOpModel model({}, TensorType_FLOAT32);
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
|
||||
EXPECT_TRUE(model.GetOutputShape().empty());
|
||||
}
|
||||
|
||||
TEST(RankOpTest, EmptyTensor) {
|
||||
RankOpModel model({1, 0}, TensorType_FLOAT32);
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
|
||||
EXPECT_TRUE(model.GetOutputShape().empty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -113,6 +113,7 @@ TfLiteRegistration* Register_NOT_EQUAL();
|
||||
TfLiteRegistration* Register_SQRT();
|
||||
TfLiteRegistration* Register_RSQRT();
|
||||
TfLiteRegistration* Register_SHAPE();
|
||||
TfLiteRegistration* Register_RANK();
|
||||
TfLiteRegistration* Register_POW();
|
||||
TfLiteRegistration* Register_FAKE_QUANT();
|
||||
TfLiteRegistration* Register_PACK();
|
||||
@ -336,6 +337,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
|
||||
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
|
||||
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
|
||||
AddBuiltin(BuiltinOperator_RANK, Register_RANK());
|
||||
AddBuiltin(BuiltinOperator_POW, Register_POW());
|
||||
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
|
||||
AddBuiltin(BuiltinOperator_PACK, Register_PACK(),
|
||||
|
@ -668,6 +668,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_ADD_N:
|
||||
case tflite::BuiltinOperator_GATHER_ND:
|
||||
case tflite::BuiltinOperator_WHERE:
|
||||
case tflite::BuiltinOperator_RANK:
|
||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
return kTfLiteError;
|
||||
break;
|
||||
|
@ -223,6 +223,7 @@ enum BuiltinOperator : byte {
|
||||
GATHER_ND = 107,
|
||||
COS = 108,
|
||||
WHERE = 109,
|
||||
RANK = 110,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -312,6 +313,7 @@ union BuiltinOptions {
|
||||
GatherNdOptions,
|
||||
CosOptions,
|
||||
WhereOptions,
|
||||
RankOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -652,6 +654,9 @@ table ShapeOptions {
|
||||
out_type : TensorType;
|
||||
}
|
||||
|
||||
table RankOptions {
|
||||
}
|
||||
|
||||
table PowOptions {
|
||||
}
|
||||
|
||||
|
@ -217,6 +217,9 @@ struct NotEqualOptionsT;
|
||||
struct ShapeOptions;
|
||||
struct ShapeOptionsT;
|
||||
|
||||
struct RankOptions;
|
||||
struct RankOptionsT;
|
||||
|
||||
struct PowOptions;
|
||||
struct PowOptionsT;
|
||||
|
||||
@ -545,11 +548,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_GATHER_ND = 107,
|
||||
BuiltinOperator_COS = 108,
|
||||
BuiltinOperator_WHERE = 109,
|
||||
BuiltinOperator_RANK = 110,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_WHERE
|
||||
BuiltinOperator_MAX = BuiltinOperator_RANK
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[109] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[110] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -659,7 +663,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[109] {
|
||||
BuiltinOperator_ADD_N,
|
||||
BuiltinOperator_GATHER_ND,
|
||||
BuiltinOperator_COS,
|
||||
BuiltinOperator_WHERE
|
||||
BuiltinOperator_WHERE,
|
||||
BuiltinOperator_RANK
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -776,6 +781,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"GATHER_ND",
|
||||
"COS",
|
||||
"WHERE",
|
||||
"RANK",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -873,11 +879,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_GatherNdOptions = 83,
|
||||
BuiltinOptions_CosOptions = 84,
|
||||
BuiltinOptions_WhereOptions = 85,
|
||||
BuiltinOptions_RankOptions = 86,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_WhereOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_RankOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[86] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[87] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -964,7 +971,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[86] {
|
||||
BuiltinOptions_AddNOptions,
|
||||
BuiltinOptions_GatherNdOptions,
|
||||
BuiltinOptions_CosOptions,
|
||||
BuiltinOptions_WhereOptions
|
||||
BuiltinOptions_WhereOptions,
|
||||
BuiltinOptions_RankOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1057,6 +1065,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"GatherNdOptions",
|
||||
"CosOptions",
|
||||
"WhereOptions",
|
||||
"RankOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1411,6 +1420,10 @@ template<> struct BuiltinOptionsTraits<WhereOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<RankOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_RankOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2122,6 +2135,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_WhereOptions ?
|
||||
reinterpret_cast<const WhereOptionsT *>(value) : nullptr;
|
||||
}
|
||||
RankOptionsT *AsRankOptions() {
|
||||
return type == BuiltinOptions_RankOptions ?
|
||||
reinterpret_cast<RankOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const RankOptionsT *AsRankOptions() const {
|
||||
return type == BuiltinOptions_RankOptions ?
|
||||
reinterpret_cast<const RankOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -6340,6 +6361,46 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(
|
||||
|
||||
flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct RankOptionsT : public flatbuffers::NativeTable {
|
||||
typedef RankOptions TableType;
|
||||
RankOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct RankOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef RankOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
RankOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<RankOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct RankOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit RankOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
RankOptionsBuilder &operator=(const RankOptionsBuilder &);
|
||||
flatbuffers::Offset<RankOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<RankOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<RankOptions> CreateRankOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
RankOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<RankOptions> CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct PowOptionsT : public flatbuffers::NativeTable {
|
||||
typedef PowOptions TableType;
|
||||
PowOptionsT() {
|
||||
@ -7806,6 +7867,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const WhereOptions *builtin_options_as_WhereOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_WhereOptions ? static_cast<const WhereOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const RankOptions *builtin_options_as_RankOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_RankOptions ? static_cast<const RankOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -8177,6 +8241,10 @@ template<> inline const WhereOptions *Operator::builtin_options_as<WhereOptions>
|
||||
return builtin_options_as_WhereOptions();
|
||||
}
|
||||
|
||||
template<> inline const RankOptions *Operator::builtin_options_as<RankOptions>() const {
|
||||
return builtin_options_as_RankOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -10374,6 +10442,29 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBuf
|
||||
_out_type);
|
||||
}
|
||||
|
||||
inline RankOptionsT *RankOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new RankOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void RankOptions::UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<RankOptions> RankOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateRankOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<RankOptions> CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RankOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateRankOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new PowOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -11537,6 +11628,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const WhereOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_RankOptions: {
|
||||
auto ptr = reinterpret_cast<const RankOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -11895,6 +11990,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const WhereOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_RankOptions: {
|
||||
auto ptr = reinterpret_cast<const RankOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -12241,6 +12340,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const WhereOptionsT *>(value);
|
||||
return CreateWhereOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_RankOptions: {
|
||||
auto ptr = reinterpret_cast<const RankOptionsT *>(value);
|
||||
return CreateRankOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -12587,6 +12690,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new WhereOptionsT(*reinterpret_cast<WhereOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_RankOptions: {
|
||||
value = new RankOptionsT(*reinterpret_cast<RankOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -13019,6 +13126,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_RankOptions: {
|
||||
auto ptr = reinterpret_cast<RankOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -2264,6 +2264,29 @@ def make_shape_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_rank_tests(zip_path):
|
||||
"""Make a set of tests to do rank."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the rank op testing graph."""
|
||||
input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input")
|
||||
out = tf.rank(input_value)
|
||||
return [input_value], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(parameters["input_dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_one_hot_tests(zip_path):
|
||||
"""Make a set of tests to do one_hot."""
|
||||
|
||||
|
@ -1305,7 +1305,8 @@ void ConvertTensorFlowShapeOperator(const Model& model,
|
||||
GetTensorFlowDataType(model, src_op.outputs[0]));
|
||||
}
|
||||
|
||||
void ConvertRankOperator(const Model& model, const RankOperator& src_op,
|
||||
void ConvertRankOperator(const Model& model,
|
||||
const TensorFlowRankOperator& src_op,
|
||||
GraphDef* tensorflow_graph) {
|
||||
tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
|
||||
rank_op->set_op("Rank");
|
||||
@ -2274,7 +2275,8 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
model, static_cast<const TensorFlowShapeOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
} else if (src_op.type == OperatorType::kRank) {
|
||||
ConvertRankOperator(model, static_cast<const RankOperator&>(src_op),
|
||||
ConvertRankOperator(model,
|
||||
static_cast<const TensorFlowRankOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
} else if (src_op.type == OperatorType::kRange) {
|
||||
ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
|
||||
|
@ -1517,7 +1517,7 @@ void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
|
||||
output_array.copy_shape(output_shape);
|
||||
}
|
||||
|
||||
void ProcessRankOperator(Model* model, RankOperator* op) {
|
||||
void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
|
||||
CHECK_GE(op->inputs.size(), 1);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
@ -2219,7 +2219,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
|
||||
break;
|
||||
case OperatorType::kRank:
|
||||
ProcessRankOperator(model, static_cast<RankOperator*>(op));
|
||||
ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
|
||||
break;
|
||||
case OperatorType::kShape:
|
||||
ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
|
||||
|
@ -2472,7 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
|
||||
{"RandomUniform", ConvertRandomUniform},
|
||||
{"Range", ConvertRangeOperator},
|
||||
{"Rank", ConvertSimpleOperator<RankOperator, 1, 1>},
|
||||
{"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
|
||||
{"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
|
||||
{"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
|
||||
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
|
||||
|
@ -24,11 +24,11 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/runtime/types.h"
|
||||
#include "tensorflow/lite/toco/toco_port.h"
|
||||
#include "tensorflow/lite/toco/toco_types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
@ -1259,13 +1259,12 @@ struct RangeOperator : Operator {
|
||||
// Inputs:
|
||||
// inputs[0]: required: the input array
|
||||
//
|
||||
// This operation outputs a 0-D integer tensor representing the rank of
|
||||
// the input.
|
||||
// This operation outputs a 0-D int32 Tensor representing the rank of input.
|
||||
//
|
||||
// TensorFlow equivalent: Rank. We currently assume that the output is int32
|
||||
// and not int64. The output type could be stored herein.
|
||||
struct RankOperator : Operator {
|
||||
RankOperator() : Operator(OperatorType::kRank) {}
|
||||
// TensorFlow equivalent: Rank.
|
||||
struct TensorFlowRankOperator : Operator {
|
||||
TensorFlowRankOperator() : Operator(OperatorType::kRank) {}
|
||||
ArrayDataType output_data_type = ArrayDataType::kInt32;
|
||||
};
|
||||
|
||||
// Element-wise negation (-x) operator.
|
||||
|
@ -2452,6 +2452,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
|
||||
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
|
||||
"REVERSE_V2", OperatorType::kReverseV2));
|
||||
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
|
||||
"RANK", OperatorType::kRank));
|
||||
return ops;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -154,6 +154,7 @@ TEST_F(OperatorTest, SimpleOperators) {
|
||||
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
|
||||
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
|
||||
OperatorType::kReverseV2);
|
||||
CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinAdd) {
|
||||
|
Loading…
Reference in New Issue
Block a user