Add versioning based on input/output types in toco.
To make sure we keep toco versioning as a single source of versioning information. Dequantize is the ops that support incremented version based on the int8 type, currently but the hybrid ops will soon follow. PiperOrigin-RevId: 225958674
This commit is contained in:
parent
5a02334ec9
commit
e4bdb31636
@ -376,7 +376,7 @@ struct Operator {
|
||||
// Output activation arrays. Same comments as for inputs apply here too.
|
||||
std::vector<string> outputs;
|
||||
|
||||
// If true, the array has more outputs than are listed in the 'outputs'
|
||||
// If true, the operator has more outputs than are listed in the 'outputs'
|
||||
// member. These need to be resolved by some graph transformation.
|
||||
// This flag is only here to indicate that an operator should not be
|
||||
// discarded as unused, even if from its 'outputs' member alone it
|
||||
@ -2208,6 +2208,16 @@ class Model {
|
||||
// addresses. See Operator::inputs, Operator::outputs.
|
||||
std::unordered_map<string, std::unique_ptr<Array>> arrays;
|
||||
};
|
||||
|
||||
// OperatorSignature contains the information required to making versioning
|
||||
// decisions.
|
||||
struct OperatorSignature {
|
||||
// The operator.
|
||||
const Operator* op;
|
||||
|
||||
// The model in which the operator resides.
|
||||
const Model* model;
|
||||
};
|
||||
} // namespace toco
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOCO_MODEL_H_
|
||||
|
||||
@ -106,16 +106,17 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
|
||||
namespace details {
|
||||
|
||||
OperatorKey::OperatorKey(
|
||||
const ::toco::Operator& op,
|
||||
const ::toco::OperatorSignature& op_signature,
|
||||
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
|
||||
bool enable_select_tf_ops) {
|
||||
// Get the op name (by Toco definition).
|
||||
const ::toco::Operator& op = *op_signature.op;
|
||||
string name = HelpfulOperatorTypeName(op);
|
||||
|
||||
bool is_builtin = false;
|
||||
const auto& builtin_ops = GetBuiltinOpsMap();
|
||||
if (ops_by_type.count(op.type) != 0) {
|
||||
version_ = ops_by_type.at(op.type)->GetVersion(op);
|
||||
version_ = ops_by_type.at(op.type)->GetVersion(op_signature);
|
||||
name = ops_by_type.at(op.type)->name();
|
||||
is_builtin = (builtin_ops.count(name) > 0);
|
||||
}
|
||||
@ -190,7 +191,8 @@ void LoadOperatorsMap(
|
||||
// First find a list of unique operator types.
|
||||
std::set<OperatorKey> keys;
|
||||
for (const auto& op : model.operators) {
|
||||
keys.insert(OperatorKey(*op, ops_by_type, enable_select_tf_ops));
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
keys.insert(OperatorKey(op_signature, ops_by_type, enable_select_tf_ops));
|
||||
}
|
||||
// Now assign indices to them and fill in the map.
|
||||
int index = 0;
|
||||
@ -301,8 +303,9 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
|
||||
std::map<int, Offset<OperatorCode>> ordered_opcodes;
|
||||
|
||||
for (const auto& op : model.operators) {
|
||||
const details::OperatorKey operator_key =
|
||||
details::OperatorKey(*op, ops_by_type, params.enable_select_tf_ops);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const details::OperatorKey operator_key = details::OperatorKey(
|
||||
op_signature, ops_by_type, params.enable_select_tf_ops);
|
||||
int op_index = operators_map.at(operator_key);
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> custom_code = 0;
|
||||
@ -349,9 +352,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
|
||||
for (const string& output : op->outputs) {
|
||||
outputs.push_back(tensors_map.at(output));
|
||||
}
|
||||
|
||||
const auto key =
|
||||
details::OperatorKey(*op, ops_by_type, params.enable_select_tf_ops);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type,
|
||||
params.enable_select_tf_ops);
|
||||
int op_index = operators_map.at(key);
|
||||
|
||||
auto tflite_op_it = ops_by_type.find(op->type);
|
||||
|
||||
@ -76,7 +76,7 @@ inline void Export(const Model& model, string* output_file_contents) {
|
||||
|
||||
namespace details {
|
||||
|
||||
// A maps from tensor name to its final position in the TF Lite buffer.
|
||||
// A map from tensor name to its final position in the TF Lite buffer.
|
||||
using TensorsMap = std::unordered_map<string, int>;
|
||||
|
||||
// A key to identify an operator.
|
||||
@ -88,7 +88,7 @@ class OperatorKey {
|
||||
|
||||
// Construct OperatorKey by Toco op.
|
||||
OperatorKey(
|
||||
const ::toco::Operator& op,
|
||||
const ::toco::OperatorSignature& op_signature,
|
||||
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
|
||||
bool enable_select_tf_ops);
|
||||
|
||||
@ -158,7 +158,7 @@ class OperatorKey {
|
||||
std::string flex_tensorflow_op_;
|
||||
};
|
||||
|
||||
// A maps from operator type to its final position in the TF Lite buffer.
|
||||
// A map from OperatorKey to its final position in the TF Lite buffer.
|
||||
using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
|
||||
|
||||
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
|
||||
|
||||
@ -301,8 +301,9 @@ class FakeConvolutionOperator
|
||||
OperatorType::kConv) {}
|
||||
|
||||
// Returning the op version according to the op parameters.
|
||||
int GetVersion(const Operator& op) const override {
|
||||
const TocoOperator& conv_op = static_cast<const TocoOperator&>(op);
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const TocoOperator& conv_op =
|
||||
static_cast<const TocoOperator&>(*op_signature.op);
|
||||
if (conv_op.dilation_width_factor != 1 ||
|
||||
conv_op.dilation_height_factor != 1) {
|
||||
// Version 2 if dilation is used.
|
||||
@ -448,22 +449,58 @@ TEST_F(VersionedOpExportTest, Export) {
|
||||
}
|
||||
|
||||
TEST(OperatorKeyTest, TestBuiltinOp) {
|
||||
Model model;
|
||||
auto op = absl::make_unique<ConvOperator>();
|
||||
|
||||
// Test a normal float operation.
|
||||
op->inputs = {"input", "filter"};
|
||||
op->outputs = {"output"};
|
||||
Array& input_array = model.GetOrCreateArray(op->inputs[0]);
|
||||
Array& filter_array = model.GetOrCreateArray(op->inputs[1]);
|
||||
Array& output_array = model.GetOrCreateArray(op->outputs[0]);
|
||||
input_array.data_type = ArrayDataType::kFloat;
|
||||
filter_array.data_type = ArrayDataType::kFloat;
|
||||
output_array.data_type = ArrayDataType::kFloat;
|
||||
|
||||
const auto ops_by_type = BuildOperatorByTypeMap();
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, false);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
|
||||
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CONV_2D);
|
||||
EXPECT_EQ(key.custom_code(), "");
|
||||
EXPECT_EQ(key.version(), 1);
|
||||
}
|
||||
|
||||
TEST(OperatorKeyTest, TestBuiltinOpWithVersionedInputTypes) {
|
||||
Model model;
|
||||
auto op = absl::make_unique<DequantizeOperator>();
|
||||
|
||||
op->inputs = {"input"};
|
||||
op->outputs = {"output"};
|
||||
Array& input_array = model.GetOrCreateArray(op->inputs[0]);
|
||||
Array& output_array = model.GetOrCreateArray(op->outputs[0]);
|
||||
input_array.data_type = ArrayDataType::kInt8;
|
||||
output_array.data_type = ArrayDataType::kFloat;
|
||||
|
||||
const auto ops_by_type = BuildOperatorByTypeMap();
|
||||
|
||||
// Test a signed int8 dequantize operation.
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
|
||||
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_DEQUANTIZE);
|
||||
EXPECT_EQ(key.custom_code(), "");
|
||||
EXPECT_EQ(key.version(), 2);
|
||||
}
|
||||
|
||||
TEST(OperatorKeyTest, TestCustomOp) {
|
||||
Model model;
|
||||
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
|
||||
op->tensorflow_op = "MyCrazyCustomOp";
|
||||
|
||||
const auto ops_by_type = BuildOperatorByTypeMap();
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, false);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
|
||||
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "MyCrazyCustomOp");
|
||||
@ -471,12 +508,14 @@ TEST(OperatorKeyTest, TestCustomOp) {
|
||||
}
|
||||
|
||||
TEST(OperatorKeyTest, TestFlexOp) {
|
||||
Model model;
|
||||
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
|
||||
op->tensorflow_op = "BatchMatMul";
|
||||
|
||||
const auto ops_by_type = BuildOperatorByTypeMap();
|
||||
{
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, false);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
|
||||
// It shouldn't be converted to Flex op if `allow_flex_op` is false.
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "BatchMatMul");
|
||||
@ -488,7 +527,8 @@ TEST(OperatorKeyTest, TestFlexOp) {
|
||||
{
|
||||
// Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
|
||||
// is true.
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, true);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "FlexBatchMatMul");
|
||||
EXPECT_EQ(key.version(), 1);
|
||||
@ -498,11 +538,13 @@ TEST(OperatorKeyTest, TestFlexOp) {
|
||||
}
|
||||
|
||||
TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
|
||||
Model model;
|
||||
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
|
||||
op->tensorflow_op = "Merge";
|
||||
|
||||
const auto ops_by_type = BuildOperatorByTypeMap();
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, true);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
|
||||
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "FlexMerge");
|
||||
@ -514,11 +556,13 @@ TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
|
||||
}
|
||||
|
||||
TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
|
||||
Model model;
|
||||
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
|
||||
op->tensorflow_op = "HashTableV2";
|
||||
|
||||
const auto ops_by_type = BuildOperatorByTypeMap();
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, true);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
|
||||
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "HashTableV2");
|
||||
@ -532,6 +576,7 @@ TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
|
||||
|
||||
TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
|
||||
// Test Toco-supported/TFLite-unsupported operators.
|
||||
Model model;
|
||||
// TODO(ycling): The test will be broken if TensorFlowAssert is implemented in
|
||||
// TFLite. Find a more robust way to test the fallback logic.
|
||||
auto op = absl::make_unique<TensorFlowAssertOperator>();
|
||||
@ -541,7 +586,8 @@ TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
|
||||
{
|
||||
// If NodeDef isn't retained in the Toco op, a regular custom op
|
||||
// will be exported.
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, true);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "Assert");
|
||||
EXPECT_EQ(key.version(), 1);
|
||||
@ -556,7 +602,8 @@ TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
|
||||
|
||||
{
|
||||
// If NodeDef is retained in the Toco op, a Flex op will be exported.
|
||||
const auto key = details::OperatorKey(*op, ops_by_type, true);
|
||||
const toco::OperatorSignature op_signature = {op.get(), &model};
|
||||
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "FlexAssert");
|
||||
EXPECT_EQ(key.version(), 1);
|
||||
|
||||
@ -14,19 +14,20 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/toco/tflite/operator.h"
|
||||
|
||||
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
||||
// graph_transformation module.
|
||||
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
||||
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/custom_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/simple_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/types.h"
|
||||
#include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
||||
// graph_transformation module.
|
||||
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/custom_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/simple_operator.h"
|
||||
#include "tensorflow/lite/toco/tflite/types.h"
|
||||
#include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
@ -60,7 +61,9 @@ class AveragePool
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Convolution
|
||||
@ -92,7 +95,9 @@ class Convolution
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class DepthwiseConvolution
|
||||
@ -126,8 +131,9 @@ class DepthwiseConvolution
|
||||
op->dilation_height_factor = options.dilation_h_factor();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override {
|
||||
const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const auto& conv_op =
|
||||
static_cast<const DepthwiseConvOperator&>(*op_signature.op);
|
||||
if (conv_op.dilation_width_factor != 1 ||
|
||||
conv_op.dilation_height_factor != 1) {
|
||||
return 2;
|
||||
@ -155,7 +161,9 @@ class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SpaceToBatchND
|
||||
@ -174,7 +182,9 @@ class SpaceToBatchND
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
|
||||
@ -196,7 +206,9 @@ class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
|
||||
@ -218,7 +230,9 @@ class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class BatchToSpaceND
|
||||
@ -237,7 +251,9 @@ class BatchToSpaceND
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
|
||||
@ -258,7 +274,9 @@ class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
|
||||
op->dst_data_type = DataType::Deserialize(options.out_data_type());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Concatenation
|
||||
@ -278,7 +296,9 @@ class Concatenation
|
||||
op->axis = options.axis();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
|
||||
@ -292,7 +312,9 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
|
||||
op->block_size = m["block_size"].AsInt64();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class FakeQuant
|
||||
@ -315,9 +337,8 @@ class FakeQuant
|
||||
op->num_bits = options.num_bits();
|
||||
op->narrow_range = options.narrow_range();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override {
|
||||
const auto& fq_op = static_cast<const FakeQuantOperator&>(op);
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
|
||||
return fq_op.narrow_range ? 2 : 1;
|
||||
}
|
||||
};
|
||||
@ -369,8 +390,9 @@ class FullyConnected
|
||||
}
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override {
|
||||
const auto& fc_op = static_cast<const FullyConnectedOperator&>(op);
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const auto& fc_op =
|
||||
static_cast<const FullyConnectedOperator&>(*op_signature.op);
|
||||
return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1
|
||||
: 2;
|
||||
}
|
||||
@ -392,7 +414,9 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
|
||||
op->axis = {options.axis()};
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
|
||||
@ -414,7 +438,9 @@ class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
|
||||
op->rank = options.rank();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class L2Normalization
|
||||
@ -436,7 +462,9 @@ class L2Normalization
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
|
||||
@ -465,7 +493,9 @@ class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class LocalResponseNormalization
|
||||
@ -490,7 +520,9 @@ class LocalResponseNormalization
|
||||
op->beta = options.beta();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
|
||||
@ -519,7 +551,9 @@ class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
|
||||
@ -541,7 +575,9 @@ class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
|
||||
ActivationFunction::Deserialize(options.fused_activation_function());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
|
||||
@ -558,7 +594,9 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Tile
|
||||
@ -574,7 +612,9 @@ class Tile
|
||||
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
|
||||
@ -591,7 +631,9 @@ class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Reshape
|
||||
@ -614,7 +656,9 @@ class Reshape
|
||||
options.new_shape()->end());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Softmax
|
||||
@ -633,7 +677,9 @@ class Softmax
|
||||
op->beta = options.beta();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SpaceToDepth
|
||||
@ -653,7 +699,9 @@ class SpaceToDepth
|
||||
op->block_size = options.block_size();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Transpose
|
||||
@ -670,7 +718,9 @@ class Transpose
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
|
||||
@ -713,8 +763,9 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
|
||||
}
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override {
|
||||
const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const auto& lstm_op =
|
||||
static_cast<const LstmCellOperator&>(*op_signature.op);
|
||||
switch (lstm_op.kernel_type) {
|
||||
case LstmCellOperator::KERNEL_FULL:
|
||||
return 1;
|
||||
@ -770,7 +821,9 @@ class UnidirectionalSequenceLstm
|
||||
::tflite::ActivationFunctionType_TANH);
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<bool> GetMutatingInputVariables(
|
||||
const Operator& op) const override {
|
||||
@ -796,7 +849,9 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
|
||||
op->keep_dims = options.keep_dims();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Sum
|
||||
@ -815,7 +870,9 @@ class Sum
|
||||
op->keep_dims = options.keep_dims();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceMax
|
||||
@ -834,7 +891,9 @@ class ReduceMax
|
||||
op->keep_dims = options.keep_dims();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceMin
|
||||
@ -853,7 +912,9 @@ class ReduceMin
|
||||
op->keep_dims = options.keep_dims();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceProd
|
||||
@ -872,7 +933,9 @@ class ReduceProd
|
||||
op->keep_dims = options.keep_dims();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ReduceAny
|
||||
@ -891,7 +954,9 @@ class ReduceAny
|
||||
op->keep_dims = options.keep_dims();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ResizeBilinear
|
||||
@ -911,7 +976,9 @@ class ResizeBilinear
|
||||
op->align_corners = options.align_corners();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ResizeNearestNeighbor
|
||||
@ -932,7 +999,9 @@ class ResizeNearestNeighbor
|
||||
op->align_corners = options.align_corners();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Squeeze
|
||||
@ -955,7 +1024,9 @@ class Squeeze
|
||||
options.squeeze_dims()->end());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Split
|
||||
@ -975,7 +1046,9 @@ class Split
|
||||
op->num_split = options.num_splits();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SplitV
|
||||
@ -995,7 +1068,9 @@ class SplitV
|
||||
op->num_split = options.num_splits();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class StridedSlice
|
||||
@ -1021,7 +1096,9 @@ class StridedSlice
|
||||
op->shrink_axis_mask = options.shrink_axis_mask();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
|
||||
@ -1037,7 +1114,9 @@ class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
|
||||
@ -1056,7 +1135,9 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
|
||||
op->output_data_type = DataType::Deserialize(options.output_type());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
|
||||
@ -1075,7 +1156,9 @@ class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
|
||||
op->output_data_type = DataType::Deserialize(options.output_type());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeConv
|
||||
@ -1100,7 +1183,9 @@ class TransposeConv
|
||||
op->stride_height = options.stride_h();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SparseToDense
|
||||
@ -1121,7 +1206,9 @@ class SparseToDense
|
||||
op->validate_indices = options.validate_indices();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class ExpandDims
|
||||
@ -1139,7 +1226,9 @@ class ExpandDims
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
|
||||
@ -1159,7 +1248,9 @@ class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
|
||||
op->axis = options.axis();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Shape
|
||||
@ -1179,7 +1270,9 @@ class Shape
|
||||
op->output_data_type = DataType::Deserialize(options.out_type());
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
|
||||
@ -1196,7 +1289,9 @@ class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
|
||||
op->axis = options.axis();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class CTCBeamSearchDecoder
|
||||
@ -1217,7 +1312,9 @@ class CTCBeamSearchDecoder
|
||||
op->merge_repeated = m["merge_repeated"].AsBool();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
|
||||
@ -1235,7 +1332,9 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
|
||||
op->axis = options.axis();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class LeakyRelu
|
||||
@ -1253,7 +1352,9 @@ class LeakyRelu
|
||||
op->alpha = options.alpha();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SquaredDifference
|
||||
@ -1272,7 +1373,9 @@ class SquaredDifference
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class MirrorPad
|
||||
@ -1295,7 +1398,9 @@ class MirrorPad
|
||||
: MirrorPadMode::kSymmetric;
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
|
||||
@ -1472,8 +1577,8 @@ class TensorFlowUnsupported : public BaseOperator {
|
||||
node_def.SerializeToString(&op->tensorflow_node_def);
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override {
|
||||
// TODO(ycling): Deisng and implement a way to plumb the version of
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
// TODO(ycling): Design and implement a way to plumb the version of
|
||||
// custom ops.
|
||||
return 1;
|
||||
}
|
||||
@ -1497,11 +1602,13 @@ class Dequantize
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const Operator& op) const override {
|
||||
// TODO(suharshs): Dequantize now supports INT8 in addition to
|
||||
// QUANTIZED_UINT8. When TOCO can create models with INT8, we need
|
||||
// to find a way to see the type here and return version 2. Right now
|
||||
// version 2 will only be added by post training quantization tools.
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// Version 2 supports signed int8 input types.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
@ -1534,6 +1641,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
ops.push_back(MakeUnique<DepthwiseConvolution>(
|
||||
::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
OperatorType::kDepthwiseConv));
|
||||
ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
|
||||
OperatorType::kDequantize));
|
||||
ops.push_back(
|
||||
MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
|
||||
OperatorType::kFullyConnected));
|
||||
@ -1645,8 +1754,6 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
// when custom ops are exported but SimpleOperator bypasses thoses. To
|
||||
// prevent user confusion we are settling on using SimpleOperator only for
|
||||
// builtins.
|
||||
ops.push_back(MakeUnique<SimpleOperator<DequantizeOperator>>(
|
||||
"DEQUANTIZE", OperatorType::kDequantize));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor));
|
||||
ops.push_back(
|
||||
|
||||
@ -87,15 +87,15 @@ class BaseOperator {
|
||||
const BuiltinOptions* builtin_options,
|
||||
const CustomOptions* custom_options) const = 0;
|
||||
|
||||
// Get the op version by op parameters.
|
||||
// The function need to be overridden to return the op version based on the
|
||||
// Get the op version using the OperatorSignature.
|
||||
// The function needs to be overridden to return the op version based on the
|
||||
// parameters. Note:
|
||||
// * The first version for each op should be 1 (to be consistent with the
|
||||
// default value in Flatbuffer. `return 1;` is okay for newly implemented
|
||||
// ops.
|
||||
// * When multiple versions are defined for an op, this function need to be
|
||||
// * When multiple versions are defined for an op, this function needs to be
|
||||
// overridden. (See example in `operator_test.cc`)
|
||||
virtual int GetVersion(const Operator& op) const = 0;
|
||||
virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
|
||||
|
||||
// Given a Toco `Operator`, return a list of booleans indicating the op
|
||||
// mutates which input variables.
|
||||
|
||||
@ -111,8 +111,6 @@ class OperatorTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(OperatorTest, SimpleOperators) {
|
||||
CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE",
|
||||
OperatorType::kDequantize);
|
||||
CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
|
||||
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
|
||||
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
|
||||
@ -469,6 +467,12 @@ TEST_F(OperatorTest, BuiltinArgMin) {
|
||||
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinDequantize) {
|
||||
DequantizeOperator op;
|
||||
auto output_toco_op = SerializeAndDeserialize(
|
||||
GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinTransposeConv) {
|
||||
TransposeConvOperator op;
|
||||
op.stride_width = 123;
|
||||
|
||||
@ -42,7 +42,9 @@ class SimpleOperator : public BaseOperator {
|
||||
return std::unique_ptr<Operator>(new T);
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user