Implement Rank.

PiperOrigin-RevId: 234068627
This commit is contained in:
A. Unique TensorFlower 2019-02-14 18:48:49 -08:00 committed by TensorFlower Gardener
parent 4b7ceaee0f
commit b23578d605
21 changed files with 354 additions and 18 deletions

View File

@ -286,6 +286,7 @@ def generated_test_models():
"prelu",
"pow",
"range",
"rank",
"reduce_any",
"reduce_max",
"reduce_min",

View File

@ -135,6 +135,7 @@ typedef enum {
kTfLiteBuiltinGatherNd = 107,
kTfLiteBuiltinCos = 108,
kTfLiteBuiltinWhere = 109,
kTfLiteBuiltinRank = 110,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -333,6 +333,9 @@ typedef struct {
TfLiteType out_type;
} TfLiteShapeParams;
typedef struct {
} TfLiteRankParams;
typedef struct {
// Parameters supported by version 1:
float min;

View File

@ -71,6 +71,7 @@ TEST(IntArray, CanCompileStructs) {
TfLiteTransposeConvParams transpose_conv_params;
TfLiteSparseToDenseParams sparse_to_dense_params;
TfLiteShapeParams shape_params;
TfLiteRankParams rank_params;
TfLiteFakeQuantParams fake_quant_params;
TfLitePackParams pack_params;
TfLiteOneHotParams one_hot_params;

View File

@ -731,6 +731,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_ADD_N:
case BuiltinOperator_GATHER_ND:
case BuiltinOperator_WHERE:
case BuiltinOperator_RANK:
break;
}
return kTfLiteOk;

View File

@ -179,6 +179,7 @@ class OpOptionData {
op_to_option_["LOG"] = "";
op_to_option_["SQRT"] = "";
op_to_option_["RSQRT"] = "";
op_to_option_["Rank"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";

View File

@ -725,6 +725,17 @@ Options {
}
```
**RANK**
```
Inputs {
0: a tensor
}
Outputs {
0: a 0-D int32 Tensor representing the rank of input
}
```
**RELU**
```

View File

@ -199,6 +199,7 @@ cc_library(
"pooling.cc",
"pow.cc",
"range.cc",
"rank.cc",
"reduce.cc",
"reshape.cc",
"resize_bilinear.cc",
@ -1096,6 +1097,19 @@ tf_cc_test(
],
)
tf_cc_test(
name = "rank_test",
size = "small",
srcs = ["rank_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
tf_cc_test(
name = "pow_test",
size = "small",

View File

@ -0,0 +1,65 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace rank {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = kTfLiteInt32;
// Rank produces a 0-D int32 Tensor representing the rank of input.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(0);
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0);
if (output->type == kTfLiteInt32) {
int32_t* output_data = GetTensorData<int32_t>(output);
*output_data = NumDimensions(input);
} else {
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace rank
TfLiteRegistration* Register_RANK() {
static TfLiteRegistration r = {nullptr, nullptr, rank::Prepare, rank::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,91 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <initializer_list>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class RankOpModel : public SingleOpModel {
public:
RankOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
TensorType output_type = TensorType_INT32;
input_ = AddInput(input_type);
output_ = AddOutput(output_type);
SetBuiltinOp(BuiltinOperator_RANK, BuiltinOptions_RankOptions,
CreateRankOptions(builder_).Union());
BuildInterpreter({input_shape});
}
TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
int input() { return input_; }
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int output_;
};
TEST(RankOpTest, InputTypeFloat) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
TEST(RankOpTest, InputTypeInt) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
TEST(RankOpTest, ScalarTensor) {
RankOpModel model({}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
TEST(RankOpTest, EmptyTensor) {
RankOpModel model({1, 0}, TensorType_FLOAT32);
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
EXPECT_TRUE(model.GetOutputShape().empty());
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -113,6 +113,7 @@ TfLiteRegistration* Register_NOT_EQUAL();
TfLiteRegistration* Register_SQRT();
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_RANK();
TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK();
@ -336,6 +337,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_RANK, Register_RANK());
AddBuiltin(BuiltinOperator_POW, Register_POW());
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
AddBuiltin(BuiltinOperator_PACK, Register_PACK(),

View File

@ -668,6 +668,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_ADD_N:
case tflite::BuiltinOperator_GATHER_ND:
case tflite::BuiltinOperator_WHERE:
case tflite::BuiltinOperator_RANK:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -223,6 +223,7 @@ enum BuiltinOperator : byte {
GATHER_ND = 107,
COS = 108,
WHERE = 109,
RANK = 110,
}
// Options for the builtin operators.
@ -312,6 +313,7 @@ union BuiltinOptions {
GatherNdOptions,
CosOptions,
WhereOptions,
RankOptions,
}
enum Padding : byte { SAME, VALID }
@ -652,6 +654,9 @@ table ShapeOptions {
out_type : TensorType;
}
table RankOptions {
}
table PowOptions {
}

View File

@ -217,6 +217,9 @@ struct NotEqualOptionsT;
struct ShapeOptions;
struct ShapeOptionsT;
struct RankOptions;
struct RankOptionsT;
struct PowOptions;
struct PowOptionsT;
@ -545,11 +548,12 @@ enum BuiltinOperator {
BuiltinOperator_GATHER_ND = 107,
BuiltinOperator_COS = 108,
BuiltinOperator_WHERE = 109,
BuiltinOperator_RANK = 110,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_WHERE
BuiltinOperator_MAX = BuiltinOperator_RANK
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[109] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[110] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -659,7 +663,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[109] {
BuiltinOperator_ADD_N,
BuiltinOperator_GATHER_ND,
BuiltinOperator_COS,
BuiltinOperator_WHERE
BuiltinOperator_WHERE,
BuiltinOperator_RANK
};
return values;
}
@ -776,6 +781,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"GATHER_ND",
"COS",
"WHERE",
"RANK",
nullptr
};
return names;
@ -873,11 +879,12 @@ enum BuiltinOptions {
BuiltinOptions_GatherNdOptions = 83,
BuiltinOptions_CosOptions = 84,
BuiltinOptions_WhereOptions = 85,
BuiltinOptions_RankOptions = 86,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_WhereOptions
BuiltinOptions_MAX = BuiltinOptions_RankOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[86] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[87] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -964,7 +971,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[86] {
BuiltinOptions_AddNOptions,
BuiltinOptions_GatherNdOptions,
BuiltinOptions_CosOptions,
BuiltinOptions_WhereOptions
BuiltinOptions_WhereOptions,
BuiltinOptions_RankOptions
};
return values;
}
@ -1057,6 +1065,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"GatherNdOptions",
"CosOptions",
"WhereOptions",
"RankOptions",
nullptr
};
return names;
@ -1411,6 +1420,10 @@ template<> struct BuiltinOptionsTraits<WhereOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions;
};
template<> struct BuiltinOptionsTraits<RankOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_RankOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2122,6 +2135,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_WhereOptions ?
reinterpret_cast<const WhereOptionsT *>(value) : nullptr;
}
RankOptionsT *AsRankOptions() {
return type == BuiltinOptions_RankOptions ?
reinterpret_cast<RankOptionsT *>(value) : nullptr;
}
const RankOptionsT *AsRankOptions() const {
return type == BuiltinOptions_RankOptions ?
reinterpret_cast<const RankOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -6340,6 +6361,46 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(
flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct RankOptionsT : public flatbuffers::NativeTable {
typedef RankOptions TableType;
RankOptionsT() {
}
};
struct RankOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef RankOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
RankOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<RankOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct RankOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit RankOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
RankOptionsBuilder &operator=(const RankOptionsBuilder &);
flatbuffers::Offset<RankOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<RankOptions>(end);
return o;
}
};
inline flatbuffers::Offset<RankOptions> CreateRankOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
RankOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<RankOptions> CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct PowOptionsT : public flatbuffers::NativeTable {
typedef PowOptions TableType;
PowOptionsT() {
@ -7806,6 +7867,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const WhereOptions *builtin_options_as_WhereOptions() const {
return builtin_options_type() == BuiltinOptions_WhereOptions ? static_cast<const WhereOptions *>(builtin_options()) : nullptr;
}
const RankOptions *builtin_options_as_RankOptions() const {
return builtin_options_type() == BuiltinOptions_RankOptions ? static_cast<const RankOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -8177,6 +8241,10 @@ template<> inline const WhereOptions *Operator::builtin_options_as<WhereOptions>
return builtin_options_as_WhereOptions();
}
template<> inline const RankOptions *Operator::builtin_options_as<RankOptions>() const {
return builtin_options_as_RankOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -10374,6 +10442,29 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBuf
_out_type);
}
inline RankOptionsT *RankOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new RankOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void RankOptions::UnPackTo(RankOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<RankOptions> RankOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateRankOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<RankOptions> CreateRankOptions(flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RankOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateRankOptions(
_fbb);
}
inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new PowOptionsT();
UnPackTo(_o, _resolver);
@ -11537,6 +11628,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const WhereOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<const RankOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -11895,6 +11990,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const WhereOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<const RankOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -12241,6 +12340,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const WhereOptionsT *>(value);
return CreateWhereOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<const RankOptionsT *>(value);
return CreateRankOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -12587,6 +12690,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new WhereOptionsT(*reinterpret_cast<WhereOptionsT *>(u.value));
break;
}
case BuiltinOptions_RankOptions: {
value = new RankOptionsT(*reinterpret_cast<RankOptionsT *>(u.value));
break;
}
default:
break;
}
@ -13019,6 +13126,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_RankOptions: {
auto ptr = reinterpret_cast<RankOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -2264,6 +2264,29 @@ def make_shape_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_rank_tests(zip_path):
"""Make a set of tests to do rank."""
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]],
}]
def build_graph(parameters):
"""Build the rank op testing graph."""
input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input")
out = tf.rank(input_value)
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_one_hot_tests(zip_path):
"""Make a set of tests to do one_hot."""

View File

@ -1305,7 +1305,8 @@ void ConvertTensorFlowShapeOperator(const Model& model,
GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertRankOperator(const Model& model, const RankOperator& src_op,
void ConvertRankOperator(const Model& model,
const TensorFlowRankOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
rank_op->set_op("Rank");
@ -2274,7 +2275,8 @@ void ConvertOperator(const Model& model, const Operator& src_op,
model, static_cast<const TensorFlowShapeOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRank) {
ConvertRankOperator(model, static_cast<const RankOperator&>(src_op),
ConvertRankOperator(model,
static_cast<const TensorFlowRankOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRange) {
ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),

View File

@ -1517,7 +1517,7 @@ void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
output_array.copy_shape(output_shape);
}
void ProcessRankOperator(Model* model, RankOperator* op) {
void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
@ -2219,7 +2219,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
break;
case OperatorType::kRank:
ProcessRankOperator(model, static_cast<RankOperator*>(op));
ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
break;
case OperatorType::kShape:
ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));

View File

@ -2472,7 +2472,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
{"RandomUniform", ConvertRandomUniform},
{"Range", ConvertRangeOperator},
{"Rank", ConvertSimpleOperator<RankOperator, 1, 1>},
{"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
{"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
{"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},

View File

@ -24,11 +24,11 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/runtime/types.h"
#include "tensorflow/lite/toco/toco_port.h"
#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
@ -1259,13 +1259,12 @@ struct RangeOperator : Operator {
// Inputs:
// inputs[0]: required: the input array
//
// This operation outputs a 0-D integer tensor representing the rank of
// the input.
// This operation outputs a 0-D int32 Tensor representing the rank of input.
//
// TensorFlow equivalent: Rank. We currently assume that the output is int32
// and not int64. The output type could be stored herein.
struct RankOperator : Operator {
RankOperator() : Operator(OperatorType::kRank) {}
// TensorFlow equivalent: Rank.
struct TensorFlowRankOperator : Operator {
TensorFlowRankOperator() : Operator(OperatorType::kRank) {}
ArrayDataType output_data_type = ArrayDataType::kInt32;
};
// Element-wise negation (-x) operator.

View File

@ -2452,6 +2452,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
"REVERSE_V2", OperatorType::kReverseV2));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
"RANK", OperatorType::kRank));
return ops;
}
} // namespace

View File

@ -154,6 +154,7 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
OperatorType::kReverseV2);
CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
}
TEST_F(OperatorTest, BuiltinAdd) {