1274 lines
48 KiB
C++
1274 lines
48 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include "tensorflow/lite/toco/tflite/operator.h"
|
|
|
|
#include "flatbuffers/flexbuffers.h"
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
|
#include "tensorflow/core/framework/node_def.pb.h"
|
|
#include "tensorflow/lite/toco/model.h"
|
|
#include "tensorflow/lite/toco/tooling_util.h"
|
|
|
|
namespace toco {
|
|
|
|
namespace tflite {
|
|
namespace {
|
|
|
|
class OperatorTest : public ::testing::Test {
|
|
protected:
|
|
// Return the operator for the given name and type.
|
|
const BaseOperator& GetOperator(const std::string& name, OperatorType type) {
|
|
using OpsByName = std::map<std::string, std::unique_ptr<BaseOperator>>;
|
|
using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
|
|
|
|
static auto* by_name = new OpsByName(BuildOperatorByNameMap());
|
|
static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
|
|
|
|
// Make sure the two maps were consistently built.
|
|
CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
|
|
BaseOperator* op1 = by_name->at(name).get();
|
|
CHECK(op1->type() == type) << "while verifying '" << name << "'.";
|
|
|
|
CHECK(by_type->count(type))
|
|
<< "No operator for '" << OperatorTypeName(type) << "'.";
|
|
BaseOperator* op2 = by_type->at(type).get();
|
|
CHECK(op2->name() == name)
|
|
<< "while verifying '" << OperatorTypeName(type) << "'.";
|
|
|
|
return *op1;
|
|
}
|
|
|
|
// Use the given BaseOperator to serialize the tf.mini operator into a set of
|
|
// TF Lite options. Proceed to deserialize the options back into a new
|
|
// tf.mini operator, which is then returned. If `options` is given, it will
|
|
// be populated with the serialized options.
|
|
template <typename T>
|
|
std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
|
|
const T& toco_op,
|
|
Options* options = nullptr) {
|
|
flatbuffers::FlatBufferBuilder builder;
|
|
Options input_options = op.Serialize(toco_op, &builder);
|
|
|
|
if (options) {
|
|
*options = input_options;
|
|
}
|
|
|
|
builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
|
|
input_options.builtin, input_options.custom,
|
|
::tflite::CustomOptionsFormat_FLEXBUFFERS));
|
|
auto* output_options =
|
|
flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
|
|
auto new_toco_op = op.Deserialize(output_options->builtin_options(),
|
|
output_options->custom_options());
|
|
|
|
CHECK(new_toco_op->type == toco_op.type)
|
|
<< "The type of the serialized and deserialized"
|
|
<< HelpfulOperatorTypeName(*new_toco_op)
|
|
<< " does not match the type of the original "
|
|
<< HelpfulOperatorTypeName(toco_op);
|
|
|
|
return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
|
|
}
|
|
|
|
// Verify serialization and deserialization of simple operators (those
|
|
// that don't have any configuration parameters).
|
|
template <typename T>
|
|
void CheckSimpleOperator(const std::string& name, OperatorType type) {
|
|
Options options;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator(name, type), T(), &options);
|
|
|
|
ASSERT_EQ(0, options.builtin.o);
|
|
ASSERT_EQ(0, options.custom.o);
|
|
ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
|
|
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
}
|
|
|
|
template <typename T>
|
|
void CheckReducerOperator(const std::string& name, OperatorType type) {
|
|
T op;
|
|
|
|
op.keep_dims = false;
|
|
|
|
auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
|
|
EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
|
|
}
|
|
};
|
|
|
|
TEST_F(OperatorTest, SimpleOperators) {
|
|
CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
|
|
CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
|
|
CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
|
|
CheckSimpleOperator<RoundOperator>("ROUND", OperatorType::kRound);
|
|
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
|
|
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
|
|
CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
|
|
CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
|
|
CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
|
|
CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
|
|
CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos);
|
|
CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
|
|
OperatorType::kLogSoftmax);
|
|
CheckSimpleOperator<TensorFlowMaximumOperator>(
|
|
"MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum
|
|
CheckSimpleOperator<TensorFlowMinimumOperator>(
|
|
"MINIMUM", OperatorType::kMinimum); // Element-wise Minimum
|
|
CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
|
|
CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
|
|
CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
|
|
CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
|
|
CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
|
|
CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
|
|
CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
|
|
OperatorType::kNotEqual);
|
|
CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
|
|
CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
|
|
CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
|
|
CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
|
|
CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
|
|
OperatorType::kLogicalOr);
|
|
CheckSimpleOperator<LogicalAndOperator>("LOGICAL_AND",
|
|
OperatorType::kLogicalAnd);
|
|
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
|
|
OperatorType::kLogicalNot);
|
|
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
|
|
CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
|
|
OperatorType::kSquare);
|
|
CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
|
|
OperatorType::kZerosLike);
|
|
CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
|
|
CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
|
|
CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
|
|
CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
|
|
OperatorType::kReverseV2);
|
|
CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinAdd) {
|
|
AddOperator op;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinAddN) {
|
|
AddNOperator op;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op);
|
|
ASSERT_NE(output_toco_op.get(), nullptr);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinReducerOps) {
|
|
CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
|
|
CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
|
|
CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
|
|
OperatorType::kReduceProd);
|
|
CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
|
|
OperatorType::kReduceMax);
|
|
CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
|
|
OperatorType::kReduceMin);
|
|
CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinCast) {
|
|
CastOperator op;
|
|
op.src_data_type = ArrayDataType::kFloat;
|
|
op.dst_data_type = ArrayDataType::kUint8;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
|
|
EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
|
|
EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomConcatenation) {
|
|
ConcatenationOperator op;
|
|
op.axis = 123;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
|
|
EXPECT_EQ(op.axis, output_toco_op->axis);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomDepthToSpace) {
|
|
DepthToSpaceOperator op;
|
|
op.block_size = 123;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
|
|
EXPECT_EQ(op.block_size, output_toco_op->block_size);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomFakeQuant) {
|
|
FakeQuantOperator op;
|
|
auto* minmax = new MinMax;
|
|
minmax->min = -10;
|
|
minmax->max = 200;
|
|
op.minmax.reset(minmax);
|
|
op.num_bits = 16;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
|
|
EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
|
|
EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
|
|
EXPECT_EQ(op.num_bits, output_toco_op->num_bits);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomFullyConnected) {
|
|
FullyConnectedOperator op;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinGather) {
|
|
GatherOperator op;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinGatherNd) {
|
|
GatherNdOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("GATHER_ND", OperatorType::kGatherNd), op);
|
|
ASSERT_NE(output_toco_op.get(), nullptr);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinWhere) {
|
|
WhereOperator op;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
|
|
ASSERT_NE(output_toco_op.get(), nullptr);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinL2Pool) {
|
|
L2PoolOperator op;
|
|
op.stride_width = 123;
|
|
op.stride_height = 124;
|
|
op.padding.type = PaddingType::kValid;
|
|
op.kwidth = 480;
|
|
op.kheight = 1080;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
|
|
EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
|
|
EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
|
|
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
|
|
EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
|
|
EXPECT_EQ(op.kheight, output_toco_op->kheight);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
|
|
LocalResponseNormalizationOperator op;
|
|
op.range = 123;
|
|
op.bias = 1.23;
|
|
op.alpha = 12.3;
|
|
op.beta = .123;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("LOCAL_RESPONSE_NORMALIZATION",
|
|
OperatorType::kLocalResponseNormalization),
|
|
op);
|
|
EXPECT_EQ(op.range, output_toco_op->range);
|
|
EXPECT_EQ(op.bias, output_toco_op->bias);
|
|
EXPECT_EQ(op.alpha, output_toco_op->alpha);
|
|
EXPECT_EQ(op.beta, output_toco_op->beta);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinMaxPool) {
|
|
MaxPoolOperator op;
|
|
op.stride_width = 123;
|
|
op.stride_height = 124;
|
|
op.padding.type = PaddingType::kValid;
|
|
op.kwidth = 480;
|
|
op.kheight = 1080;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
|
|
EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
|
|
EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
|
|
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
|
|
EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
|
|
EXPECT_EQ(op.kheight, output_toco_op->kheight);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinReshape) {
|
|
TensorFlowReshapeOperator op;
|
|
op.shape = {1, 2, 4, 5, 8};
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("RESHAPE", OperatorType::kReshape), op);
|
|
EXPECT_EQ(op.shape, output_toco_op->shape);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomSoftmax) {
|
|
SoftmaxOperator op;
|
|
op.beta = 123.1;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
|
|
EXPECT_EQ(op.beta, output_toco_op->beta);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinSpaceToDepth) {
|
|
SpaceToDepthOperator op;
|
|
op.block_size = 123;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
|
|
EXPECT_EQ(op.block_size, output_toco_op->block_size);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomSplit) {
|
|
TensorFlowSplitOperator op;
|
|
op.num_split = 123;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
|
|
EXPECT_EQ(op.num_split, output_toco_op->num_split);
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomSplitV) {
|
|
TensorFlowSplitVOperator op;
|
|
op.num_split = 123;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SPLIT_V", OperatorType::kSplitV), op);
|
|
EXPECT_EQ(op.num_split, output_toco_op->num_split);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinAveragePool) {
|
|
AveragePoolOperator op;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
op.stride_width = 123;
|
|
op.stride_height = 124;
|
|
op.padding.type = PaddingType::kValid;
|
|
op.kwidth = 480;
|
|
op.kheight = 1080;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
|
|
EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
|
|
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
|
|
EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
|
|
EXPECT_EQ(op.kheight, output_toco_op->kheight);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinConvolution) {
|
|
ConvOperator op;
|
|
op.stride_width = 123;
|
|
op.stride_height = 124;
|
|
op.padding.type = PaddingType::kValid;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
|
|
EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
|
|
EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
|
|
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
|
|
DepthwiseConvOperator op;
|
|
op.stride_width = 123;
|
|
op.stride_height = 124;
|
|
op.padding.type = PaddingType::kValid;
|
|
op.depth_multiplier = 6;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
|
|
EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
|
|
EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
|
|
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
|
|
EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinL2Norm) {
|
|
L2NormalizationOperator op;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinMul) {
|
|
MulOperator op;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu6;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
}
|
|
|
|
TEST_F(OperatorTest, ResizeBilinear) {
|
|
ResizeBilinearOperator op;
|
|
op.align_corners = true;
|
|
op.half_pixel_centers = false;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
|
|
EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
|
|
EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
|
|
}
|
|
|
|
TEST_F(OperatorTest, ResizeBilinear_HalfPixelCenters) {
|
|
ResizeBilinearOperator op;
|
|
op.align_corners = true;
|
|
op.half_pixel_centers = true;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
|
|
EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
|
|
EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
|
|
}
|
|
|
|
TEST_F(OperatorTest, ResizeNearestNeighbor) {
|
|
ResizeNearestNeighborOperator op;
|
|
op.align_corners = true;
|
|
op.half_pixel_centers = false;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
|
|
OperatorType::kResizeNearestNeighbor),
|
|
op);
|
|
EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
|
|
EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
|
|
}
|
|
|
|
TEST_F(OperatorTest, ResizeNearestNeighbor_HalfPixelCenters) {
|
|
ResizeNearestNeighborOperator op;
|
|
op.align_corners = true;
|
|
op.half_pixel_centers = true;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
|
|
OperatorType::kResizeNearestNeighbor),
|
|
op);
|
|
EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
|
|
EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
|
|
}
|
|
|
|
TEST_F(OperatorTest, Svdf) {
|
|
SvdfOperator op;
|
|
op.fused_activation_function = FusedActivationFunctionType::kRelu;
|
|
op.rank = 1;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
|
|
EXPECT_EQ(op.fused_activation_function,
|
|
output_toco_op->fused_activation_function);
|
|
EXPECT_EQ(op.rank, output_toco_op->rank);
|
|
}
|
|
|
|
TEST_F(OperatorTest, Squeeze) {
|
|
SqueezeOperator op;
|
|
op.squeeze_dims = {-2, -3, 4, 1, 4};
|
|
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
|
|
EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
|
|
}
|
|
|
|
TEST_F(OperatorTest, StridedSlice) {
|
|
StridedSliceOperator op;
|
|
|
|
op.begin_mask = 1;
|
|
op.end_mask = 2;
|
|
op.ellipsis_mask = 1;
|
|
op.new_axis_mask = 1;
|
|
op.shrink_axis_mask = 2;
|
|
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
|
|
EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
|
|
EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
|
|
EXPECT_EQ(op.strides, output_toco_op->strides);
|
|
EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
|
|
EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
|
|
EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
|
|
EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
|
|
EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
|
|
EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinTopKV2) {
|
|
TopKV2Operator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinArgMax) {
|
|
ArgMaxOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("ARG_MAX", OperatorType::kArgMax), op);
|
|
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinArgMin) {
|
|
ArgMinOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("ARG_MIN", OperatorType::kArgMin), op);
|
|
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinDequantize) {
|
|
DequantizeOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinTransposeConv) {
|
|
TransposeConvOperator op;
|
|
op.stride_width = 123;
|
|
op.stride_height = 124;
|
|
op.padding.type = PaddingType::kValid;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op);
|
|
EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
|
|
EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
|
|
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinShape) {
|
|
TensorFlowShapeOperator op;
|
|
op.output_data_type = ArrayDataType::kInt64;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
|
|
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinSparseToDense) {
|
|
SparseToDenseOperator op;
|
|
op.validate_indices = false;
|
|
std::unique_ptr<toco::SparseToDenseOperator> output_toco_op =
|
|
SerializeAndDeserialize(
|
|
GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op);
|
|
EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSpareToDense) {
|
|
SparseToDenseOperator op;
|
|
op.inputs = {"indices", "output_shape", "input_values", "default_value"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model int32_model;
|
|
Array& int32_array = int32_model.GetOrCreateArray(op.inputs[2]);
|
|
int32_array.data_type = ArrayDataType::kInt32;
|
|
OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
|
|
EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
|
|
|
|
// Expect version 2 for int64 input.
|
|
Model int64_model;
|
|
Array& int64_array = int64_model.GetOrCreateArray(op.inputs[2]);
|
|
int64_array.data_type = ArrayDataType::kInt64;
|
|
OperatorSignature int64_signature = {.op = &op, .model = &int64_model};
|
|
EXPECT_EQ(base_op->GetVersion(int64_signature), 2);
|
|
|
|
// Expect version 3 for int8 and uint8 input.
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[2]);
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 3);
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[2]);
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 3);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinPack) {
|
|
PackOperator op;
|
|
op.values_count = 3;
|
|
op.axis = 1;
|
|
std::unique_ptr<toco::PackOperator> output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op);
|
|
EXPECT_EQ(op.values_count, output_toco_op->values_count);
|
|
EXPECT_EQ(op.axis, output_toco_op->axis);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinOneHot) {
|
|
OneHotOperator op;
|
|
op.axis = 2;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("ONE_HOT", OperatorType::kOneHot), op);
|
|
EXPECT_EQ(op.axis, output_toco_op->axis);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinUnpack) {
|
|
UnpackOperator op;
|
|
op.num = 5;
|
|
op.axis = 2;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
|
|
EXPECT_EQ(op.num, output_toco_op->num);
|
|
EXPECT_EQ(op.axis, output_toco_op->axis);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinLeakyRelu) {
|
|
LeakyReluOperator op;
|
|
op.alpha = 3;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op);
|
|
EXPECT_EQ(op.alpha, output_toco_op->alpha);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinSquaredDifference) {
|
|
SquaredDifferenceOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SQUARED_DIFFERENCE", OperatorType::kSquaredDifference), op);
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinScatterNd) {
|
|
ScatterNdOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SCATTER_ND", OperatorType::kScatterNd), op);
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
}
|
|
|
|
TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
|
|
CTCBeamSearchDecoderOperator op;
|
|
op.beam_width = 3;
|
|
op.top_paths = 2;
|
|
op.merge_repeated = false;
|
|
std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER",
|
|
OperatorType::kCTCBeamSearchDecoder),
|
|
op);
|
|
EXPECT_EQ(op.beam_width, output_toco_op->beam_width);
|
|
EXPECT_EQ(op.top_paths, output_toco_op->top_paths);
|
|
EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated);
|
|
}
|
|
|
|
TEST_F(OperatorTest, TensorFlowUnsupported) {
|
|
TensorFlowUnsupportedOperator op;
|
|
op.tensorflow_op = "MyCustomUnsupportedOp";
|
|
|
|
::tensorflow::NodeDef node_def;
|
|
auto attr = node_def.mutable_attr();
|
|
(*attr)["float_attr"].set_f(2.0);
|
|
(*attr)["str_attr"].set_s("Hello World");
|
|
(*attr)["int_attr"].set_i(17);
|
|
(*attr)["bool_attr"].set_b(true);
|
|
{
|
|
auto* list = (*attr)["list_string_attr"].mutable_list();
|
|
list->add_s("abcde");
|
|
list->add_s("1234");
|
|
list->add_s("");
|
|
list->add_s("zyxwv");
|
|
list->add_s("!-.");
|
|
}
|
|
{
|
|
auto* list = (*attr)["list_float_attr"].mutable_list();
|
|
list->add_f(std::numeric_limits<float>::min());
|
|
list->add_f(2.0);
|
|
list->add_f(-std::numeric_limits<float>::max());
|
|
}
|
|
{
|
|
auto* list = (*attr)["list_int_attr"].mutable_list();
|
|
list->add_i(1);
|
|
list->add_i(20);
|
|
list->add_i(1LL << 40);
|
|
list->add_i(-(1LL << 40));
|
|
}
|
|
node_def.SerializeToString(&op.tensorflow_node_def);
|
|
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
|
|
|
|
::tensorflow::NodeDef output_node_def;
|
|
output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
|
|
const auto& output_attr = output_node_def.attr();
|
|
EXPECT_EQ(2.0, output_attr.at("float_attr").f());
|
|
EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
|
|
EXPECT_EQ(17, output_attr.at("int_attr").i());
|
|
EXPECT_EQ(true, output_attr.at("bool_attr").b());
|
|
{
|
|
const auto& list = output_attr.at("list_string_attr").list();
|
|
ASSERT_EQ(5, list.s_size());
|
|
EXPECT_EQ("abcde", list.s(0));
|
|
EXPECT_EQ("1234", list.s(1));
|
|
EXPECT_EQ("", list.s(2));
|
|
EXPECT_EQ("zyxwv", list.s(3));
|
|
EXPECT_EQ("!-.", list.s(4));
|
|
}
|
|
{
|
|
const auto& list = output_attr.at("list_float_attr").list();
|
|
ASSERT_EQ(3, list.f_size());
|
|
EXPECT_EQ(std::numeric_limits<float>::min(), list.f(0));
|
|
EXPECT_EQ(2.0, list.f(1));
|
|
EXPECT_EQ(-std::numeric_limits<float>::max(), list.f(2));
|
|
}
|
|
{
|
|
const auto& list = output_attr.at("list_int_attr").list();
|
|
ASSERT_EQ(4, list.i_size());
|
|
EXPECT_EQ(1, list.i(0));
|
|
EXPECT_EQ(20, list.i(1));
|
|
EXPECT_EQ(1LL << 40, list.i(2));
|
|
EXPECT_EQ(-(1LL << 40), list.i(3));
|
|
}
|
|
}
|
|
|
|
TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
|
|
TensorFlowUnsupportedOperator op;
|
|
op.tensorflow_op = "MyCustomUnsupportedOp";
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
|
|
|
|
::tensorflow::NodeDef output_node_def;
|
|
output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
|
|
EXPECT_TRUE(output_node_def.attr().empty());
|
|
}
|
|
|
|
TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
|
|
EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
|
|
EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
|
|
EXPECT_TRUE(ShouldExportAsFlexOp(true, "EluGrad"));
|
|
EXPECT_TRUE(ShouldExportAsFlexOp(true, "RFFT"));
|
|
EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
|
|
// While the RandomShuffle op is available on desktop, it is not in the kernel
|
|
// set available on mobile and should be excluded.
|
|
EXPECT_FALSE(ShouldExportAsFlexOp(true, "RandomShuffle"));
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinMirrorPad) {
|
|
MirrorPadOperator op;
|
|
op.mode = MirrorPadMode::kReflect;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("MIRROR_PAD", OperatorType::kMirrorPad), op);
|
|
EXPECT_EQ(op.mode, output_toco_op->mode);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinUnique) {
|
|
UniqueOperator op;
|
|
op.idx_out_type = ArrayDataType::kInt64;
|
|
auto output_toco_op =
|
|
SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinSegmentSum) {
|
|
SegmentSumOperator op;
|
|
auto output_toco_op = SerializeAndDeserialize(
|
|
GetOperator("SEGMENT_SUM", OperatorType::kSegmentSum), op);
|
|
ASSERT_NE(nullptr, output_toco_op.get());
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinReverseSequence) {
|
|
ReverseSequenceOperator op;
|
|
op.seq_dim = 3;
|
|
op.batch_dim = 1;
|
|
std::unique_ptr<toco::ReverseSequenceOperator> output_toco_op =
|
|
SerializeAndDeserialize(
|
|
GetOperator("REVERSE_SEQUENCE", OperatorType::kReverseSequence), op);
|
|
EXPECT_EQ(op.seq_dim, output_toco_op->seq_dim);
|
|
EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinMatrixDiag) {
|
|
MatrixDiagOperator op;
|
|
std::unique_ptr<toco::MatrixDiagOperator> output_toco_op =
|
|
SerializeAndDeserialize(
|
|
GetOperator("MATRIX_DIAG", OperatorType::kMatrixDiag), op);
|
|
}
|
|
|
|
TEST_F(OperatorTest, BuiltinMatrixSetDiag) {
|
|
MatrixSetDiagOperator op;
|
|
std::unique_ptr<toco::MatrixSetDiagOperator> output_toco_op =
|
|
SerializeAndDeserialize(
|
|
GetOperator("MATRIX_SET_DIAG", OperatorType::kMatrixSetDiag), op);
|
|
}
|
|
|
|
// Test version for a simple Op with 2 versions and the input type controls the
|
|
// version.
|
|
template <typename Op>
|
|
void SimpleVersioningTest() {
|
|
Op op;
|
|
op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
|
|
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
|
}
|
|
|
|
// Test version for a simple Op with 2 versions and the output type controls the
|
|
// version.
|
|
template <typename Op>
|
|
void SimpleOutputVersioningTest() {
|
|
Op op;
|
|
op.outputs = {"output1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
|
|
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningEqualTest) {
|
|
SimpleVersioningTest<TensorFlowEqualOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningNotEqualTest) {
|
|
SimpleVersioningTest<TensorFlowNotEqualOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningLessTest) {
|
|
SimpleVersioningTest<TensorFlowLessOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningLessEqualTest) {
|
|
SimpleVersioningTest<TensorFlowLessEqualOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningGreaterTest) {
|
|
SimpleVersioningTest<TensorFlowGreaterOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningGreaterEqualTest) {
|
|
SimpleVersioningTest<TensorFlowGreaterEqualOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
|
|
SpaceToBatchNDOperator op;
|
|
op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
|
uint8_array.copy_shape({1, 2, 2, 2});
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
|
|
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
|
int8_array.copy_shape({1, 2, 2, 2});
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
|
|
|
Model float_model;
|
|
Array& float_array = float_model.GetOrCreateArray(op.inputs[0]);
|
|
float_array.copy_shape({1, 2, 2});
|
|
float_array.data_type = ArrayDataType::kFloat;
|
|
OperatorSignature float_signature = {.op = &op, .model = &float_model};
|
|
EXPECT_EQ(base_op->GetVersion(float_signature), 3);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
|
|
SimpleVersioningTest<LogSoftmaxOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningPackTest) {
|
|
SimpleVersioningTest<PackOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningUnpackTest) {
|
|
UnpackOperator op;
|
|
op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model int32_model;
|
|
Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]);
|
|
int32_array.data_type = ArrayDataType::kInt32;
|
|
OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
|
|
EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 2);
|
|
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
|
|
BatchToSpaceNDOperator op;
|
|
op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
uint8_array.copy_shape({1, 2, 2, 2});
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
|
|
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
int8_array.copy_shape({1, 2, 2, 2});
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
|
|
|
Model float_model;
|
|
Array& float_array = float_model.GetOrCreateArray(op.inputs[0]);
|
|
float_array.copy_shape({1, 2, 2});
|
|
float_array.data_type = ArrayDataType::kFloat;
|
|
OperatorSignature float_signature = {.op = &op, .model = &float_model};
|
|
EXPECT_EQ(base_op->GetVersion(float_signature), 3);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningTanhTest) {
|
|
SimpleVersioningTest<TanhOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningStridedSliceTest) {
|
|
StridedSliceOperator op;
|
|
op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
|
uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
|
|
|
|
Model int8_model;
|
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
|
int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
|
|
|
Model bool_model;
|
|
Array& bool_array = bool_model.GetOrCreateArray(op.inputs[0]);
|
|
bool_array.data_type = ArrayDataType::kBool;
|
|
OperatorSignature bool_signature = {.op = &op, .model = &bool_model};
|
|
EXPECT_EQ(base_op->GetVersion(bool_signature), 3);
|
|
|
|
op.start_indices = {0, 0, 0, 0, 0};
|
|
op.stop_indices = {1, 2, 2, 2, 2};
|
|
op.strides = {1, 1, 1, 1, 1};
|
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 4);
|
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 4);
|
|
EXPECT_EQ(base_op->GetVersion(bool_signature), 4);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
|
|
SimpleVersioningTest<SpaceToDepthOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSliceTest) {
|
|
SimpleVersioningTest<SliceOperator>();
|
|
|
|
// Check that a string input results in a version 3 op.
|
|
SliceOperator op;
|
|
op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model string_model;
|
|
Array& string_array = string_model.GetOrCreateArray(op.inputs[0]);
|
|
string_array.data_type = ArrayDataType::kString;
|
|
OperatorSignature string_signature = {.op = &op, .model = &string_model};
|
|
EXPECT_EQ(base_op->GetVersion(string_signature), 3);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningLogisticTest) {
|
|
SimpleVersioningTest<LogisticOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningL2NormTest) {
|
|
SimpleOutputVersioningTest<L2NormalizationOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningMaxTest) {
|
|
SimpleVersioningTest<TensorFlowMaximumOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningMinTest) {
|
|
SimpleVersioningTest<TensorFlowMinimumOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningMeanTest) {
|
|
SimpleVersioningTest<MeanOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSumTest) {
|
|
SimpleVersioningTest<TensorFlowSumOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
|
|
|
|
void SimpleMulVersioningTest(ArrayDataType data_type, float multiplier,
|
|
int version) {
|
|
MulOperator op;
|
|
op.inputs = {"input1", "input2"};
|
|
op.outputs = {"output"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model model;
|
|
Array& input0 = model.GetOrCreateArray(op.inputs[0]);
|
|
Array& input1 = model.GetOrCreateArray(op.inputs[1]);
|
|
Array& output = model.GetOrCreateArray(op.outputs[0]);
|
|
|
|
input0.data_type = data_type;
|
|
input0.GetOrCreateQuantizationParams().scale = 1.0f;
|
|
input1.data_type = data_type;
|
|
input1.GetOrCreateQuantizationParams().scale = 1.0f;
|
|
output.data_type = data_type;
|
|
output.GetOrCreateQuantizationParams().scale = 1.0f / multiplier;
|
|
|
|
OperatorSignature signature = {.op = &op, .model = &model};
|
|
EXPECT_EQ(base_op->GetVersion(signature), version);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningMulTest) {
|
|
SimpleMulVersioningTest(ArrayDataType::kUint8, 0.5f, 1);
|
|
SimpleMulVersioningTest(ArrayDataType::kInt8, 0.5f, 2);
|
|
SimpleMulVersioningTest(ArrayDataType::kInt8, 2.0f, 3);
|
|
}
|
|
|
|
template <typename OpType>
|
|
void SimpleTwoInputsVersioningTest(ArrayDataType data_type, Shape shape1,
|
|
Shape shape2, int version) {
|
|
OpType op;
|
|
op.inputs = {"input1", "input2"};
|
|
op.outputs = {"output"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
|
|
|
Model model;
|
|
Array& input0 = model.GetOrCreateArray(op.inputs[0]);
|
|
Array& input1 = model.GetOrCreateArray(op.inputs[1]);
|
|
Array& output = model.GetOrCreateArray(op.outputs[0]);
|
|
|
|
input0.data_type = data_type;
|
|
input0.copy_shape(shape1);
|
|
input1.data_type = data_type;
|
|
input1.copy_shape(shape2);
|
|
output.data_type = data_type;
|
|
|
|
OperatorSignature signature = {.op = &op, .model = &model};
|
|
EXPECT_EQ(base_op->GetVersion(signature), version);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSubTest) {
|
|
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
|
|
{1, 2, 2, 2}, {1, 2, 2, 2}, 1);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2, 2},
|
|
{1, 2, 2, 2}, 2);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8, {1, 2, 2},
|
|
{1, 2, 2}, 1);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2},
|
|
{1, 2, 2}, 2);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kUint8,
|
|
{1, 2, 2, 2}, {1, 2, 2, 1}, 1);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(ArrayDataType::kInt8, {1, 2, 2, 2},
|
|
{1, 2, 2, 1}, 2);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(
|
|
ArrayDataType::kUint8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3);
|
|
SimpleTwoInputsVersioningTest<SubOperator>(
|
|
ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 3);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningDivTest) {
|
|
SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kUint8,
|
|
{1, 2, 2, 2}, {1, 2, 2, 2}, 1);
|
|
SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kInt8, {1, 2, 2},
|
|
{1, 2, 2}, 1);
|
|
SimpleTwoInputsVersioningTest<DivOperator>(ArrayDataType::kUint8,
|
|
{1, 2, 2, 2}, {1, 2, 2, 1}, 1);
|
|
SimpleTwoInputsVersioningTest<DivOperator>(
|
|
ArrayDataType::kInt8, {1, 2, 2, 2, 2}, {1, 2, 2, 2, 1}, 2);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
|
|
|
|
TEST_F(OperatorTest, VersioningPadV2Test) {
|
|
SimpleVersioningTest<PadV2Operator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningConcatenationTest) {
|
|
SimpleVersioningTest<ConcatenationOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningSelectTest) {
|
|
SimpleVersioningTest<SelectOperator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningRelu6Test) {
|
|
SimpleVersioningTest<Relu6Operator>();
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningFullyConnectedTest) {
|
|
FullyConnectedOperator fully_connected_op;
|
|
fully_connected_op.inputs = {"input", "weight"};
|
|
fully_connected_op.outputs = {"output"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* op =
|
|
operator_by_type_map.at(fully_connected_op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& input_uint8_array =
|
|
uint8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
|
|
input_uint8_array.data_type = ArrayDataType::kUint8;
|
|
Array& weight_uint8_array =
|
|
uint8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
|
|
weight_uint8_array.data_type = ArrayDataType::kUint8;
|
|
Array& output_uint8_array =
|
|
uint8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
|
|
output_uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &fully_connected_op,
|
|
.model = &uint8_model};
|
|
EXPECT_EQ(op->GetVersion(uint8_signature), 6);
|
|
|
|
Model int8_model;
|
|
Array& input_int8_array =
|
|
int8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
|
|
input_int8_array.data_type = ArrayDataType::kInt8;
|
|
Array& weight_int8_array =
|
|
int8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
|
|
weight_int8_array.data_type = ArrayDataType::kInt8;
|
|
Array& output_int8_array =
|
|
int8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
|
|
output_int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &fully_connected_op,
|
|
.model = &int8_model};
|
|
EXPECT_EQ(op->GetVersion(int8_signature), 6);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningDequantizeTest) {
|
|
DequantizeOperator dequant_op;
|
|
dequant_op.inputs = {"input"};
|
|
dequant_op.outputs = {"output"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* op = operator_by_type_map.at(dequant_op.type).get();
|
|
|
|
Model int16_model;
|
|
Array& input_int16_array = int16_model.GetOrCreateArray(dequant_op.inputs[0]);
|
|
input_int16_array.data_type = ArrayDataType::kInt16;
|
|
OperatorSignature int16_signature = {.op = &dequant_op,
|
|
.model = &int16_model};
|
|
EXPECT_EQ(op->GetVersion(int16_signature), 3);
|
|
|
|
Model float16_model;
|
|
Array& input_float16_array =
|
|
float16_model.GetOrCreateArray(dequant_op.inputs[0]);
|
|
input_float16_array.data_type = ArrayDataType::kFloat16;
|
|
OperatorSignature float16_signature = {.op = &dequant_op,
|
|
.model = &float16_model};
|
|
EXPECT_EQ(op->GetVersion(float16_signature), 3);
|
|
|
|
Model int8_model;
|
|
Array& input_int8_array = int8_model.GetOrCreateArray(dequant_op.inputs[0]);
|
|
input_int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &dequant_op, .model = &int8_model};
|
|
EXPECT_EQ(op->GetVersion(int8_signature), 2);
|
|
|
|
Model float_model;
|
|
Array& input_float_array = float_model.GetOrCreateArray(dequant_op.inputs[0]);
|
|
input_float_array.data_type = ArrayDataType::kFloat;
|
|
OperatorSignature float_signature = {.op = &dequant_op,
|
|
.model = &float_model};
|
|
EXPECT_EQ(op->GetVersion(float_signature), 1);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningConv2DTest) {
|
|
ConvOperator conv_op;
|
|
conv_op.inputs = {"input", "filter"};
|
|
conv_op.outputs = {"output"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* op = operator_by_type_map.at(conv_op.type).get();
|
|
|
|
Model uint8_model;
|
|
Array& input_uint8_array = uint8_model.GetOrCreateArray(conv_op.inputs[0]);
|
|
input_uint8_array.data_type = ArrayDataType::kUint8;
|
|
Array& filter_uint8_array = uint8_model.GetOrCreateArray(conv_op.inputs[1]);
|
|
filter_uint8_array.data_type = ArrayDataType::kUint8;
|
|
Array& output_uint8_array = uint8_model.GetOrCreateArray(conv_op.outputs[0]);
|
|
output_uint8_array.data_type = ArrayDataType::kUint8;
|
|
OperatorSignature uint8_signature = {.op = &conv_op, .model = &uint8_model};
|
|
EXPECT_EQ(op->GetVersion(uint8_signature), 1);
|
|
|
|
Model int8_model;
|
|
Array& input_int8_array = int8_model.GetOrCreateArray(conv_op.inputs[0]);
|
|
input_int8_array.data_type = ArrayDataType::kInt8;
|
|
Array& filter_int8_array = int8_model.GetOrCreateArray(conv_op.inputs[1]);
|
|
filter_int8_array.data_type = ArrayDataType::kInt8;
|
|
Array& output_int8_array = int8_model.GetOrCreateArray(conv_op.outputs[0]);
|
|
output_int8_array.data_type = ArrayDataType::kInt8;
|
|
OperatorSignature int8_signature = {.op = &conv_op, .model = &int8_model};
|
|
EXPECT_EQ(op->GetVersion(int8_signature), 3);
|
|
|
|
Model float_model;
|
|
Array& input_float_array = float_model.GetOrCreateArray(conv_op.inputs[0]);
|
|
input_float_array.data_type = ArrayDataType::kFloat;
|
|
Array& filter_int8_array1 = float_model.GetOrCreateArray(conv_op.inputs[1]);
|
|
filter_int8_array1.data_type = ArrayDataType::kInt8;
|
|
Array& output_float_array = float_model.GetOrCreateArray(conv_op.outputs[0]);
|
|
output_float_array.data_type = ArrayDataType::kFloat;
|
|
OperatorSignature float_signature = {.op = &conv_op, .model = &float_model};
|
|
EXPECT_EQ(op->GetVersion(float_signature), 2);
|
|
}
|
|
|
|
TEST_F(OperatorTest, VersioningFloorDivOperatorTest) {
|
|
FloorDivOperator floordiv_op;
|
|
floordiv_op.inputs = {"input1"};
|
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
|
const BaseOperator* op = operator_by_type_map.at(floordiv_op.type).get();
|
|
|
|
Model int32_model;
|
|
Array& input_int32_array =
|
|
int32_model.GetOrCreateArray(floordiv_op.inputs[0]);
|
|
input_int32_array.data_type = ArrayDataType::kInt32;
|
|
OperatorSignature int32_signature = {.op = &floordiv_op,
|
|
.model = &int32_model};
|
|
EXPECT_EQ(op->GetVersion(int32_signature), 1);
|
|
|
|
Model float_model;
|
|
Array& input_float_array =
|
|
float_model.GetOrCreateArray(floordiv_op.inputs[0]);
|
|
input_float_array.data_type = ArrayDataType::kFloat;
|
|
OperatorSignature float_signature = {.op = &floordiv_op,
|
|
.model = &float_model};
|
|
EXPECT_EQ(op->GetVersion(float_signature), 2);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace tflite
|
|
|
|
} // namespace toco
|