Automated rollback of commit c8d3be804a

PiperOrigin-RevId: 269617801
This commit is contained in:
Jared Duke 2019-09-17 11:42:05 -07:00 committed by TensorFlower Gardener
parent 1345f18835
commit 845f970ca0
9 changed files with 870 additions and 925 deletions

View File

@ -32,7 +32,6 @@ cc_library(
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/toco:graph_transformations",
"//tensorflow/lite/toco:model",
"//tensorflow/lite/tools/versioning:op_version",
"@com_google_absl//absl/memory",
"@flatbuffers",
],

View File

@ -36,8 +36,7 @@ class BuiltinOperator : public BaseOperator {
using TfLiteOptions = T2;
BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type)
: BaseOperator(::tflite::EnumNameBuiltinOperator(op), type),
builtin_op_(op) {}
: BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {}
// Build the configuration object in the given flatbuffer builder. Return
// its offset.
@ -66,16 +65,6 @@ class BuiltinOperator : public BaseOperator {
}
return std::unique_ptr<Operator>(op.release());
}
int GetVersion(const OperatorSignature& op_signature) const override {
return ::tflite::GetBuiltinOperatorVersion(
GetVersioningOpSig(builtin_op_, op_signature));
}
::tflite::BuiltinOperator builtin_op() const { return builtin_op_; }
private:
const ::tflite::BuiltinOperator builtin_op_;
};
} // namespace tflite

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
namespace toco {
@ -94,9 +93,8 @@ class BaseOperator {
// * The first version for each op should be 1 (to be consistent with the
// default value in Flatbuffer. `return 1;` is okay for newly implemented
// ops.
// * When multiple versions are defined for an op, this function could be
// overridden. (See example in `operator_test.cc` and
// 'tools/versioning/op_version.cc`)
// * When multiple versions are defined for an op, this function needs to be
// overridden. (See example in `operator_test.cc`)
virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
// Given a Toco `Operator`, return a list of booleans indicating the op
@ -115,11 +113,6 @@ class BaseOperator {
OperatorType type_;
};
// Helper function to create ::tflite::OpSignature from the given
// ::tflite::BuiltinOperator and OperatorSignature.
::tflite::OpSignature GetVersioningOpSig(const ::tflite::BuiltinOperator op,
const OperatorSignature& op_signature);
// Helper function to determine if a unsupported TensorFlow op should be
// exported as an Flex op or a regular custom op.
bool ShouldExportAsFlexOp(bool enable_select_tf_ops,

View File

@ -32,11 +32,6 @@ template <typename T>
class SimpleOperator : public BaseOperator {
public:
using BaseOperator::BaseOperator;
SimpleOperator(::tflite::BuiltinOperator op, OperatorType type)
: BaseOperator(::tflite::EnumNameBuiltinOperator(op), type),
builtin_op_(op) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return Options();
@ -48,14 +43,8 @@ class SimpleOperator : public BaseOperator {
}
int GetVersion(const OperatorSignature& op_signature) const override {
return ::tflite::GetBuiltinOperatorVersion(
GetVersioningOpSig(builtin_op_, op_signature));
return 1;
}
::tflite::BuiltinOperator builtin_op() const { return builtin_op_; }
private:
const ::tflite::BuiltinOperator builtin_op_;
};
} // namespace tflite

View File

@ -1,33 +0,0 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "op_version",
srcs = ["op_version.cc"],
hdrs = [
"op_version.h",
],
deps = [
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
)
tf_cc_test(
name = "op_version_test",
srcs = ["op_version_test.cc"],
deps = [
":op_version",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -1,283 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/versioning/op_version.h"
#include <cstring>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
namespace tflite {
int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
switch (op_sig.op) {
case BuiltinOperator_CONV_2D:
// If the op has signed int8 op_sig.inputs and op_sig.outputs, its
// version 3.
if (op_sig.input_types.at(0) == TensorType_INT8 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_INT8) {
return 3;
}
// If the op is a signed int8 hybrid operation, we need to return
// version 2.
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_FLOAT32) {
return 2;
}
return 1;
case BuiltinOperator_DEPTHWISE_CONV_2D:
// If the op has signed int8 op_sig.inputs and op_sig.outputs, its
// version 3.
if (op_sig.input_types.at(0) == TensorType_INT8 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_INT8) {
return 3;
}
if (op_sig.options.depthwise_conv_2d.dilation_w_factor != 1 ||
op_sig.options.depthwise_conv_2d.dilation_h_factor != 1) {
return 2;
}
return 1;
case BuiltinOperator_FAKE_QUANT:
if (op_sig.options.fakequant.narrow_range) {
return 2;
}
return 1;
case BuiltinOperator_FULLY_CONNECTED:
// +-----------------+--------------------+--------------------------+
// | | Weight::Default | Weight::Shuffled4x16Int8 |
// +-----------------+--------------------+--------------------------+
// | Float | 1 | 2 |
// | Quantized Uint8 | 1 | 2 |
// | Hybrid | 3 | 3 |
// | Quantized Int8 | 4 | 4 |
// +-----------------+--------------------+--------------------------+
// 2 op_sig.inputs (no bias) use case is supported starting from
// version 6.
if (op_sig.input_types.size() == 2) {
return 6;
}
// `keep_num_dims` is supported at verison 5.
if (op_sig.options.fully_connected.keep_num_dims) {
return 5;
}
// Int8 fully fixed point kernel is at version 4.
if (op_sig.input_types.at(0) == TensorType_INT8 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_INT8) {
return 4;
}
// If the op is a signed int8 hybrid operation, we need to return
// version 3.
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_FLOAT32) {
return 3;
}
// For float and uint8 fixed point kernels, if the weight is
// Shuffled4x16Int8, is is version 2.
if (op_sig.options.fully_connected.weights_format ==
FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8) {
return 2;
}
// Otherwise (weight is default), the version is 1.
return 1;
case BuiltinOperator_GATHER:
// If the op takes bool input, it is version 3.
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_SVDF:
// If the op is a signed int8 hybrid operation, we need to return
// version 2.
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
op_sig.input_types.at(1) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_FLOAT32) {
return 2;
}
return 1;
case BuiltinOperator_MUL:
// Version 3 supports have a rescale value greater than or equal to 1.
if (op_sig.options.mul.input1_scale != 0 &&
op_sig.options.mul.input2_scale != 0 &&
op_sig.options.mul.output_scale != 0 &&
(op_sig.options.mul.input1_scale * op_sig.options.mul.input2_scale /
op_sig.options.mul.output_scale) >= 1.0) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_TRANSPOSE:
// If the op takes bool input, it is version 3.
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_LSTM:
// If the input tensor is float and a weight is int8, this is a version
// 3 hybrid operation.
if (op_sig.options.lstm.kernel_type == LSTMKernelType_FULL &&
op_sig.input_types.at(0) == TensorType_FLOAT32 &&
op_sig.input_types.at(2) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_FLOAT32) {
return 3;
}
// KERNEL_BASIC was added in version 2.
if (op_sig.options.lstm.kernel_type == LSTMKernelType_BASIC) {
return 2;
}
return 1;
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
// If the input tensor is float and a weight is int8, this is a version
// 2 hybrid operation.
if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
op_sig.input_types.at(2) == TensorType_INT8 &&
op_sig.output_types.at(0) == TensorType_FLOAT32) {
return 2;
}
return 1;
case BuiltinOperator_SPLIT:
// If the op take int8 input, it is version 2, for int32 it's version 3.
if (op_sig.input_types.at(0) == TensorType_INT32) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_SPARSE_TO_DENSE:
// Version 3 supports Int8 and Uint8 type.
if (op_sig.input_types.at(2) == TensorType_INT8 ||
op_sig.input_types.at(2) == TensorType_UINT8) {
return 3;
}
// Version 2 supports Int64 value type.
if (op_sig.input_types.at(2) == TensorType_INT64) {
return 2;
}
return 1;
case BuiltinOperator_SLICE:
// Version 3 supports string input types.
if (op_sig.input_types.at(0) == TensorType_STRING) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_UNPACK:
// If the op take int8/uint8 input, it is version 2.
if (op_sig.input_types.at(0) == TensorType_INT8 ||
op_sig.input_types.at(0) == TensorType_UINT8) {
return 2;
}
return 1;
case BuiltinOperator_DEQUANTIZE:
// Version 3 supports signed int16 input types.
if (op_sig.input_types.at(0) == TensorType_INT16 ||
op_sig.input_types.at(0) == TensorType_FLOAT16) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_FLOOR_DIV:
if (op_sig.input_types.at(0) == TensorType_FLOAT32) {
return 2;
}
return 1;
case BuiltinOperator_L2_NORMALIZATION:
if (op_sig.output_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_ADD:
case BuiltinOperator_SPACE_TO_BATCH_ND:
case BuiltinOperator_SUB:
case BuiltinOperator_BATCH_TO_SPACE_ND:
case BuiltinOperator_CONCATENATION:
case BuiltinOperator_MAX_POOL_2D:
case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM:
case BuiltinOperator_PAD:
case BuiltinOperator_PADV2:
case BuiltinOperator_SOFTMAX:
case BuiltinOperator_SPACE_TO_DEPTH:
case BuiltinOperator_MEAN:
case BuiltinOperator_SUM:
case BuiltinOperator_REDUCE_MAX:
case BuiltinOperator_REDUCE_MIN:
case BuiltinOperator_RELU6:
case BuiltinOperator_RESIZE_BILINEAR:
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
case BuiltinOperator_PACK:
case BuiltinOperator_TANH:
case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_STRIDED_SLICE:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_ARG_MAX:
case BuiltinOperator_ARG_MIN:
case BuiltinOperator_EQUAL:
case BuiltinOperator_NOT_EQUAL:
case BuiltinOperator_GREATER:
case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:
case BuiltinOperator_SELECT:
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
default:
return 1;
}
}
} // namespace tflite

View File

@ -1,57 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
#define TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_
#include <vector>
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
// OpSignature contains operator parameters for version functions.
typedef struct {
BuiltinOperator op;
std::vector<TensorType> input_types;
std::vector<TensorType> output_types;
union {
struct {
int32_t dilation_w_factor;
int32_t dilation_h_factor;
} depthwise_conv_2d;
struct {
bool narrow_range;
} fakequant;
struct {
bool keep_num_dims;
FullyConnectedOptionsWeightsFormat weights_format;
} fully_connected;
struct {
float input1_scale;
float input2_scale;
float output_scale;
} mul;
struct {
LSTMKernelType kernel_type;
} lstm;
} options;
} OpSignature;
// Returns version of builtin ops by the given signature.
int GetBuiltinOperatorVersion(const OpSignature& op_sig);
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_VERSIONING_OP_VERSION_H_

View File

@ -1,334 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/versioning/op_version.h"
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace tflite {
TEST(OpVersionTest, VersioningSpareToDense) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_SPARSE_TO_DENSE,
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8,
TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_SPARSE_TO_DENSE,
.input_types = std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8,
TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_SPARSE_TO_DENSE,
.input_types = std::vector<TensorType>{TensorType_INT64, TensorType_INT64,
TensorType_INT64},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_SPARSE_TO_DENSE,
.input_types = std::vector<TensorType>{TensorType_INT32, TensorType_INT32,
TensorType_INT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
// Test version for a simple Op with 2 versions and the input type controls the
// version.
void SimpleVersioningTest(BuiltinOperator op) {
OpSignature fake_op_sig = {
.op = op,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = op,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
// Test version for a simple Op with 2 versions and the output type controls the
void SimpleOutputVersioningTest(BuiltinOperator op) {
OpSignature fake_op_sig = {
.op = op,
.input_types = std::vector<TensorType>{},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = op,
.input_types = std::vector<TensorType>{},
.output_types = std::vector<TensorType>{TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
TEST(OpVersionTest, VersioningEqualTest) {
SimpleVersioningTest(BuiltinOperator_EQUAL);
}
TEST(OpVersionTest, VersioningNotEqualTest) {
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
}
TEST(OpVersionTest, VersioningLessTest) {
SimpleVersioningTest(BuiltinOperator_LESS);
}
TEST(OpVersionTest, VersioningLessEqualTest) {
SimpleVersioningTest(BuiltinOperator_LESS_EQUAL);
}
TEST(OpVersionTest, VersioningGreaterTest) {
SimpleVersioningTest(BuiltinOperator_GREATER);
}
TEST(OpVersionTest, VersioningGreaterEqualTest) {
SimpleVersioningTest(BuiltinOperator_GREATER_EQUAL);
}
TEST(OpVersionTest, VersioningSpaceToBatchNDTest) {
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
}
TEST(OpVersionTest, VersioningLogSoftmaxTest) {
SimpleVersioningTest(BuiltinOperator_LOG_SOFTMAX);
}
TEST(OpVersionTest, VersioningPackTest) {
SimpleVersioningTest(BuiltinOperator_PACK);
}
TEST(OpVersionTest, VersioningUnpackTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_UNPACK,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_UNPACK,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_UNPACK,
.input_types = std::vector<TensorType>{TensorType_INT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
TEST(OpVersionTest, VersioningBatchToSpaceNDTest) {
SimpleVersioningTest(BuiltinOperator_BATCH_TO_SPACE_ND);
}
TEST(OpVersionTest, VersioningTanhTest) {
SimpleVersioningTest(BuiltinOperator_TANH);
}
TEST(OpVersionTest, VersioningStridedSliceTest) {
SimpleVersioningTest(BuiltinOperator_STRIDED_SLICE);
}
TEST(OpVersionTest, VersioningSpaceToDepthTest) {
SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH);
}
TEST(OpVersionTest, VersioningSliceTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_STRING},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_SLICE,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
TEST(OpVersionTest, VersioningLogisticTest) {
SimpleVersioningTest(BuiltinOperator_SPACE_TO_DEPTH);
}
TEST(OpVersionTest, VersioningL2NormTest) {
SimpleOutputVersioningTest(BuiltinOperator_L2_NORMALIZATION);
}
TEST(OpVersionTest, VersioningMaxTest) {
SimpleVersioningTest(BuiltinOperator_MAXIMUM);
}
TEST(OpVersionTest, VersioningMinTest) {
SimpleVersioningTest(BuiltinOperator_MINIMUM);
}
TEST(OpVersionTest, VersioningMeanTest) {
SimpleVersioningTest(BuiltinOperator_MEAN);
}
TEST(OpVersionTest, VersioningSumTest) {
SimpleVersioningTest(BuiltinOperator_SUM);
}
TEST(OpVersionTest, VersioningAddTest) {
SimpleVersioningTest(BuiltinOperator_ADD);
}
TEST(OpVersionTest, VersioningSubTest) {
SimpleVersioningTest(BuiltinOperator_SUB);
}
void SimpleMulVersioningTest(TensorType data_type, float multiplier,
int version) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_MUL,
.input_types = std::vector<TensorType>{data_type, data_type},
.output_types = std::vector<TensorType>{data_type},
};
fake_op_sig.options.mul = {1.0f, 1.0f, 1.0f / multiplier};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), version);
}
TEST(OpVersionTest, VersioningMulTest) {
SimpleMulVersioningTest(TensorType_UINT8, 0.5f, 1);
SimpleMulVersioningTest(TensorType_INT8, 0.5f, 2);
SimpleMulVersioningTest(TensorType_INT8, 2.0f, 3);
}
TEST(OpVersionTest, VersioningPadTest) {
SimpleVersioningTest(BuiltinOperator_PAD);
}
TEST(OpVersionTest, VersioningPadV2Test) {
SimpleVersioningTest(BuiltinOperator_PADV2);
}
TEST(OpVersionTest, VersioningConcatenationTest) {
SimpleVersioningTest(BuiltinOperator_CONCATENATION);
}
TEST(OpVersionTest, VersioningSelectTest) {
SimpleVersioningTest(BuiltinOperator_SELECT);
}
TEST(OpVersionTest, VersioningRelu6Test) {
SimpleVersioningTest(BuiltinOperator_RELU6);
}
TEST(OpVersionTest, VersioningFullyConnectedTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_FULLY_CONNECTED,
.input_types =
std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8},
.output_types = std::vector<TensorType>{TensorType_UINT8},
};
fake_op_sig.options.fully_connected = {
false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
fake_op_sig = {
.op = BuiltinOperator_FULLY_CONNECTED,
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
fake_op_sig.options.fully_connected = {
false, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
}
TEST(OpVersionTest, VersioningDequantizeTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_DEQUANTIZE,
.input_types = std::vector<TensorType>{TensorType_INT16},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_DEQUANTIZE,
.input_types = std::vector<TensorType>{TensorType_FLOAT16},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_DEQUANTIZE,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_DEQUANTIZE,
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
TEST(OpVersionTest, VersioningConv2DTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_CONV_2D,
.input_types =
std::vector<TensorType>{TensorType_UINT8, TensorType_UINT8},
.output_types = std::vector<TensorType>{TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
fake_op_sig = {
.op = BuiltinOperator_CONV_2D,
.input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_CONV_2D,
.input_types =
std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
}
TEST(OpVersionTest, VersioningFloorDivOperatorTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_FLOOR_DIV,
.input_types = std::vector<TensorType>{TensorType_INT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
fake_op_sig = {
.op = BuiltinOperator_FLOOR_DIV,
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
}
} // namespace tflite