parent
1345f18835
commit
845f970ca0
tensorflow/lite
@ -32,7 +32,6 @@ cc_library(
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/toco:graph_transformations",
|
||||
"//tensorflow/lite/toco:model",
|
||||
"//tensorflow/lite/tools/versioning:op_version",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@flatbuffers",
|
||||
],
|
||||
|
@ -36,8 +36,7 @@ class BuiltinOperator : public BaseOperator {
|
||||
using TfLiteOptions = T2;
|
||||
|
||||
BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type)
|
||||
: BaseOperator(::tflite::EnumNameBuiltinOperator(op), type),
|
||||
builtin_op_(op) {}
|
||||
: BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {}
|
||||
|
||||
// Build the configuration object in the given flatbuffer builder. Return
|
||||
// its offset.
|
||||
@ -66,16 +65,6 @@ class BuiltinOperator : public BaseOperator {
|
||||
}
|
||||
return std::unique_ptr<Operator>(op.release());
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return ::tflite::GetBuiltinOperatorVersion(
|
||||
GetVersioningOpSig(builtin_op_, op_signature));
|
||||
}
|
||||
|
||||
::tflite::BuiltinOperator builtin_op() const { return builtin_op_; }
|
||||
|
||||
private:
|
||||
const ::tflite::BuiltinOperator builtin_op_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
@ -94,9 +93,8 @@ class BaseOperator {
|
||||
// * The first version for each op should be 1 (to be consistent with the
|
||||
// default value in Flatbuffer. `return 1;` is okay for newly implemented
|
||||
// ops.
|
||||
// * When multiple versions are defined for an op, this function could be
|
||||
// overridden. (See example in `operator_test.cc` and
|
||||
// 'tools/versioning/op_version.cc`)
|
||||
// * When multiple versions are defined for an op, this function needs to be
|
||||
// overridden. (See example in `operator_test.cc`)
|
||||
virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
|
||||
|
||||
// Given a Toco `Operator`, return a list of booleans indicating the op
|
||||
@ -115,11 +113,6 @@ class BaseOperator {
|
||||
OperatorType type_;
|
||||
};
|
||||
|
||||
// Helper function to create ::tflite::OpSignature from the given
|
||||
// ::tflite::BuiltinOperator and OperatorSignature.
|
||||
::tflite::OpSignature GetVersioningOpSig(const ::tflite::BuiltinOperator op,
|
||||
const OperatorSignature& op_signature);
|
||||
|
||||
// Helper function to determine if a unsupported TensorFlow op should be
|
||||
// exported as an Flex op or a regular custom op.
|
||||
bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
|
||||
|
@ -32,11 +32,6 @@ template <typename T>
|
||||
class SimpleOperator : public BaseOperator {
|
||||
public:
|
||||
using BaseOperator::BaseOperator;
|
||||
|
||||
SimpleOperator(::tflite::BuiltinOperator op, OperatorType type)
|
||||
: BaseOperator(::tflite::EnumNameBuiltinOperator(op), type),
|
||||
builtin_op_(op) {}
|
||||
|
||||
Options Serialize(const Operator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
return Options();
|
||||
@ -48,14 +43,8 @@ class SimpleOperator : public BaseOperator {
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return ::tflite::GetBuiltinOperatorVersion(
|
||||
GetVersioningOpSig(builtin_op_, op_signature));
|
||||
return 1;
|
||||
}
|
||||
|
||||
::tflite::BuiltinOperator builtin_op() const { return builtin_op_; }
|
||||
|
||||
private:
|
||||
const ::tflite::BuiltinOperator builtin_op_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -1,33 +0,0 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_version",
|
||||
srcs = ["op_version.cc"],
|
||||
hdrs = [
|
||||
"op_version.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "op_version_test",
|
||||
srcs = ["op_version_test.cc"],
|
||||
deps = [
|
||||
":op_version",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
@ -1,283 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
switch (op_sig.op) {
|
||||
case BuiltinOperator_CONV_2D:
|
||||
// If the op has signed int8 op_sig.inputs and op_sig.outputs, its
|
||||
// version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_INT8) {
|
||||
return 3;
|
||||
}
|
||||
// If the op is a signed int8 hybrid operation, we need to return
|
||||
// version 2.
|
||||
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_DEPTHWISE_CONV_2D:
|
||||
// If the op has signed int8 op_sig.inputs and op_sig.outputs, its
|
||||
// version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_INT8) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.options.depthwise_conv_2d.dilation_w_factor != 1 ||
|
||||
op_sig.options.depthwise_conv_2d.dilation_h_factor != 1) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_FAKE_QUANT:
|
||||
if (op_sig.options.fakequant.narrow_range) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_FULLY_CONNECTED:
|
||||
// +-----------------+--------------------+--------------------------+
|
||||
// | | Weight::Default | Weight::Shuffled4x16Int8 |
|
||||
// +-----------------+--------------------+--------------------------+
|
||||
// | Float | 1 | 2 |
|
||||
// | Quantized Uint8 | 1 | 2 |
|
||||
// | Hybrid | 3 | 3 |
|
||||
// | Quantized Int8 | 4 | 4 |
|
||||
// +-----------------+--------------------+--------------------------+
|
||||
// 2 op_sig.inputs (no bias) use case is supported starting from
|
||||
// version 6.
|
||||
if (op_sig.input_types.size() == 2) {
|
||||
return 6;
|
||||
}
|
||||
// `keep_num_dims` is supported at verison 5.
|
||||
if (op_sig.options.fully_connected.keep_num_dims) {
|
||||
return 5;
|
||||
}
|
||||
// Int8 fully fixed point kernel is at version 4.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_INT8) {
|
||||
return 4;
|
||||
}
|
||||
// If the op is a signed int8 hybrid operation, we need to return
|
||||
// version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||
return 3;
|
||||
}
|
||||
// For float and uint8 fixed point kernels, if the weight is
|
||||
// Shuffled4x16Int8, is is version 2.
|
||||
if (op_sig.options.fully_connected.weights_format ==
|
||||
FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8) {
|
||||
return 2;
|
||||
}
|
||||
// Otherwise (weight is default), the version is 1.
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_GATHER:
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SVDF:
|
||||
// If the op is a signed int8 hybrid operation, we need to return
|
||||
// version 2.
|
||||
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||
op_sig.input_types.at(1) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_MUL:
|
||||
// Version 3 supports have a rescale value greater than or equal to 1.
|
||||
if (op_sig.options.mul.input1_scale != 0 &&
|
||||
op_sig.options.mul.input2_scale != 0 &&
|
||||
op_sig.options.mul.output_scale != 0 &&
|
||||
(op_sig.options.mul.input1_scale * op_sig.options.mul.input2_scale /
|
||||
op_sig.options.mul.output_scale) >= 1.0) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_TRANSPOSE:
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_LSTM:
|
||||
// If the input tensor is float and a weight is int8, this is a version
|
||||
// 3 hybrid operation.
|
||||
if (op_sig.options.lstm.kernel_type == LSTMKernelType_FULL &&
|
||||
op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||
op_sig.input_types.at(2) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||
return 3;
|
||||
}
|
||||
// KERNEL_BASIC was added in version 2.
|
||||
if (op_sig.options.lstm.kernel_type == LSTMKernelType_BASIC) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
|
||||
// If the input tensor is float and a weight is int8, this is a version
|
||||
// 2 hybrid operation.
|
||||
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
|
||||
op_sig.input_types.at(2) == TensorType_INT8 &&
|
||||
op_sig.output_types.at(0) == TensorType_FLOAT32) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SPLIT:
|
||||
// If the op take int8 input, it is version 2, for int32 it's version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT32) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SPARSE_TO_DENSE:
|
||||
// Version 3 supports Int8 and Uint8 type.
|
||||
if (op_sig.input_types.at(2) == TensorType_INT8 ||
|
||||
op_sig.input_types.at(2) == TensorType_UINT8) {
|
||||
return 3;
|
||||
}
|
||||
// Version 2 supports Int64 value type.
|
||||
if (op_sig.input_types.at(2) == TensorType_INT64) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SLICE:
|
||||
// Version 3 supports string input types.
|
||||
if (op_sig.input_types.at(0) == TensorType_STRING) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_UNPACK:
|
||||
// If the op take int8/uint8 input, it is version 2.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 ||
|
||||
op_sig.input_types.at(0) == TensorType_UINT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_DEQUANTIZE:
|
||||
// Version 3 supports signed int16 input types.
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16 ||
|
||||
op_sig.input_types.at(0) == TensorType_FLOAT16) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_FLOOR_DIV:
|
||||
if (op_sig.input_types.at(0) == TensorType_FLOAT32) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_L2_NORMALIZATION:
|
||||
if (op_sig.output_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_AVERAGE_POOL_2D:
|
||||
case BuiltinOperator_ADD:
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
case BuiltinOperator_SUB:
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
case BuiltinOperator_CONCATENATION:
|
||||
case BuiltinOperator_MAX_POOL_2D:
|
||||
case BuiltinOperator_MAXIMUM:
|
||||
case BuiltinOperator_MINIMUM:
|
||||
case BuiltinOperator_PAD:
|
||||
case BuiltinOperator_PADV2:
|
||||
case BuiltinOperator_SOFTMAX:
|
||||
case BuiltinOperator_SPACE_TO_DEPTH:
|
||||
case BuiltinOperator_MEAN:
|
||||
case BuiltinOperator_SUM:
|
||||
case BuiltinOperator_REDUCE_MAX:
|
||||
case BuiltinOperator_REDUCE_MIN:
|
||||
case BuiltinOperator_RELU6:
|
||||
case BuiltinOperator_RESIZE_BILINEAR:
|
||||
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
|
||||
case BuiltinOperator_PACK:
|
||||
case BuiltinOperator_TANH:
|
||||
case BuiltinOperator_LOGISTIC:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
case BuiltinOperator_ARG_MAX:
|
||||
case BuiltinOperator_ARG_MIN:
|
||||
case BuiltinOperator_EQUAL:
|
||||
case BuiltinOperator_NOT_EQUAL:
|
||||
case BuiltinOperator_GREATER:
|
||||
case BuiltinOperator_GREATER_EQUAL:
|
||||
case BuiltinOperator_LESS:
|
||||
case BuiltinOperator_LESS_EQUAL:
|
||||
case BuiltinOperator_SELECT:
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -1,57 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// OpSignature contains operator parameters for version functions.
|
||||
typedef struct {
|
||||
BuiltinOperator op;
|
||||
std::vector<TensorType> input_types;
|
||||
std::vector<TensorType> output_types;
|
||||
union {
|
||||
struct {
|
||||
int32_t dilation_w_factor;
|
||||
int32_t dilation_h_factor;
|
||||
} depthwise_conv_2d;
|
||||
struct {
|
||||
bool narrow_range;
|
||||
} fakequant;
|
||||
struct {
|
||||
bool keep_num_dims;
|
||||
FullyConnectedOptionsWeightsFormat weights_format;
|
||||
} fully_connected;
|
||||
struct {
|
||||
float input1_scale;
|
||||
float input2_scale;
|
||||
float output_scale;
|
||||
} mul;
|
||||
struct {
|
||||
LSTMKernelType kernel_type;
|
||||
} lstm;
|
||||
} options;
|
||||
} OpSignature;
|
||||
|
||||
// Returns version of builtin ops by the given signature.
|
||||
int GetBuiltinOperatorVersion(const OpSignature& op_sig);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
|
@ -1,334 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
TEST(OpVersionTest, VersioningSpareToDense) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_SPARSE_TO_DENSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8,
|
||||
TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SPARSE_TO_DENSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8,
|
||||
TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SPARSE_TO_DENSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT64, TensorType_INT64,
|
||||
TensorType_INT64},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SPARSE_TO_DENSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32, TensorType_INT32,
|
||||
TensorType_INT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
// Test version for a simple Op with 2 versions and the input type controls the
|
||||
// version.
|
||||
void SimpleVersioningTest(BuiltinOperator op) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = op,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = op,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
// Test version for a simple Op with 2 versions and the output type controls the
|
||||
void SimpleOutputVersioningTest(BuiltinOperator op) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = op,
|
||||
.input_types = std::vector<TensorType>{},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = op,
|
||||
.input_types = std::vector<TensorType>{},
|
||||
.output_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_EQUAL);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningNotEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningLessTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_LESS);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningLessEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_LESS_EQUAL);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningGreaterTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_GREATER);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningGreaterEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_GREATER_EQUAL);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSpaceToBatchNDTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningLogSoftmaxTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_LOG_SOFTMAX);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningPackTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_PACK);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningUnpackTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_UNPACK,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_UNPACK,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_UNPACK,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningBatchToSpaceNDTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_BATCH_TO_SPACE_ND);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningTanhTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_TANH);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningStridedSliceTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_STRIDED_SLICE);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSpaceToDepthTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSliceTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_STRING},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_SLICE,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningLogisticTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningL2NormTest) {
|
||||
SimpleOutputVersioningTest(BuiltinOperator_L2_NORMALIZATION);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMaxTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_MAXIMUM);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMinTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_MINIMUM);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMeanTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_MEAN);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSumTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_SUM);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningAddTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_ADD);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSubTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_SUB);
|
||||
}
|
||||
|
||||
void SimpleMulVersioningTest(TensorType data_type, float multiplier,
|
||||
int version) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_MUL,
|
||||
.input_types = std::vector<TensorType>{data_type, data_type},
|
||||
.output_types = std::vector<TensorType>{data_type},
|
||||
};
|
||||
fake_op_sig.options.mul = {1.0f, 1.0f, 1.0f / multiplier};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), version);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMulTest) {
|
||||
SimpleMulVersioningTest(TensorType_UINT8, 0.5f, 1);
|
||||
SimpleMulVersioningTest(TensorType_INT8, 0.5f, 2);
|
||||
SimpleMulVersioningTest(TensorType_INT8, 2.0f, 3);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningPadTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_PAD);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningPadV2Test) {
|
||||
SimpleVersioningTest(BuiltinOperator_PADV2);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningConcatenationTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_CONCATENATION);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSelectTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_SELECT);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningRelu6Test) {
|
||||
SimpleVersioningTest(BuiltinOperator_RELU6);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningFullyConnectedTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_FULLY_CONNECTED,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
fake_op_sig.options.fully_connected = {
|
||||
false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_FULLY_CONNECTED,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
fake_op_sig.options.fully_connected = {
|
||||
false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningDequantizeTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_DEQUANTIZE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT16},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_DEQUANTIZE,
|
||||
.input_types = std::vector<TensorType>{TensorType_FLOAT16},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_DEQUANTIZE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_DEQUANTIZE,
|
||||
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningConv2DTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_CONV_2D,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_CONV_2D,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_CONV_2D,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningFloorDivOperatorTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_FLOOR_DIV,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_FLOOR_DIV,
|
||||
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user